# タンパク質のシーケンスおよび構造の探索

In [None]:
!pip install dgl dgllife biopython seaborn transformers
!pip uninstall -y ipywidgets
!pip install ipywidgets

In [None]:
!git clone https://github.com/aws-samples/lm-gvp.git
%cd lm-gvp/
!git reset --hard 3e0cd99bad9134466f6588eea278a7bce9fd60a9
%cd ../

In [None]:
import torch
import dgl

import boto3
import gzip
import numpy as np
import pandas as pd
from io import StringIO
from sklearn.metrics import pairwise_distances
from Bio import pairwise2
from Bio.Seq import Seq
from Bio.PDB.Polypeptide import three_to_one, is_aa
from Bio.PDB import MMCIFParser

In [None]:
# custom utils
import sys
sys.path.append('lm-gvp')
from data.contact_map_utils import gunzip_to_ram
from data.prepare_GO_data import chain_to_coords

In [None]:
import matplotlib.pylab as plt
import seaborn as sns

In [None]:
def read_file_from_s3(bucket, prefix):
    s3 = boto3.resource('s3')
    obj = s3.Object(bucket, prefix)
    return obj.get()['Body']

## 1. AWS Opendata レジストリを介して PDB からタンパク質構造を取得する

https://registry.opendata.aws/pdb-3d-structural-biology-data/

In [None]:
pdb_id = '6XWU'
pdb_bucket_name = 'pdbsnapshots'

In [None]:
# PDB からタンパク質構造を取得して解析する
cif_parser = MMCIFParser()

pdb_file = read_file_from_s3(
    pdb_bucket_name,
    f'20220103/pub/pdb/data/structures/all/mmCIF/{pdb_id.lower()}.cif.gz'
)
structure = cif_parser.get_structure(pdb_id, gunzip_to_ram(pdb_file))

In [None]:
print(structure, type(structure))

In [None]:
def extract_coords(structure, target_atoms=["N", "CA", "C", "O"]):
    '''
    Extract the atomic coordinates for all the chains.
    '''
    records = []
    chain_ids = set()
    for chain in structure.get_chains():
        if chain.id in chain_ids:  # skip duplicated chains
            continue
        chain_ids.add(chain.id)
        record = chain_to_coords(chain, 
                                 name=structure.id, 
                                 target_atoms=target_atoms)
        if record is not None:
            records.append(record)
    return records

In [None]:
records = extract_coords(structure)
# take the first chain from the structure
protein = records[0]
print(protein['seq'])
print('sequence length:', len(protein['seq']))

In [None]:
# 3d coordinates of 4 key atoms ["N", "CA", "C", "O"]
coords = np.asarray(protein['coords']) 
print(coords.shape)

In [None]:
# タンパク質のアミノ酸残基全体で 4つの重要な原子を視覚化
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')

colors = sns.color_palette()
atoms = ["N", "CA", "C", "O"]
for i, atom in enumerate(atoms):
    ax.scatter(coords[:, i, 0],
               coords[:, i, 1],
               coords[:, i, 2],
               color=colors[i]
              );
    
# CA原子を順次接続
coords_CA = coords[:, 1]
N_residues = coords.shape[0]
edge_pos = np.array([(coords_CA[u], coords_CA[u+1]) for u in range(N_residues-1)])

# Plot the edges
for vizedge in edge_pos:
    ax.plot(*vizedge.T, color="tab:gray")

In [None]:
# Visualize as contact map
# calculate adjacency matrix of residues by proximity threshod
dist_thresh = 10.0
dist_mat = pairwise_distances(coords_CA, metric="euclidean")
adj =  1 * (dist_mat < dist_thresh)
print(adj.shape)
sns.heatmap(adj, cmap='Reds');

In [None]:
# alternatively, build a kNN graph
g = dgl.knn_graph(torch.tensor(coords_CA), k=30)
adj = g.adj(scipy_fmt='coo')
print(adj.shape)
sns.heatmap(adj.todense(), cmap='Reds');

## 2. 事前学習済みのタンパク質言語モデルを用いたタンパク質配列のコンテキスト埋め込みの計算

In [None]:
from transformers import BertTokenizer, AlbertModel

In [None]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
prot_lm = AlbertModel.from_pretrained('Rostlab/prot_albert')

