In [1]:
#!conda install -c conda-forge pytorch-lightning 

In [2]:
# autoreload 
%load_ext autoreload
%autoreload 2

seed = 1

from pathlib import Path
from functools import partial

from typing import List, Dict, Tuple, Union, Optional, Callable

import os
import pandas as pd 
import numpy as np 
np.random.seed(seed)

import h5py
import torch
import torch_geometric

from tqdm import tqdm

# Set random seed for torch 
torch.manual_seed(seed)


<torch._C.Generator at 0x7f4e58e60d70>

In [3]:
# Manually add T5 embeddings 
from graphein.utils.utils import annotate_node_features



#### Function definitions.

In [4]:
def get_pyg_from_uniprot(
    uniprot_id: str,
) -> torch_geometric.data.Data:
    """
    Creates a PyG Data object from a uniprot ID. 
    """
    

#### Generate a phosphosite graph with T5 embeddings.

In [5]:



# construct graph 
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph

config = ProteinGraphConfig()
config.dict()

{'granularity': 'CA',
 'keep_hets': [],
 'insertions': True,
 'alt_locs': 'max_occupancy',
 'pdb_dir': None,
 'verbose': False,
 'exclude_waters': True,
 'deprotonate': False,
 'protein_df_processing_functions': None,
 'edge_construction_functions': [<function graphein.protein.edges.distance.add_peptide_bonds(G: 'nx.Graph') -> 'nx.Graph'>],
 'node_metadata_functions': [<function graphein.protein.features.nodes.amino_acid.meiler_embedding(n: str, d: Dict[str, Any], return_array: bool = False) -> Union[pandas.core.series.Series, numpy.ndarray]>],
 'edge_metadata_functions': None,
 'graph_metadata_functions': None,
 'get_contacts_config': None,
 'dssp_config': None}

#### Load in PSP dataset.

In [6]:
psp_path = Path.home() / "STRUCTURAL_MOTIFS" / "DATA" / "PSP" / "Phosphorylation_site_dataset"
assert psp_path.is_file()

In [7]:
df = pd.read_csv(
    psp_path, 
    sep="\t",
    skiprows=3,
)
df = df[df.ORGANISM == "human"]
df.rename(columns={"ACC_ID": "uniprot_id", "MOD_RSD": "mod_rsd"}, inplace=True)
psp = df[["uniprot_id", "mod_rsd"]]

# only containing "-p"
psp = psp[psp.mod_rsd.str.contains("-p")]

# only containing S, T, Y
psp = psp[psp.mod_rsd.str.contains("S|T|Y")]
psp

Unnamed: 0,uniprot_id,mod_rsd
3,P31946,T2-p
5,P31946,S6-p
7,P31946,Y21-p
9,P31946,T32-p
10,P31946,S39-p
...,...,...
378765,Q8IYH5,S474-p
378766,Q8IYH5,S606-p
378768,Q8IYH5,Y670-p
378769,Q8IYH5,S677-p


In [8]:
# Test with EF1A = P68104
uid = "P68104"
psp[psp.uniprot_id == uid][0:2]


Unnamed: 0,uniprot_id,mod_rsd
93065,P68104,S21-p
93066,P68104,T22-p


#### Get embedding.

In [9]:
from phosphosite.protein.embeddings import get_embedding 
emb = get_embedding(uid)
emb.shape

(462, 1024)

In [10]:
from phosphosite.uniprot import sequence_dict
seq = sequence_dict[uid]
len(seq)

462

In [11]:
#g.nodes(data=True)['A:MET:1']

NameError: name 'g' is not defined

In [12]:
uniprot_id = "P68104"

In [13]:
pdb_dir = Path("pdb_structures")

from graphein.protein.config import ProteinGraphConfig
from graphein.protein.edges.distance import add_distance_threshold

NODE_DISTANCE_THRESHOLD = 6.0 # Å 
LONG_INTERACTION_THRESHOLD = 5 # 5 # How many sequence positions away can a node have its edges connected to it?


new_edge_funcs = {"edge_construction_functions": [
    partial(
    add_distance_threshold, long_interaction_threshold=LONG_INTERACTION_THRESHOLD, threshold=NODE_DISTANCE_THRESHOLD)
]}

config = ProteinGraphConfig(
    pdb_dir=pdb_dir,

    granularity="CA",

    # Node features
    #node_metadata_functions=[],

    # Edges based on thresholded distance 
    **new_edge_funcs,
)

af_format = "AF-{uniprot_id}-F1-model_v4.pdb"
pdb_path = pdb_dir / af_format.format(uniprot_id=uniprot_id)


""" CONSTRUCT GRAPH """
g = construct_graph(config=config, path=pdb_path, verbose=False)

from phosphosite.protein.embeddings import get_embedding 
from phosphosite.graphs.features import add_residue_embedding
emb = get_embedding(uniprot_id)
g = add_residue_embedding(g, emb, label="x")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


In [14]:
g.nodes(data=True)['A:MET:1']

