In [1]:
#!conda install -c conda-forge pytorch-lightning 

TODO
- decouple the y labels and the graphs from eachother; so we can change / add / remove `y_labels` dynamically without affecting the fundamental dataset i.e. the directory containing the pre-processed graphs.  
- the dataset can thus load in all the graphs and assign the `y` to the `torch_geometric` `DataBatch` object.



Other TODOs
- implement dataset that loads from disk; as I don't think all graphs will be able to fit into memory. 

In [2]:
# autoreload 
%load_ext autoreload
%autoreload 2

seed = 1

from pathlib import Path
from functools import partial

from typing import List, Dict, Tuple, Union, Optional, Callable

import os
import pandas as pd 
import numpy as np 
np.random.seed(seed)

import h5py
import torch
import torch_geometric

from tqdm import tqdm

# Set random seed for torch 
torch.manual_seed(seed)


<torch._C.Generator at 0x7f4e58e60d70>

In [3]:
# Manually add T5 embeddings 
from graphein.utils.utils import annotate_node_features



#### Function definitions.

In [4]:
def get_pyg_from_uniprot(
    uniprot_id: str,
) -> torch_geometric.data.Data:
    """
    Creates a PyG Data object from a uniprot ID. 
    """
    

#### Generate a phosphosite graph with T5 embeddings.

In [5]:



# construct graph 
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph

config = ProteinGraphConfig()
config.dict()

{'granularity': 'CA',
 'keep_hets': [],
 'insertions': True,
 'alt_locs': 'max_occupancy',
 'pdb_dir': None,
 'verbose': False,
 'exclude_waters': True,
 'deprotonate': False,
 'protein_df_processing_functions': None,
 'edge_construction_functions': [<function graphein.protein.edges.distance.add_peptide_bonds(G: 'nx.Graph') -> 'nx.Graph'>],
 'node_metadata_functions': [<function graphein.protein.features.nodes.amino_acid.meiler_embedding(n: str, d: Dict[str, Any], return_array: bool = False) -> Union[pandas.core.series.Series, numpy.ndarray]>],
 'edge_metadata_functions': None,
 'graph_metadata_functions': None,
 'get_contacts_config': None,
 'dssp_config': None}

#### Load in PSP dataset.

In [6]:
psp_path = Path.home() / "STRUCTURAL_MOTIFS" / "DATA" / "PSP" / "Phosphorylation_site_dataset"
assert psp_path.is_file()

In [186]:
df = pd.read_csv(
    psp_path, 
    sep="\t",
    skiprows=3,
)
df = df[df.ORGANISM == "human"]
df.rename(columns={"ACC_ID": "uniprot_id", "MOD_RSD": "mod_rsd"}, inplace=True)
psp = df[["uniprot_id", "mod_rsd"]]

# only containing "-p"
psp = psp[psp.mod_rsd.str.contains("-p")]

# only containing S, T, Y
psp = psp[psp.mod_rsd.str.contains("S|T|Y")]

# remove isoforms 
psp["uniprot_id"] = psp.uniprot_id.str.split("-").str[0] 

# remove cases like Q9Y6M4_VAR_006
# i.e. filter out anything containing underscore 
psp = psp[~psp.uniprot_id.str.contains("_")]
psp

Unnamed: 0,uniprot_id,mod_rsd
3,P31946,T2-p
5,P31946,S6-p
7,P31946,Y21-p
9,P31946,T32-p
10,P31946,S39-p
...,...,...
378765,Q8IYH5,S474-p
378766,Q8IYH5,S606-p
378768,Q8IYH5,Y670-p
378769,Q8IYH5,S677-p


In [8]:
# Test with EF1A = P68104
uid = "P68104"
psp[psp.uniprot_id == uid][0:2]


Unnamed: 0,uniprot_id,mod_rsd
93065,P68104,S21-p
93066,P68104,T22-p


#### Get embedding.

In [9]:
from phosphosite.protein.embeddings import get_embedding 
emb = get_embedding(uid)
emb.shape

(462, 1024)

In [10]:
from phosphosite.uniprot import sequence_dict
seq = sequence_dict[uid]
len(seq)

462

In [11]:
#g.nodes(data=True)['A:MET:1']