In [None]:
import re
def prep_seq(seq):
    """
    Adding spaces between AAs and replace rare AA [UZOB] to X.
    ref: https://huggingface.co/Rostlab/prot_bert.
    """
    seq_spaced = " ".join(seq)
    seq_input = re.sub(r"[UZOB]", "X", seq_spaced)
    return seq_input

In [None]:
encodings = tokenizer(prep_seq(protein['seq']), 
                      return_tensors="pt", 
                      padding=True)
encodings

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
prot_lm = prot_lm.to(device)

In [None]:
prot_lm.eval()

In [None]:
# skip [CLS] and [SEP]
residue_embeddings = prot_lm(encodings['input_ids'].to(device))\
    .last_hidden_state[:, 1:-1, :]
print(residue_embeddings.shape)

## 3. タンパク質の追加特徴量の生成

- Edge
    + ベクトル特徴量：エッジベクトル
    + スカラー特徴量 エッジベクトルのRBF    
- Node
    + ベクトル特徴量
    + スカラー特徴量

In [None]:
# construct knn graph from C-alpha coordinates
coords_CA = torch.tensor(coords_CA)
g = dgl.knn_graph(coords_CA, k=30)
edge_index = g.edges()
g

In [None]:
# Directions along edges between adjacent C-alpha atoms
E_vectors = coords_CA[edge_index[0]] - coords_CA[edge_index[1]]
E_vectors.shape

In [None]:
import math
import torch.nn.functional as F

def get_rbf(D, D_min=0.0, D_max=20.0, D_count=16, device="cpu"):
    """
    From https://github.com/jingraham/neurips19-graph-protein-design

    Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
    That is, if `D` has shape [...dims], then the returned tensor will have
    shape [...dims, D_count].


    Args:
        D: generic torch tensor
        D_min: Float. Minimum of the sequence of numbers created.
        D_max: Float. Max of the sequence of numbers created.
        D_count: Positive integer. Count of the numbers in the sequence. It is also lenght of the new dimension (-1) created in D.
        device: Device where D is stored.

    Return:
        Input `D` matrix with an RBF embedding along axis -1.
    """
    D_mu = torch.linspace(D_min, D_max, D_count, device=device)
    D_mu = D_mu.view([1, -1])
    D_sigma = (D_max - D_min) / D_count
    D_expand = torch.unsqueeze(D, -1)

    RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
    return RBF

In [None]:
# RBF embeddings of the edge vectors as edge scalar features
rbf_E_vectors = get_rbf(E_vectors.norm(dim=-1), D_count=16)
rbf_E_vectors.shape

In [None]:
def _normalize(tensor, dim=-1):
    """
    Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
    """
    return torch.nan_to_num(
        torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))
    )

def get_dihedrals(X, eps=1e-7):
    """Compute sines and cosines dihedral angles (phi, psi, and omega)

    Args:
        X: torch.Tensor specifying coordinates of key atoms (N, CA, C, O) in 3D space with shape [seq_len, 4, 3]
        eps: Float defining the epsilon using to clamp the angle between normals: min= -1*eps, max=1-eps

    Returns:
        Sines and cosines dihedral angles as a torch.Tensor of shape [seq_len, 6]
    """
    # From https://github.com/jingraham/neurips19-graph-protein-design
    X = torch.reshape(X[:, :3], [3 * X.shape[0], 3])
    dX = X[1:] - X[:-1]
    U = _normalize(dX, dim=-1)
    u_2 = U[:-2]
    u_1 = U[1:-1]
    u_0 = U[2:]

    # Backbone normals
    n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
    n_1 = _normalize(torch.cross(u_1, u_0), dim=-1)

    # Angle between normals
    cosD = torch.sum(n_2 * n_1, -1)
    cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
    D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)

    # This scheme will remove phi[0], psi[-1], omega[-1]
    D = F.pad(D, [1, 2])
    D = torch.reshape(D, [-1, 3])
    # Lift angle representations to the circle
    D_features = torch.cat([torch.cos(D), torch.sin(D)], 1)
    return D_features

In [None]:
# Dihedral angles as node scalar features
dihedrals = get_dihedrals(torch.tensor(coords))
dihedrals.shape

