In [1]:
# Autoreload 
%load_ext autoreload
%autoreload 2


In [13]:
import glob
import re
import time

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch 
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT
import torch_geometric 

import torch.nn as nn
import torch.nn.functional as F


from typing import Dict, List, Union
from pathlib import Path
from functools import reduce
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.callbacks import StochasticWeightAveraging, EarlyStopping, ModelCheckpoint

from phospholite import SAVED_MODEL_DIR, DATASET_DIR
from phospholite.model import PhosphoGAT
from phospholite.ml import get_dataloader_split
from phospholite.ml.dataset import PhosphositeGraphDataset 
from phospholite.utils.io import load_index_dict
from phospholite import INDEX_DICT_PATH

In [3]:
dataset_root_dir =  DATASET_DIR / "protein_graph_dataset"
root_dir = dataset_root_dir



In [43]:
verbose = True
processed_filenames = [Path(a).stem for a in glob.glob(str(root_dir / "processed" / "*.pt"))]
if verbose: print(f"Using {len(processed_filenames)} processed files.")

from graphein.protein.config import ProteinGraphConfig
from graphein.protein.edges.distance import add_distance_threshold
from functools import partial

long_interaction_threshold = 5 # seq positions 
edge_threshold_distance = 6.0 # Å
new_edge_funcs = {"edge_construction_functions": [
    partial(
    add_distance_threshold, long_interaction_threshold=long_interaction_threshold, threshold=edge_threshold_distance)
]}
config = ProteinGraphConfig(
    granularity="CA",
    **new_edge_funcs,
)
from graphein.ml.conversion import GraphFormatConvertor

columns = [
    "b_factor",
    "name",
    "edge_index",
    "x", # T5 per-residue embedding
]
convertor = GraphFormatConvertor(
    src_format="nx", dst_format="pyg", verbose="gnn",
    columns=columns,
)

# List of functions that consume a nx.Graph and return a nx.Graph. Applied to graphs after construction but before conversion to pyg
#from phosphosite.graphs.pyg import add_per_residue_embedding
graph_transforms = [
    #add_per_residue_embedding,
]

indexes_dict = load_index_dict(filepath=INDEX_DICT_PATH)

kwargs = dict(
    root=root_dir,
    graphein_config=config, 
    graph_transformation_funcs=graph_transforms,
    graph_format_convertor=convertor,
    pre_transform=None, # before saved to disk , after PyG conversion 
    pre_filter=None,    # whether it will be in final dataset
)
uniprot_ids_to_use = [u for u in processed_filenames if u in indexes_dict.keys()]
ds = PhosphositeGraphDataset(
    uniprot_ids=uniprot_ids_to_use,
    y_label_map=indexes_dict,
    **kwargs,
)
if verbose: print(ds)

train_loader, valid_loader, test_loader = get_dataloader_split(
    ds, batch_size=32, train_batch_size=32,
    #num_workers=num_workers,
)

Using 17067 processed files.
PhosphositeGraphDataset(17063)


In [44]:
len(train_loader)

342

In [6]:
dropout = 0.1 
batch_size = 64 # 100
num_heads = 8 # 4 
learning_rate = 0.001

model = PhosphoGAT(
    dropout=dropout,
    batch_size=batch_size,
    learning_rate=learning_rate,
)

In [12]:
for i, d in enumerate(ds):
    if not isinstance(d, torch_geometric.data.Data):
        print(i, d)
        break


NameError: name 'torch_geometric' is not defined

In [55]:
i = 118
'Q8TBC3' in [d.name for d in ds[i*batch_size:(i+1)*batch_size]]

False

In [56]:
for i, d in enumerate(ds): 
    if 'Q8TBC3' in d.name:
        break

KeyboardInterrupt: 

In [None]:
i

In [57]:
# Manually loading in from files 
from torch import Tensor 
y_sparse = Tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
        0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1.])