{'chain_id': 'A',
 'residue_name': 'MET',
 'residue_number': 1,
 'atom_type': 'CA',
 'element_symbol': 'C',
 'coords': array([17.446,  9.974, 26.153]),
 'b_factor': 35.54,
 'meiler': dim_1    2.35
 dim_2    0.22
 dim_3    4.43
 dim_4    1.23
 dim_5    5.71
 dim_6    0.38
 dim_7    0.32
 Name: MET, dtype: float64,
 'x': array([ 0.05770874, -0.07293701, -0.09008789, ...,  0.10064697,
         0.40844727,  0.12176514], dtype=float32)}

In [15]:
""" CONVERT TO PYG """

from graphein.ml.conversion import GraphFormatConvertor
columns = [
        "b_factor",
        #"coords",

        "edge_index",
        "x", # T5 per-residue embedding

    ]
convertor = GraphFormatConvertor(
    src_format="nx", dst_format="pyg", verbose="gnn",
    columns=columns,
)
pyg = convertor(g)
assert type(pyg) is torch_geometric.data.Data

In [17]:
from phosphosite import PHOSPHOSITE_PREDICT_DIR
from graphein.protein.utils import download_alphafold_structure
#protein_path = download_alphafold_structure("Q8W3K0", out_dir=PHOSPHOSITE_PREDICT_DIR / "pdb_structures", aligned_score=False)

# TODO: fix this error in graphein!

In [18]:
psp.uniprot_id.unique()

array(['P31946', 'P62258', 'Q04917', ..., 'Q15942', 'O43149', 'Q8IYH5'],
      dtype=object)

In [42]:
uniprot_ids = ["P68104", "Q04917", "P31946", "P62258"]
for uniprot_id in uniprot_ids:
    protein_path = download_alphafold_structure(uniprot_id, out_dir=str(PHOSPHOSITE_PREDICT_DIR / "pdb_structures"), aligned_score=False)

In [43]:
# All of edge_index where node 9 is the source
i = pyg.node_id.index("A:ILE:10")
pyg.edge_index[:, pyg.edge_index[0] == i]

tensor([[  9,   9,   9,   9,   9,   9],
        [ 88,  89, 108, 109, 110, 111]])

In [44]:
# print dtype of every attribute of pyg 
for k, v in pyg:
    print(k, type(v), end=" ")
    if isinstance(v, torch.Tensor):
        print(v.dtype, end=" ")
    print()

x <class 'torch.Tensor'> torch.float32 
edge_index <class 'torch.Tensor'> torch.int64 
node_id <class 'list'> 
b_factor <class 'torch.Tensor'> torch.float64 
num_nodes <class 'int'> 


In [45]:
pyg.x[i].shape

torch.Size([1024])

In [46]:
# Add y labels 

In [47]:
# Validate that all node and edge indexes are correct
pyg.validate(raise_on_error=True)

True

In [48]:
pyg.num_node_features

1024

### Loss function masking

In [49]:
from phosphosite.ml import MaskedBinaryCrossEntropy, MaskedMSELoss

# toy example
input = torch.FloatTensor([.8, .5, .2])
target = torch.FloatTensor([1, 0, 0])
mask0 = torch.tensor([1, 1, 1])
mask1 = torch.tensor([0, 1, 1])
mask2 = torch.tensor([0, 0, 1])
mask3 = torch.tensor([1, 0, 1])

for loss_func in (MaskedMSELoss(), MaskedBinaryCrossEntropy()):
   
    print([loss_func(input, target, m) for m in (mask0, mask1, mask2, mask3)])



[tensor(0.1100), tensor(0.1450), tensor(0.0400), tensor(0.0400)]
[tensor(0.3798), tensor(0.4581), tensor(0.2231), tensor(0.2231)]


### Train 

In [50]:
# pytorch lightning from graphein 

### Validate that model architecture works with graph input.

In [51]:
pyg

Data(x=[462, 1024], edge_index=[2, 438], node_id=[462], b_factor=[462], num_nodes=462)

In [52]:
# Create simple databatch 
from torch_geometric.data import Data, Batch
batch = Batch.from_data_list([pyg])
batch

DataBatch(x=[462, 1024], edge_index=[2, 438], node_id=[1], b_factor=[462], num_nodes=462, batch=[462], ptr=[2])

In [53]:
# set dtype for all (problematic) tensors in pyg to be float32
pyg = pyg.apply(lambda x: x.to(torch.float32) if x.dtype == torch.float16 else x)
pyg.x.dtype

torch.float32

### Run model forward pass on batch.

In [54]:
pbd_dir = PHOSPHOSITE_PREDICT_DIR / "pdb_structures" 

from torch_geometric.data import Data, Batch
from phosphosite.model import PhosphoGAT
from phosphosite.graphs.pyg import get_pyg_graph

In [55]:
uniprot_ids

['P68104', 'Q04917', 'P31946', 'P62258']

In [56]:
func = partial(get_pyg_graph, pdb_dir=pdb_dir)
pyg_list = [func(uniprot_id) for uniprot_id in uniprot_ids]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


In [57]:
g = pyg_list[0]
len(pyg_list), g.num_node_features, g.num_nodes

(4, 1024, 462)

In [58]:
# Create simple databatch
batch = Batch.from_data_list(pyg_list)
batch