NameError: name 'g' is not defined

In [12]:
uniprot_id = "P68104"

In [13]:
pdb_dir = Path("pdb_structures")

from graphein.protein.config import ProteinGraphConfig
from graphein.protein.edges.distance import add_distance_threshold

NODE_DISTANCE_THRESHOLD = 6.0 # Å 
LONG_INTERACTION_THRESHOLD = 5 # 5 # How many sequence positions away can a node have its edges connected to it?


new_edge_funcs = {"edge_construction_functions": [
    partial(
    add_distance_threshold, long_interaction_threshold=LONG_INTERACTION_THRESHOLD, threshold=NODE_DISTANCE_THRESHOLD)
]}

config = ProteinGraphConfig(
    pdb_dir=pdb_dir,

    granularity="CA",

    # Node features
    #node_metadata_functions=[],

    # Edges based on thresholded distance 
    **new_edge_funcs,
)

af_format = "AF-{uniprot_id}-F1-model_v4.pdb"
pdb_path = pdb_dir / af_format.format(uniprot_id=uniprot_id)


""" CONSTRUCT GRAPH """
g = construct_graph(config=config, path=pdb_path, verbose=False)

from phosphosite.protein.embeddings import get_embedding 
from phosphosite.graphs.features import add_residue_embedding
emb = get_embedding(uniprot_id)
g = add_residue_embedding(g, emb, label="x")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


In [14]:
g.nodes(data=True)['A:MET:1']

{'chain_id': 'A',
 'residue_name': 'MET',
 'residue_number': 1,
 'atom_type': 'CA',
 'element_symbol': 'C',
 'coords': array([17.446,  9.974, 26.153]),
 'b_factor': 35.54,
 'meiler': dim_1    2.35
 dim_2    0.22
 dim_3    4.43
 dim_4    1.23
 dim_5    5.71
 dim_6    0.38
 dim_7    0.32
 Name: MET, dtype: float64,
 'x': array([ 0.05770874, -0.07293701, -0.09008789, ...,  0.10064697,
         0.40844727,  0.12176514], dtype=float32)}

In [15]:
""" CONVERT TO PYG """

from graphein.ml.conversion import GraphFormatConvertor
columns = [
        "b_factor",
        #"coords",

        "edge_index",
        "x", # T5 per-residue embedding

    ]
convertor = GraphFormatConvertor(
    src_format="nx", dst_format="pyg", verbose="gnn",
    columns=columns,
)
pyg = convertor(g)
assert type(pyg) is torch_geometric.data.Data

In [17]:
from phosphosite import PHOSPHOSITE_PREDICT_DIR
from graphein.protein.utils import download_alphafold_structure
#protein_path = download_alphafold_structure("Q8W3K0", out_dir=PHOSPHOSITE_PREDICT_DIR / "pdb_structures", aligned_score=False)

# TODO: fix this error in graphein!

### Custom dataset class

In [163]:
print(uniprot_ids)

['P68104', 'Q04917', 'P31946', 'P62258']


### Get warnings to STFU

In [197]:
import graphein
graphein.verbose(enabled=False)
pd.options.mode.chained_assignment = None  # default='warn'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

##### Construct the dataset from list of PSP uniprot ids

In [211]:
from phosphosite.ml.dataset import PhosphoGraphDataset
all_psp_ids = list(psp.uniprot_id.unique())
outpath = PHOSPHOSITE_PREDICT_DIR / "dataset" / "raw" 
available_uniprot_ids = [
    u 
    for u in all_psp_ids
    if (outpath / f"{u}.pdb").is_file()
]


In [204]:
len(available_uniprot_ids)

1105

In [224]:
available_uniprot_ids