y_index = Tensor([  143,   160,   184,   186,   623,   636,   639,   646,   648,   661,
          162,   471,   681,   692,   695,   579,    92,   615,   612,     5,
           12,   483,   535,   571,   621,   387,   552,    53,    54,   424,
          346,   243,   456,    32,    31,   545,   433,   400,   329,   390,
           42,   430,   721,   722,   725,   727,   735,   740,  1012,   823,
          824,   816,   856,   851,  1005,   961,   924,  1093,  1069,  1076,
         1105,  1097,  1084,  1149,  1152,  1155,  1187,  1205,  1208,  1211,
         1214,  1235,  1236,  1244,  1256,  1265,  1274,  1462,  1481,  1543,
         1548,  1550,  1589,  1594,  1617,  1662,  1666,  1672,  1684,  1750,
         1769,  1842,  1858,  1892,  1897,  1918,  1960,  1965,  1969,  1970,
         1973,  1974,  2033,  2051,  2053,  2112,  2130,  2145,  2161,  2166,
         2237,  1153,  1217,  1238,  1277,  1379,  1557,  1657,  1658,  1663,
         2012,  2013,  2139,  2141,  2058,  2125,  2160,  2702,  1998,  1982,
         1984,  2002,  1996,  2004,  1713,  1987,  1361,  1781,  1362,  1993,
         2382,  2540,  2327,  2587,  2487,  2464,  2763,  2457,  2419,  2670,
         2500,  1930,  2394,  2370,  2435,  2483,  3021,  3011,  2999,  3199,
         3099,  3082,  3032,  3410,  3448,  3465,  3526,  3391,  3529,  3243,
         3244,  3349,  3727,  3789,  3790,  4068,  3867,  3696,  3710,  3712,
         3714,  3826,  3580,  3918,  4052,  4047,  3908,  3607,  3645,  3646,
         3673,  4061,  3734,  3806,  3556,  3872,  4161,  4168,  4176,  4179,
         4182,  4200,  4202,  4205,  4206,  4218,  4224,  4236,  4271,  4346,
         4352,  4353,  4356,  4382,  4409,  4452,  4465,  4478,  4583,  4586,
         4587,  4604,  4692,  4864,  5194,  5305,  5624,  5636,  5642,  5652,
         5681,  5690,  5703,  5798,  5804,  5807,  5817,  5824,  5826,  5994,
         4341,  4350,  4357,  4423,  4456,  4477,  4486,  4560,  4617,  4826,
         4948,  5286,  4233,  4466,  5626,  5966,  4574,  5562,  5360,  4802,
         4695,  4815,  4873,  4797,  4955,  4849,  5127,  4998,  5082,  5335,
         4368,  4363,  5869,  4370,  5896,  4293,  4501,  4671,  5319,  5104,
         5418,  5488,  4938,  5067,  5213,  5103,  4644,  4871,  4756,  4652,
         6057,  6258,  6313,  6184,  6176,  6233,  6029,  6030,  6360,  6035,
         6348,  6045,  6031,  6356,  6026,  6209,  6118,  6192,  6351,  6540,
         6388,  6383,  6372,  6419,  6693,  6849,  7007,  6899,  6906,  6938,
         6879,  6564,  6951,  6874,  7181,  6567,  6952,  6786,  6604,  6768,
         6629,  6740,  7008,  7177,  6969,  7093,  6972,  7021,  7052,  6798,
         7023,  6947,  7112,  7508,  7279,  7276,  7493,  7492,  7425,  7550,
         7542,  7540,  7567,  7370,  7241,  7378,  7664,  7662,  7632,  7599,
         7703,  7716,  7697,  7741,  7746,  7747,  7784,  7800,  7802,  7809,
         7822,  7823,  7833,  7851,  7851,  7883,  7885,  7911,  7934,  7948,
         7957,  8015,  8049,  8066,  8111,  8115,  8118,  8121,  8233,  8270,
         8273,  8275,  8279,  8282,  8387,  8397,  8456,  7808,  7834,  7838,
         8432,  8449,  8437,  8307,  8489,  8488,  8454,  8175,  8050,  8265,
         8300,  8020,  7829,  8513,  8950,  8989,  8980,  8981,  8997,  9123,
         8778,  9160,  8939,  9095,  8749,  9139,  8861,  8823,  8675,  8806,
         8688,  8683,  9214,  9318,  9302,  9369,  9441,  9272,  9363,  9275,
         9725,  9505,  9499,  9703,  9686,  9712,  9529,  9714,  9675,  9653,
         9781,  9788,  9912,  9951, 10082, 10184,  9793, 10028,  9940,  9904,
         9905, 10196,  9946, 10218, 10136, 10213, 10231, 10241,  9929, 10234,
         9803, 10219, 10077,  9785,  9807, 10358, 10359, 10364, 10355, 10547,
        10592, 10821, 10822, 10541, 10555, 10669, 10670, 10712, 10860, 10869,
        10871, 10900, 10395, 10386, 10624, 10317, 10537, 10382, 10619, 10610,
        10662, 10896, 10831, 10420, 10668, 11273, 11688, 11291, 11282, 11658,
        11722, 11640, 11037, 11051, 11336, 11133, 11016, 11447, 11453, 11054,
        11137, 11466, 11604, 11019, 11599, 11255, 11445, 11542, 11192, 11425,
        11388, 11770, 11795, 11996, 12015, 12051, 12078, 12156, 12218, 12230,
        12237, 12473, 12475, 12480, 12483, 12811, 12872, 12928, 12930, 12979,
        13005, 13013, 13061, 13063, 13065, 13068, 13069, 13072, 13098, 13103,
        13117, 13120, 13127, 13131, 13140, 13143, 13153, 13157, 13158, 13159,
        13166, 13185, 13192, 13215, 13218, 13235, 13237, 13240, 13242, 13257,
        13261, 13261, 13263, 13266, 13268, 13270, 13291, 12016, 12048, 12591,
        12664, 12878, 12885, 12923, 12971, 13010, 13038, 13040, 13090, 13093,
        13109, 13126, 13163, 13169, 13172, 13195, 13196, 13236, 13262, 11848,
        11917, 12094, 12284, 12288, 12378, 12571, 12596, 13287, 12366, 11816,
        12293, 12416, 12701, 12194, 12990, 11914, 12456, 12138, 11804, 11869,
        12607, 12658, 12086, 12684, 12765, 12686, 11984, 11947, 12357, 12356,
        12721, 12544, 13424, 13457, 13458, 13474, 13487, 13642, 13669, 13672,
        13456, 13489, 13588, 13668, 13671, 13715, 13603, 13324, 13786, 13561,
        13787, 13790, 13449, 13314, 13613, 14112, 14116, 14167, 14169, 14172,
        13921, 13995, 14080, 14073, 14043, 13841, 13846, 14001, 13899, 13873,
        13869, 13885, 13930, 13832, 14180, 14524, 14697, 14728, 14187, 14210,
        14518, 14201, 14319, 14266, 14789, 14185, 14688, 14340, 14478, 14370,
        14814, 14579, 14343, 14496, 14823, 14799, 14335, 14332, 14212, 14651,
        14547, 14973, 14869, 14899, 14956, 14858, 14940, 14995, 15015, 15030,
        15118, 15145, 15013, 15026, 15114, 15119, 15119, 15170, 15104, 15497,
        15546, 15564, 15330, 15321, 15690, 15753, 15334, 15898, 15591, 15592,
        15344, 15694, 15825, 15337, 15659, 15474, 15588, 15896, 15636, 15599,
        15535, 15556, 15204, 15303, 16128, 15951, 15929, 15931, 15909, 16088,
        16134, 16044, 16042, 16127, 16058, 16121, 16165, 16197, 16837, 16846,
        16903, 16914, 16921, 17107, 17113, 17114, 17115, 17119, 17128, 16697,
        16812, 17120, 16331, 17087, 16151, 16566, 16386, 16655, 17048, 17064,
        16231, 16598, 16662, 16226, 16218, 16791, 16867, 16322, 17576, 17480,
        17663, 17664, 17906, 17342, 17420, 17306, 17855, 17536, 17509, 17362,
        17633, 17257, 17483, 17442, 17353, 17317, 17786, 17764, 17734, 17729,
        17889, 17933, 17175, 17146, 17371, 17621, 17914, 17302, 17658, 17925,
        17805, 17594, 18220, 18309, 17972, 17977, 18144, 18278, 18194, 18241,
        18383, 18230, 18268, 18236, 18336, 18348, 17961, 18253, 18131, 17970,
        17971, 18175, 18018, 18180, 18029, 18017, 18095, 18179, 18038, 18314,
        17942, 18452, 18455, 18456, 18694, 18715, 18717, 18723, 18725, 18441,
        18699, 18724])