DataBatch(x=[1209, 1024], edge_index=[2, 611], node_id=[4], b_factor=[1209], num_nodes=1209, batch=[1209], ptr=[5])

In [85]:

model = PhosphoGAT()
model.eval()

y_hat1 = model(pyg)
y_hat1

y_hat = model(batch)
y_hat1.shape, y_hat.shape

(torch.Size([462, 1]), torch.Size([1209, 1]))

In [86]:
y_hat

tensor([[0.4600],
        [0.4729],
        [0.4683],
        ...,
        [0.4710],
        [0.5040],
        [0.4971]], grad_fn=<SigmoidBackward0>)

In [87]:
# Turn into binary class i.e. 0 or 1
# i.e. if y_hat > 0.5, then 1, else 0
y_pred = (y_hat > 0.5).int()
y_pred

tensor([[0],
        [0],
        [0],
        ...,
        [0],
        [1],
        [0]], dtype=torch.int32)

In [62]:
y_pred.shape

torch.Size([1209, 1])

In [84]:
from phosphosite.ml import MaskedBinaryCrossEntropy
loss_func = MaskedBinaryCrossEntropy()

y = torch.zeros_like(y_hat)
mask = torch.zeros_like(y_hat)

# Set indexes of mask to 1 
n = y.shape[0]
indexes = torch.tensor([0, n-3, n-2, n-1])
mask[indexes] = 1
mask.sum()

tensor(4.)

In [88]:
# y_labels is sparse label vector
y_labels = torch.tensor([1, 1, 1, 1])
for i, idx in enumerate(indexes):
    y[idx] = y_labels[i]

y

tensor([[1.],
        [0.],
        [0.],
        ...,
        [1.],
        [1.],
        [1.]])

In [89]:
y_hat

tensor([[0.4600],
        [0.4729],
        [0.4683],
        ...,
        [0.4710],
        [0.5040],
        [0.4971]], grad_fn=<SigmoidBackward0>)

In [73]:
mask

tensor([[1.],
        [0.],
        [0.],
        ...,
        [1.],
        [1.],
        [1.]])

In [76]:
loss = loss_func(y_hat, y, mask)
loss

tensor(0.6741, grad_fn=<DivBackward0>)

Change `y_hat` to be accurate on the first one...

In [90]:
y_hat_better = y_hat.clone()
y_hat_better[0, 0] = 1
loss = loss_func(y_hat_better, y, mask)
loss

tensor(0.5343, grad_fn=<DivBackward0>)

Change `mask` to only consider first two...

In [81]:
y_hat[0, 0] = 1
mask2 = torch.zeros_like(mask) 
mask2[indexes[0:2]] = 1
mask2.sum(), mask2


(tensor(2.),
 tensor([[1.],
         [0.],
         [0.],
         ...,
         [1.],
         [0.],
         [0.]]))

In [91]:
loss = loss_func(y_hat_better, y, mask2)
loss

tensor(0.3764, grad_fn=<DivBackward0>)

The loss here is even better, when considering *less* values with the mask.

In [93]:
y_hat[0, 0] = 1
loss = loss_func(y_hat, y, mask)
loss

tensor(0.5343, grad_fn=<DivBackward0>)

In [94]:
perfect_y_hat = y_hat.clone()
perfect_y_hat[indexes] = torch.tensor([1, 1, 1, 1], dtype=torch.float32).reshape(-1, 1) # Turn from (4,) to (4, 1)

new_mask = torch.zeros_like(mask) 
loss = loss_func(perfect_y_hat, y, mask)
loss

tensor(0., grad_fn=<DivBackward0>)

As expected, when all values are correct, we get 0 loss.  

What about when not all values are correct in the `y` positions; but the mask only covers the correct ones?

In [95]:
y_hat

tensor([[1.0000],
        [0.4729],
        [0.4683],
        ...,
        [0.4710],
        [0.5040],
        [0.4971]], grad_fn=<CopySlices>)

In [96]:
m1 = torch.zeros_like(mask)
m1[indexes[0]] = 1
m2 = torch.zeros_like(mask)
m2[indexes[1]] = 1
loss_func(y_hat, y, m1), loss_func(y_hat, y, m2)

(tensor(0., grad_fn=<DivBackward0>), tensor(0.7528, grad_fn=<DivBackward0>))

As expected, for the same `y_hat`, when the mask covers a 100% accurate prediction, you get 0 loss.

In [97]:
loss = loss_func(perfect_y_hat[indexes], y[indexes], mask[indexes])
loss

tensor(0., grad_fn=<DivBackward0>)

### Testing Binary cross entropy loss

In [383]:
y_hat[indexes]

tensor([[0.5156],
        [0.5218],
        [0.5211],
        [0.5292]], grad_fn=<IndexBackward0>)

In [384]:
y[indexes]

tensor([[1.],
        [1.],
        [1.],
        [1.]])

In [387]:

# binary cross entropy loss
from torch.nn import BCELoss
bce = BCELoss()
loss = bce(perfect_y_hat[indexes], y[indexes])
loss

tensor(0., grad_fn=<BinaryCrossEntropyBackward0>)