['P31946',
 'P62258',
 'Q04917',
 'P61981',
 'P31947',
 'P27348',
 'P63104',
 'Q13541',
 'Q13542',
 'O60516',
 'Q9NRA8',
 'P08908',
 'P28222',
 'P28221',
 'P28566',
 'P30939',
 'P28223',
 'P41595',
 'P46098',
 'O95264',
 'Q8WXA8',
 'Q70Z44',
 'A5X5Y0',
 'Q13639',
 'P47898',
 'P50406',
 'P34969',
 'P09917',
 'Q12888',
 'Q13625',
 'P05408',
 'Q7Z417',
 'P43652',
 'P10243',
 'Q9NWB1',
 'P01023',
 'A8K2U0',
 'Q9NPC4',
 'Q12792',
 'Q6IBS0',
 'Q16613',
 'Q9NRG9',
 'Q86V21',
 'P22760',
 'Q6PIU2',
 'Q6P093',
 'Q5VUY0',
 'Q5VUY2',
 'Q8N5Z0',
 'Q7RTV5',
 'Q6PD74',
 'Q2M2I8',
 'Q9H7C9',
 'Q13685',
 'Q9Y312',
 'P49588',
 'Q5JTZ9',
 'Q4L235',
 'Q9NRN7',
 'Q9UDR5',
 'Q9NY61',
 'P80404',
 'Q6AI08',
 'O95477',
 'Q8WWZ4',
 'Q86UK0',
 'Q9BZC7',
 'Q99758',
 'P78363',
 'Q8WWZ7',
 'Q8N139',
 'Q8IZY2',
 'O94911',
 'Q8IUA7',
 'P08183',
 'Q9NRK6',
 'O95342',
 'P21439',
 'Q2M3G0',
 'Q9NP58',
 'O75027',
 'Q9NUT2',
 'Q9NP78',
 'P33527',
 'Q5T3U5',
 'Q96J66',
 'Q96J65',
 'Q92887',
 'O15438',
 'O15439',
 'O15440',

In [216]:
will_fail = [uniprot_id, "P31946"]
will_fail

['Q7Z7G0', 'P31946']

In [221]:
dataset = PhosphoGraphDataset(root=PHOSPHOSITE_PREDICT_DIR / "dataset", uniprot_ids=will_fail)

Processing...


Q7Z7G0: Q7Z7G0: Embedding shape (1068, 1024) does not match graph size 1075.


Processing P31946: 100%|██████████| 2/2 [00:00<00:00,  2.63it/s]

Q7Z7G0: Q7Z7G0: Embedding shape (1068, 1024) does not match graph size 1075.
Successfully generated graphs for 1 out of 2 proteins.



Done!


In [232]:
PhosphoGraphDataset(root=PHOSPHOSITE_PREDICT_DIR / "dataset")

PhosphoGraphDataset(1093)

In [None]:
# add y_labels...

In [None]:
# Delete .pt files and start again ....

In [230]:
dataset = PhosphoGraphDataset(root=PHOSPHOSITE_PREDICT_DIR / "dataset", uniprot_ids=available_uniprot_ids)

Processing...


Q7Z7G0: Q7Z7G0: Embedding shape (1068, 1024) does not match graph size 1075.
Q3I5F7: Q3I5F7: Embedding shape (421, 1024) does not match graph size 207.
O43687: O43687: No embedding data.
Q53TS8: Q53TS8: Embedding shape (1820, 1024) does not match graph size 623.
P23109: P23109: Embedding shape (747, 1024) does not match graph size 780.
Q8N957: Q8N957: Embedding shape (1146, 1024) does not match graph size 763.
Q6NY19: Q6NY19: Embedding shape (821, 1024) does not match graph size 840.
B4E2M5: B4E2M5: Embedding shape (196, 1024) does not match graph size 251.
Q9UJX3: Q9UJX3: Embedding shape (565, 1024) does not match graph size 599.
Q8NFD5: Q8NFD5: Embedding shape (2319, 1024) does not match graph size 2236.
Q96Q27: Q96Q27: Embedding shape (635, 1024) does not match graph size 587.
Q96FT7: Q96FT7: Embedding shape (539, 1024) does not match graph size 647.


Processing P00519:  13%|█▎        | 146/1105 [00:55<03:34,  4.47it/s]

Q7Z7G0: Q7Z7G0: Embedding shape (1068, 1024) does not match graph size 1075.


Processing O14734:  18%|█▊        | 203/1105 [01:15<03:28,  4.32it/s]

Q3I5F7: Q3I5F7: Embedding shape (421, 1024) does not match graph size 207.


Processing O43823:  44%|████▍     | 486/1105 [02:54<03:44,  2.76it/s]

O43687: O43687: No embedding data.


Processing Q96Q35:  53%|█████▎    | 586/1105 [03:29<04:26,  1.95it/s]

Q53TS8: Q53TS8: Embedding shape (1820, 1024) does not match graph size 623.


Processing Q01433:  56%|█████▌    | 616/1105 [03:38<03:14,  2.52it/s]

P23109: P23109: Embedding shape (747, 1024) does not match graph size 780.


Processing Q9P2R3:  59%|█████▉    | 656/1105 [03:50<03:30,  2.13it/s]

Q8N957: Q8N957: Embedding shape (1146, 1024) does not match graph size 763.


Processing Q9ULJ7:  65%|██████▍   | 715/1105 [04:19<02:06,  3.08it/s]

Q6NY19: Q6NY19: Embedding shape (821, 1024) does not match graph size 840.


Processing Q96BM1:  66%|██████▌   | 728/1105 [04:23<01:25,  4.41it/s]

B4E2M5: B4E2M5: Embedding shape (196, 1024) does not match graph size 251.


Processing Q8J025:  75%|███████▍  | 828/1105 [04:58<01:40,  2.77it/s]

Q9UJX3: Q9UJX3: Embedding shape (565, 1024) does not match graph size 599.


Processing Q68CP9:  89%|████████▉ | 981/1105 [05:55<01:23,  1.49it/s]

Q8NFD5: Q8NFD5: Embedding shape (2319, 1024) does not match graph size 2236.


Processing Q9Y575:  99%|█████████▉| 1092/1105 [06:25<00:03,  3.43it/s]

Q96Q27: Q96Q27: Embedding shape (635, 1024) does not match graph size 587.


Processing Q9NY37: : 1114it [06:32,  3.35it/s]                        

Q96FT7: Q96FT7: Embedding shape (539, 1024) does not match graph size 647.


Processing P68104: 100%|██████████| 1105/1105 [06:46<00:00,  2.72it/s]


Successfully generated graphs for 1093 out of 1105 proteins.


Done!


In [231]:
len(dataset)

1093

In [None]:
# TODO: 
# generate list of uniprot_ids that are valid i.e. do not fail (embedding shape matches, etc.)

In [179]:
dataset

PhosphoGraphDataset(4)

In [18]:
psp.uniprot_id.unique()

array(['P31946', 'P62258', 'Q04917', ..., 'Q15942', 'O43149', 'Q8IYH5'],
      dtype=object)

In [42]:
uniprot_ids = ["P68104", "Q04917", "P31946", "P62258"]
for uniprot_id in uniprot_ids:
    protein_path = download_alphafold_structure(uniprot_id, out_dir=str(PHOSPHOSITE_PREDICT_DIR / "pdb_structures"), aligned_score=False)

In [175]:
# Download all structures
outpath = PHOSPHOSITE_PREDICT_DIR / "pdb_structures"

outpath = PHOSPHOSITE_PREDICT_DIR / "dataset" / "raw"
for uniprot_id in tqdm(psp.uniprot_id.unique()):
    download_alphafold_structure(uniprot_id, out_dir=str(outpath), aligned_score=False)

  0%|          | 0/19827 [00:00<?, ?it/s]

  0%|          | 1/19827 [00:00<5:26:27,  1.01it/s]

  0%|          | 3/19827 [00:01<1:45:05,  3.14it/s]

  0%|          | 4/19827 [00:01<1:21:10,  4.07it/s]

  0%|          | 6/19827 [00:01<54:07,  6.10it/s]  

  0%|          | 8/19827 [00:01<42:49,  7.71it/s]

  0%|          | 11/19827 [00:01<29:40, 11.13it/s]

  0%|          | 13/19827 [00:03<1:28:04,  3.75it/s]

  0%|          | 15/19827 [00:03<1:10:38,  4.67it/s]

  0%|          | 17/19827 [00:03<59:38,  5.54it/s]  

  0%|          | 19/19827 [00:03<50:25,  6.55it/s]

  0%|          | 21/19827 [00:04<1:32:45,  3.56it/s]

  0%|          | 23/19827 [00:04<1:13:28,  4.49it/s]

  0%|          | 24/19827 [00:05<1:23:55,  3.93it/s]

  0%|          | 26/19827 [00:05<1:20:11,  4.12it/s]

  0%|          | 27/19827 [00:06<2:11:20,  2.51it/s]

  0%|          | 28/19827 [00:07<2:13:01,  2.48it/s]

  0%|          | 30/19827 [00:07<1:30:34,  3.64it/s]

  0%|          | 32/19827 [00:07<1:07:01,  4.92it/s]

  0%|          | 34/19827 [00:08<1:39:29,  3.32it/s]

  0%|          | 36/19827 [00:08<1:15:33,  4.37it/s]

  0%|          | 37/19827 [00:08<1:08:14,  4.83it/s]

  0%|          | 38/19827 [00:09<2:07:52,  2.58it/s]

  0%|          | 39/19827 [00:10<2:54:20,  1.89it/s]

  0%|          | 40/19827 [00:10<2:20:09,  2.35it/s]

  0%|          | 42/19827 [00:11<1:32:47,  3.55it/s]

  0%|          | 43/19827 [00:12<2:30:15,  2.19it/s]

  0%|          | 44/19827 [00:13<3:10:20,  1.73it/s]

  0%|          | 45/19827 [00:13<2:50:53,  1.93it/s]

  0%|          | 46/19827 [00:13<2:35:09,  2.12it/s]

  0%|          | 47/19827 [00:14<3:17:59,  1.67it/s]

  0%|          | 48/19827 [00:15<3:50:50,  1.43it/s]

  0%|          | 49/19827 [00:16<4:21:08,  1.26it/s]

  0%|          | 51/19827 [00:16<2:35:25,  2.12it/s]

  0%|          | 53/19827 [00:17<2:03:34,  2.67it/s]

  0%|          | 54/19827 [00:18<2:57:23,  1.86it/s]

  0%|          | 55/19827 [00:18<2:23:08,  2.30it/s]

  0%|          | 56/19827 [00:18<1:55:37,  2.85it/s]

  0%|          | 57/19827 [00:19<2:50:14,  1.94it/s]

  0%|          | 58/19827 [00:20<3:30:25,  1.57it/s]

  0%|          | 59/19827 [00:21<4:10:17,  1.32it/s]

  0%|          | 61/19827 [00:21<2:28:55,  2.21it/s]

  0%|          | 62/19827 [00:22<2:20:43,  2.34it/s]

  0%|          | 63/19827 [00:22<1:53:02,  2.91it/s]

  0%|          | 65/19827 [00:22<1:13:30,  4.48it/s]

  0%|          | 66/19827 [00:22<1:26:46,  3.80it/s]

  0%|          | 67/19827 [00:23<2:27:51,  2.23it/s]

  0%|          | 69/19827 [00:23<1:40:34,  3.27it/s]

  0%|          | 70/19827 [00:25<2:37:44,  2.09it/s]

  0%|          | 71/19827 [00:25<2:27:01,  2.24it/s]

  0%|          | 72/19827 [00:26<3:18:19,  1.66it/s]

  0%|          | 73/19827 [00:26<2:54:18,  1.89it/s]

  0%|          | 74/19827 [00:26<2:16:40,  2.41it/s]

  0%|          | 75/19827 [00:27<2:10:25,  2.52it/s]

  0%|          | 76/19827 [00:28<3:10:40,  1.73it/s]

  0%|          | 77/19827 [00:28<2:47:56,  1.96it/s]

  0%|          | 78/19827 [00:29<3:41:09,  1.49it/s]

  0%|          | 79/19827 [00:30<4:14:26,  1.29it/s]

  0%|          | 80/19827 [00:30<3:10:13,  1.73it/s]

  0%|          | 81/19827 [00:31<3:58:28,  1.38it/s]

  0%|          | 83/19827 [00:32<3:27:29,  1.59it/s]

  0%|          | 85/19827 [00:34<3:16:47,  1.67it/s]

  0%|          | 87/19827 [00:34<2:12:39,  2.48it/s]

  0%|          | 88/19827 [00:35<2:48:26,  1.95it/s]

  0%|          | 89/19827 [00:35<3:19:49,  1.65it/s]

  0%|          | 90/19827 [00:36<2:42:11,  2.03it/s]

  0%|          | 91/19827 [00:36<2:10:33,  2.52it/s]

  0%|          | 93/19827 [00:36<1:24:38,  3.89it/s]

  0%|          | 95/19827 [00:36<59:47,  5.50it/s]  

  0%|          | 97/19827 [00:36<52:03,  6.32it/s]

  0%|          | 99/19827 [00:37<1:33:07,  3.53it/s]

  1%|          | 100/19827 [00:37<1:24:39,  3.88it/s]

  1%|          | 101/19827 [00:38<2:18:43,  2.37it/s]

  1%|          | 102/19827 [00:39<3:03:31,  1.79it/s]

  1%|          | 104/19827 [00:41<2:59:10,  1.83it/s]

  1%|          | 106/19827 [00:41<2:01:00,  2.72it/s]

  1%|          | 108/19827 [00:42<2:08:09,  2.56it/s]


KeyboardInterrupt: 

In [43]:
# All of edge_index where node 9 is the source
i = pyg.node_id.index("A:ILE:10")
pyg.edge_index[:, pyg.edge_index[0] == i]

tensor([[  9,   9,   9,   9,   9,   9],
        [ 88,  89, 108, 109, 110, 111]])

In [44]:
# print dtype of every attribute of pyg 
for k, v in pyg:
    print(k, type(v), end=" ")
    if isinstance(v, torch.Tensor):
        print(v.dtype, end=" ")
    print()

x <class 'torch.Tensor'> torch.float32 
edge_index <class 'torch.Tensor'> torch.int64 
node_id <class 'list'> 
b_factor <class 'torch.Tensor'> torch.float64 
num_nodes <class 'int'> 


In [45]:
pyg.x[i].shape

torch.Size([1024])

In [46]:
# Add y labels 

In [47]:
# Validate that all node and edge indexes are correct
pyg.validate(raise_on_error=True)

True

In [48]:
pyg.num_node_features

1024

### Loss function masking

In [49]:
from phosphosite.ml import MaskedBinaryCrossEntropy, MaskedMSELoss

# toy example
input = torch.FloatTensor([.8, .5, .2])
target = torch.FloatTensor([1, 0, 0])
mask0 = torch.tensor([1, 1, 1])
mask1 = torch.tensor([0, 1, 1])
mask2 = torch.tensor([0, 0, 1])
mask3 = torch.tensor([1, 0, 1])

for loss_func in (MaskedMSELoss(), MaskedBinaryCrossEntropy()):
   
    print([loss_func(input, target, m) for m in (mask0, mask1, mask2, mask3)])



[tensor(0.1100), tensor(0.1450), tensor(0.0400), tensor(0.0400)]
[tensor(0.3798), tensor(0.4581), tensor(0.2231), tensor(0.2231)]


### Train 

In [50]:
# pytorch lightning from graphein 

### Validate that model architecture works with graph input.

In [51]:
pyg

Data(x=[462, 1024], edge_index=[2, 438], node_id=[462], b_factor=[462], num_nodes=462)

In [52]:
# Create simple databatch 
from torch_geometric.data import Data, Batch
batch = Batch.from_data_list([pyg])
batch

DataBatch(x=[462, 1024], edge_index=[2, 438], node_id=[1], b_factor=[462], num_nodes=462, batch=[462], ptr=[2])

In [53]:
# set dtype for all (problematic) tensors in pyg to be float32
pyg = pyg.apply(lambda x: x.to(torch.float32) if x.dtype == torch.float16 else x)
pyg.x.dtype

torch.float32

### Run model forward pass on batch.

In [54]:
pbd_dir = PHOSPHOSITE_PREDICT_DIR / "pdb_structures" 

from torch_geometric.data import Data, Batch
from phosphosite.model import PhosphoGAT
from phosphosite.graphs.pyg import get_pyg_graph

In [55]:
uniprot_ids

['P68104', 'Q04917', 'P31946', 'P62258']

In [155]:
func = partial(get_pyg_graph, pdb_dir=pdb_dir)
pyg_list = [func(uniprot_id) for uniprot_id in uniprot_ids]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  idxs.loc[:, "end_idx"] = ends.line_idx.values
  df["node_id"] = df["node_id"].str.replace(":$", "")
  df["node_id"] = df["node_id"].str.replace(":$", "")


In [156]:
g = pyg_list[0]
len(pyg_list), g.num_node_features, g.num_nodes

(4, 1024, 462)

In [157]:
g

Data(x=[462, 1024], edge_index=[2, 438], node_id=[462], b_factor=[462], name='P68104', num_nodes=462)

In [158]:
# Create simple databatch
batch = Batch.from_data_list(pyg_list)
batch

DataBatch(x=[1209, 1024], edge_index=[2, 611], node_id=[4], b_factor=[1209], name=[4], num_nodes=1209, batch=[1209], ptr=[5])

In [160]:
batch.ptr

tensor([   0,  462,  708,  954, 1209])

In [85]:

model = PhosphoGAT()
model.eval()

y_hat1 = model(pyg)
y_hat1

y_hat = model(batch)
y_hat1.shape, y_hat.shape

(torch.Size([462, 1]), torch.Size([1209, 1]))

In [86]:
y_hat

tensor([[0.4600],
        [0.4729],
        [0.4683],
        ...,
        [0.4710],
        [0.5040],
        [0.4971]], grad_fn=<SigmoidBackward0>)

In [87]:
# Turn into binary class i.e. 0 or 1
# i.e. if y_hat > 0.5, then 1, else 0
y_pred = (y_hat > 0.5).int()
y_pred

tensor([[0],
        [0],
        [0],
        ...,
        [0],
        [1],
        [0]], dtype=torch.int32)

In [62]:
y_pred.shape

torch.Size([1209, 1])

In [84]:
from phosphosite.ml import MaskedBinaryCrossEntropy
loss_func = MaskedBinaryCrossEntropy()

y = torch.zeros_like(y_hat)
mask = torch.zeros_like(y_hat)

# Set indexes of mask to 1 
n = y.shape[0]
indexes = torch.tensor([0, n-3, n-2, n-1])
mask[indexes] = 1
mask.sum()

tensor(4.)

In [88]:
# y_labels is sparse label vector
y_labels = torch.tensor([1, 1, 1, 1])
for i, idx in enumerate(indexes):
    y[idx] = y_labels[i]

y

tensor([[1.],
        [0.],
        [0.],
        ...,
        [1.],
        [1.],
        [1.]])

In [89]:
y_hat

tensor([[0.4600],
        [0.4729],
        [0.4683],
        ...,
        [0.4710],
        [0.5040],
        [0.4971]], grad_fn=<SigmoidBackward0>)

In [73]:
mask

tensor([[1.],
        [0.],
        [0.],
        ...,
        [1.],
        [1.],
        [1.]])

In [76]:
loss = loss_func(y_hat, y, mask)
loss

tensor(0.6741, grad_fn=<DivBackward0>)

Change `y_hat` to be accurate on the first one...

In [90]:
y_hat_better = y_hat.clone()
y_hat_better[0, 0] = 1
loss = loss_func(y_hat_better, y, mask)
loss

tensor(0.5343, grad_fn=<DivBackward0>)

Change `mask` to only consider first two...

In [81]:
y_hat[0, 0] = 1
mask2 = torch.zeros_like(mask) 
mask2[indexes[0:2]] = 1
mask2.sum(), mask2


(tensor(2.),
 tensor([[1.],
         [0.],
         [0.],
         ...,
         [1.],
         [0.],
         [0.]]))

In [91]:
loss = loss_func(y_hat_better, y, mask2)
loss

tensor(0.3764, grad_fn=<DivBackward0>)

The loss here is even better, when considering *less* values with the mask.

In [93]:
y_hat[0, 0] = 1
loss = loss_func(y_hat, y, mask)
loss

tensor(0.5343, grad_fn=<DivBackward0>)

In [94]:
perfect_y_hat = y_hat.clone()
perfect_y_hat[indexes] = torch.tensor([1, 1, 1, 1], dtype=torch.float32).reshape(-1, 1) # Turn from (4,) to (4, 1)

new_mask = torch.zeros_like(mask) 
loss = loss_func(perfect_y_hat, y, mask)
loss

tensor(0., grad_fn=<DivBackward0>)

As expected, when all values are correct, we get 0 loss.  

What about when not all values are correct in the `y` positions; but the mask only covers the correct ones?

In [95]:
y_hat

tensor([[1.0000],
        [0.4729],
        [0.4683],
        ...,
        [0.4710],
        [0.5040],
        [0.4971]], grad_fn=<CopySlices>)

In [96]:
m1 = torch.zeros_like(mask)
m1[indexes[0]] = 1
m2 = torch.zeros_like(mask)
m2[indexes[1]] = 1
loss_func(y_hat, y, m1), loss_func(y_hat, y, m2)

(tensor(0., grad_fn=<DivBackward0>), tensor(0.7528, grad_fn=<DivBackward0>))

As expected, for the same `y_hat`, when the mask covers a 100% accurate prediction, you get 0 loss.

In [97]:
loss = loss_func(perfect_y_hat[indexes], y[indexes], mask[indexes])
loss

tensor(0., grad_fn=<DivBackward0>)

### Testing Binary cross entropy loss

In [383]:
y_hat[indexes]

tensor([[0.5156],
        [0.5218],
        [0.5211],
        [0.5292]], grad_fn=<IndexBackward0>)

In [384]:
y[indexes]

tensor([[1.],
        [1.],
        [1.],
        [1.]])

In [387]:

# binary cross entropy loss
from torch.nn import BCELoss
bce = BCELoss()
loss = bce(perfect_y_hat[indexes], y[indexes])
loss

tensor(0., grad_fn=<BinaryCrossEntropyBackward0>)

### Test accuracy metric

In [102]:
from torchmetrics import Accuracy
accuracy = Accuracy(task="binary")

accuracy(perfect_y_hat[indexes], y[indexes])

ImportError: cannot import name 'BinaryAccuracy' from 'torchmetrics' (/home/cim/anaconda3/envs/phosphosite_ml/lib/python3.9/site-packages/torchmetrics/__init__.py)

In [105]:
y_hat

tensor([[1.0000],
        [0.4729],
        [0.4683],
        ...,
        [0.4710],
        [0.5040],
        [0.4971]], grad_fn=<CopySlices>)

In [104]:
from phosphosite.ml import calculate_masked_accuracy
calculate_masked_accuracy(perfect_y_hat, y, mask), calculate_masked_accuracy(y_hat, y, mask)

(tensor(1.), tensor(0.5000))

In [101]:
accuracy(perfect_y_hat, y), accuracy(y_hat, y)

(tensor(0.5699), tensor(0.5682))

In [108]:
y_hat

tensor([[1.0000],
        [0.4729],
        [0.4683],
        ...,
        [0.4710],
        [0.5040],
        [0.4971]], grad_fn=<CopySlices>)

#### F1 Score

In [114]:
# change type to int 
y = y.int()
y

tensor([[1],
        [0],
        [0],
        ...,
        [1],
        [1],
        [1]], dtype=torch.int32)

tensor([[1],
        [0],
        [0],
        ...,
        [0],
        [1],
        [0]], dtype=torch.int32)

In [117]:
# f1 score 
from sklearn.metrics import f1_score 
y_pred = (y_hat > 0.5).int()



0.7217309359380047

In [121]:
indexes

tensor([   0, 1206, 1207, 1208])

In [144]:
# mask where element is 1
idxs = torch.nonzero(mask == 1, as_tuple=True)[0] # get first dimension only
idxs

masked_y = y[idxs]
masked_y_pred = y_pred[idxs]
masked_y_pred[3] = 1

f1_score(masked_y.detach().cpu().numpy(), masked_y_pred.detach().cpu().numpy(), average="weighted")

0.8571428571428571

In [142]:
masked_y_pred, masked_y

(tensor([[1],
         [0],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
         [1]], dtype=torch.int32))

In [153]:
from phosphosite.ml import calculate_masked_f1
calculate_masked_f1(y_hat, y, mask, average="micro")

0.5

In [154]:
from phosphosite import SAVED_MODEL_DIR
SAVED_MODEL_DIR

PosixPath('/home/cim/STRUCTURAL_MOTIFS/phosphosite/train/saved_models')