names = ['Q8TBC3', 'Q9Y4X0', 'Q9Y258', 'Q9Y4K1', 'Q8ND90', 'Q9NQ87', 'Q96SF7', 'Q9C0D5', 'P04001', 'Q8TD86', 'Q96JB8', 'P0CB48', 'Q6RSH7', 'P06401', 'Q9NZ20', 'Q9Y257', 'B0FP48', 'Q9ULM6', 'P16435', 'Q9Y5E2', 'P11388', 'O75886', 'P06729', 'Q96BT7', 'B2RV13', 'Q93062', 'Q5VTM2', 'Q96CG8', 'Q13472', 'Q7LBE3', 'P31644', 'Q13491']

In [61]:
uniprot_id_indexes = [ds.uniprot_ids.index(u) for u in names]
uniprot_id_indexes

[7074,
 2162,
 12275,
 14604,
 8965,
 8713,
 11533,
 4270,
 8500,
 7799,
 4953,
 9054,
 1254,
 11180,
 5356,
 11699,
 6754,
 16688,
 4069,
 10389,
 16493,
 9542,
 2999,
 10882,
 13990,
 373,
 3048,
 2580,
 8072,
 11543,
 7878,
 16273]

In [7]:
batch_dict = {}
stopping_num = 149 # epoch 149 / 341
for i, batch in enumerate(train_loader):
    
    if i > stopping_num - 5 and i < stopping_num + 5:
        batch_dict[i] = batch