In [None]:
def get_orientations(X):
    """Compute orientations between pairs of atoms from neighboring residues.

    Args:
        X: torch.Tensor representing atom coordinates with shape [n_atoms, 3]

    Returns:
        torch.Tensor atom pair orientations
    """
    forward = _normalize(X[1:] - X[:-1])
    backward = _normalize(X[:-1] - X[1:])
    forward = F.pad(forward, [0, 0, 0, 1])
    backward = F.pad(backward, [0, 0, 1, 0])
    return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)

def get_sidechains(X):
    """Compute the unit vector representing the imputed side chain directions (C_beta - C_alpha).

    Args:
        X: torch.Tensor specifying coordinates of key atoms (N, CA, C, O) in 3D space with shape [seq_len, 4, 3]

    Returns:
        Torch tensor representing side chain directions with shape [seq_len, 3]
    """
    n, origin, c = X[:, 0], X[:, 1], X[:, 2]
    c, n = _normalize(c - origin), _normalize(n - origin)
    bisector = _normalize(c + n)
    perp = _normalize(torch.cross(c, n))
    vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
    return vec

In [None]:
# Orientations, side chains as node vector features
orientations = get_orientations(coords_CA)
print(orientations.shape)
sidechains = get_sidechains(torch.tensor(coords))
print(sidechains.shape)

In [None]:
coords_CA.shape

In [None]:
# plot the side chain vector feature as a 3d vector field
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')

atoms = ["N", "CA", "C", "O"]
i = 1
ax.scatter(coords[:, i, 0],
           coords[:, i, 1],
           coords[:, i, 2],
           color=colors[i]
          );
    
# # connect CA atoms sequentially
coords_CA = coords[:, 1]
N_residues = coords.shape[0]
edge_pos = np.array([(coords_CA[u], coords_CA[u+1]) for u in range(N_residues-1)])

# Plot the edges
for vizedge in edge_pos:
    ax.plot(*vizedge.T, color="tab:gray")
    
# plot the directions of side chains for each node (amino acid residue)
ax.quiver(coords_CA[:, 0], 
          coords_CA[:, 1],
          coords_CA[:, 2],
          sidechains[:, 0],
          sidechains[:, 1],
          sidechains[:, 2],
          length=2, 
          normalize=True
         )

## 4. タンパク質Functionデータセットを調べる

トレーニング データセットは、LM-GVP 論文で使用されている元のデータのうち10%をランダムにサンプリングしたサブセットです。

In [None]:
import json

In [None]:
%%bash
# Download protein function prediction data
mkdir -p protein_data/DeepFRI_GO

files=(
    proteins_train.json \
    proteins_valid.json \
    proteins_test.json \
    nrPDB-GO_2019.06.18_annot.tsv
)
for i in ${!files[@]}; do
    file=${files[i]}
    echo "Downloading " $file
    wget https://d2125kp0qwrvcx.cloudfront.net/DeepFRI_GO_data/$file \
        -P protein_data/DeepFRI_GO
done

In [None]:
train_data = json.load(
    open('protein_data/DeepFRI_GO/proteins_train.json', 'r')
)

In [None]:
N_train = len(train_data)
print('Number of proteins in training set:', N_train)

# examine a protein instance
protein = train_data[0]
print(protein.keys())
print('name:', protein['name'])
print('seq:', protein['seq'])
print('coords.shape:', np.asarray(protein['coords']).shape)

In [None]:
# sample_idx = np.random.choice(N_train, int(0.1*N_train))
# train_data_sample = [train_data[idx] for idx in sample_idx]
# len(train_data_sample)

In [None]:
# json.dump(train_data_sample, open('proteins_train.json', 'w'))

In [None]:
from lmgvp.deepfrier_utils import load_GO_annot

In [None]:
# !aws s3 cp s3://gnn-in-lifesci-kdd2022/DeepFRI_GO_data/nrPDB-GO_2019.06.18_annot.tsv .

In [None]:
prot2annot, goterms, gonames, counts = load_GO_annot(
    'protein_data/DeepFRI_GO/nrPDB-GO_2019.06.18_annot.tsv'
)

In [None]:
for key, terms in goterms.items():
    print(key, len(terms))

In [None]:
for key, terms in counts.items():
    print(key, len(terms))

In [None]:
go_cc_meta = pd.DataFrame({
    'term': goterms['cc'],
    'name': gonames['cc'],
    'protein_counts': counts['cc']                          
}).set_index('term')
go_cc_meta.sort_values('protein_counts', ascending=False).head()