AttributeError: 'tuple' object has no attribute 'name'

In [14]:
batch_dict

{145: DataBatch(x=[21487, 1024], edge_index=[2, 7114], node_id=[32], b_factor=[21487], name=[32], num_nodes=21487, y_index=[1062], y=[1062], batch=[21487], ptr=[33]),
 146: DataBatch(x=[17847, 1024], edge_index=[2, 7715], node_id=[32], b_factor=[17847], name=[32], num_nodes=17847, y_index=[833], y=[833], batch=[17847], ptr=[33]),
 147: DataBatch(x=[25468, 1024], edge_index=[2, 6496], node_id=[32], b_factor=[25468], name=[32], num_nodes=25468, y_index=[1336], y=[1336], batch=[25468], ptr=[33]),
 148: DataBatch(x=[13070, 1024], edge_index=[2, 4865], node_id=[32], b_factor=[13070], name=[32], num_nodes=13070, y_index=[632], y=[632], batch=[13070], ptr=[33]),
 149: DataBatch(x=[14740, 1024], edge_index=[2, 6137], node_id=[32], b_factor=[14740], name=[32], num_nodes=14740, y_index=[807], y=[807], batch=[14740], ptr=[33]),
 150: DataBatch(x=[17829, 1024], edge_index=[2, 4540], node_id=[32], b_factor=[17829], name=[32], num_nodes=17829, y_index=[794], y=[794], batch=[17829], ptr=[33]),
 151: 

In [38]:
for i, batch in batch_dict.items():
    print(i)
    print(batch.name[0:10])

145
['O75182', 'P52789', 'P33076', 'Q66K14', 'O15121', 'P50148', 'Q96PH1', 'P15382', 'Q5SQ64', 'Q5GH70']
146
['P05023', 'P00738', 'P30085', 'P84090', 'Q99767', 'Q8IY47', 'O43572', 'Q6PIJ6', 'Q3B726', 'Q15436']
147
['Q9BZ23', 'P35542', 'Q5T8I3', 'Q9HCJ0', 'Q96PF2', 'Q9GZN6', 'Q86UE4', 'A6H8Y1', 'Q9HAH7', 'Q9UQB9']
148
['P09912', 'P31213', 'P52735', 'Q9Y3A4', 'Q8NGP0', 'P0CE71', 'Q9H227', 'Q9H7E9', 'Q15291', 'P15260']
149
['O00238', 'P37108', 'P34932', 'Q96BY2', 'O95218', 'O75083', 'Q9BXY4', 'P0DTE5', 'Q86T24', 'A6NNA2']
150
['Q9NWT6', 'P20138', 'Q9UIH9', 'Q9BRK5', 'O43148', 'Q6UXU4', 'A6NHQ4', 'Q8WXH4', 'Q15032', 'Q15797']
151
['Q92769', 'P47211', 'Q9NPI7', 'P30939', 'Q8IWB1', 'Q6ZMQ8', 'P43243', 'P35222', 'O95626', 'Q02535']
152
['Q96H55', 'P49763', 'O43543', 'Q02040', 'Q9BYB4', 'P07199', 'Q6ZNW5', 'P52198', 'Q6UXS9', 'Q96QE3']
153
['Q96KC8', 'A8MXV6', 'Q13200', 'Q8N5H7', 'Q9BY66', 'Q86UP8', 'Q9H2G2', 'Q8IZY5', 'A8MW95', 'Q5VT66']


In [36]:
for i, batch in batch_dict.items():
    x = batch 
    y_sparse = x.y 
    y_index = x.y_index
    y_index = y_index.to(torch.long)
    y_hat = model(x)

    # Flatten 
    y_hat = torch.flatten(y_hat)
    y_sparse = torch.flatten(y_sparse)
    y_index = torch.flatten(y_index).detach().cpu()

    # Use `y_index` to only select the values that are in the mask
    y_hat = y_hat[y_index]



In [31]:
x = batch 
y_sparse = x.y 
y_index = x.y_index
y_index = y_index.to(torch.long)
y_hat = model(x)

# Flatten 
y_hat = torch.flatten(y_hat)
y_sparse = torch.flatten(y_sparse)
y_index = torch.flatten(y_index).detach().cpu()

# Use `y_index` to only select the values that are in the mask
y_hat = y_hat[y_index]

In [32]:
batch.y_index.shape

torch.Size([807])

In [33]:
y_index.shape

torch.Size([807])

In [34]:
y_hat.shape

torch.Size([807])

#### Manually create databatch that we have on NCI

In [39]:
uniprot_ids = ['Q8IYF3', 'Q92990', 'P0DMS8', 'P78545', 'Q9BY43', 'Q14D04', 'P0DJI9', 'O75843', 'Q86WX3', 'Q8N7X4', 'P52742', 'Q9HD45', 'Q70EL2', 'A6NFQ7', 'O43716', 'P06756', 'A6NJ46', 'Q9Y2R2', 'P20333', 'Q9NP66', 'Q69YN4', 'Q9NYL4', 'Q8NB16', 'P47928', 'Q9BUA6', 'Q8NGP0', 'P00736', 'Q9P2X3', 'Q8TBM7', 'O75290', 'P58753', 'Q12756']

In [42]:
# manually create databatch from uniprot_ids 

ds[0]

Data(x=[1311, 1024], edge_index=[2, 59], node_id=[1311], b_factor=[1311], name='P51816', num_nodes=1311, y_index=[54], y=[54])

### Test problem batch

In [94]:
from torch_geometric.data import Batch
problem_batch = Batch.from_data_list([ds[i] for i in uniprot_id_indexes])

def test_batch(
    batch = None, 
):
    
    y_sparse = batch.y 
    y_index = batch.y_index
    y_index = y_index.to(torch.long)
    y_hat = model(batch)

    # Flatten 
    y_hat = torch.flatten(y_hat)
    y_sparse = torch.flatten(y_sparse)
    y_index = torch.flatten(y_index).detach().cpu()

    # Use `y_index` to only select the values that are in the mask
    y_hat = y_hat[y_index]

    return y_hat, y_sparse

In [95]:
test_batch(problem_batch)

IndexError: index 18694 is out of bounds for dimension 0 with size 18663

In [96]:
batch = problem_batch 
y_sparse = batch.y 
y_index = batch.y_index
y_index = y_index.to(torch.long)
y_hat = model(x)



In [97]:
y_sparse.shape, y_index.shape, y_hat.shape, batch.x.shape

(torch.Size([872]),
 torch.Size([872]),
 torch.Size([18663, 1]),
 torch.Size([18663, 1024]))

In [98]:
sum([len(i) for i in batch.node_id])

18663

In [103]:
batch.ptr

tensor([    0,   707,  1040,  1134,  2857,  3210,  3538,  4140,  6001,  6365,
         6546,  7183,  7583,  7722,  8655,  9164,  9477,  9740, 10297, 10974,
        11767, 13298, 13823, 14174, 14838, 15002, 15198, 15901, 16144, 17145,
        17936, 18398, 18663])

In [99]:
batch.num_nodes

18663

In [104]:
y_hat.shape

torch.Size([18663, 1])

In [None]:
# Why the fuck are there more `y_hat` PREDICTIONS from running the model than there are actual `x` values?

In [100]:
max(batch.y_index)

tensor(18725.)

In [105]:
y_index

tensor([  143,   160,   184,   186,   623,   636,   639,   646,   648,   661,
          162,   471,   681,   692,   695,   579,    92,   615,   612,     5,
           12,   483,   535,   571,   621,   387,   552,    53,    54,   424,
          346,   243,   456,    32,    31,   545,   433,   400,   329,   390,
           42,   430,   721,   722,   725,   727,   735,   740,  1012,   823,
          824,   816,   856,   851,  1005,   961,   924,  1093,  1069,  1076,
         1105,  1097,  1084,  1149,  1152,  1155,  1187,  1205,  1208,  1211,
         1214,  1235,  1236,  1244,  1256,  1265,  1274,  1462,  1481,  1543,
         1548,  1550,  1589,  1594,  1617,  1662,  1666,  1672,  1684,  1750,
         1769,  1842,  1858,  1892,  1897,  1918,  1960,  1965,  1969,  1970,
         1973,  1974,  2033,  2051,  2053,  2112,  2130,  2145,  2161,  2166,
         2237,  1153,  1217,  1238,  1277,  1379,  1557,  1657,  1658,  1663,
         2012,  2013,  2139,  2141,  2058,  2125,  2160,  2702, 

In [102]:
x = batch.x 
batch

DataBatch(x=[18663, 1024], edge_index=[2, 7705], node_id=[32], b_factor=[18663], name=[32], num_nodes=18663, y_index=[872], y=[872], batch=[18663], ptr=[33])

In [106]:
# Flatten 
y_hat = torch.flatten(y_hat)
y_sparse = torch.flatten(y_sparse)
y_index = torch.flatten(y_index).detach().cpu()



In [108]:
y

NameError: name 'y' is not defined

In [111]:
# How many values are out of bounds ? 

for i, idx in enumerate(y_index):
    if int(idx) >= batch.num_nodes:
        print(i)

864
865
866
867
868
870
871


In [117]:
names[-1]

'Q13491'

In [116]:
y_index[871]

tensor(18724)

In [None]:
# # Use `y_index` to only select the values that are in the mask
y_hat = y_hat[y_index]

### Track down the culprit and see where the mismatch is

In [119]:
uid = "Q13491"

In [120]:
indexes_dict[uid]

{'idx': tensor([ 54.,  57.,  58., 296., 317., 319., 325., 327.,  43., 301., 326.]),
 'y': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])}

In [122]:
indexes_dict[uid]["idx"].shape, indexes_dict[uid]["y"].shape

(torch.Size([11]), torch.Size([11]))

In [124]:
from phosphosite.dataset import sequence_dict
seq = sequence_dict[uid]

In [125]:
len(seq)

265