In [1]:
import sys
import os
import pickle as pkl
import pandas as pd
import numpy as np
import torch

path = os.path.join('..', '/Users/dylandissanayake/Desktop/DPhil/Comp Disc/Repositories/TB-PNCA-GNN') if "SSH_CONNECTION" not in os.environ else os.path.join('..', '/mnt/alphafold-volume-1/dylan2/repos/tb-pnca-gnn')
if path not in sys.path:
    sys.path.append(os.path.abspath(path))

from src.protein_graph import pncaGraph
from src.run_model import pnca_GCN_vary_graph

from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

%aimport src

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_seqs = pd.read_csv('../data/real_train_sequences.csv')
test_seqs = pd.read_csv('../data/real_test_sequences.csv')

### Create graphs and corresponding Data objects

In [5]:
def create_graphs(structs_path, seqs):
    
    output_dict = {}
    
    for f in tqdm(os.listdir(structs_path)):
        
        index = f[:f.find('_')]
        name = 'pnca_mut_' + index

        pnca_m = pncaGraph(
                        pdb=f'../pdb/3PL1-PZA.pdb',
                        lig_resname='PZA', 
                        self_loops=False,
                        cutoff_distance=12)
        
        metadata = seqs.iloc[[int(index)]]
        # display(metadata)

        output_dict[name] = {
            'graph':pnca_m, 
            'metadata':metadata
            }
        
    return output_dict

In [6]:
test_structs = '../pdb/test_pza'
train_structs = '../pdb/train_pza'

# test_structs = '../pdb/test_structures_w_pza'
# train_structs = '../pdb/train_structures_w_pza'

train_graph_dict = create_graphs(train_structs, train_seqs)
test_graph_dict = create_graphs(test_structs, test_seqs)

  0%|          | 0/464 [00:00<?, ?it/s]

100%|██████████| 464/464 [00:30<00:00, 15.32it/s]
100%|██████████| 200/200 [00:13<00:00, 14.87it/s]


In [8]:
train_seqs

Unnamed: 0,phenotype_label,number_resistant_mutations,number_susceptible_mutations,allele,mutation
0,S,0,1,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...,A102V
1,S,0,1,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...,A134D
2,R,1,0,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...,A134P
3,S,0,1,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...,A134S
4,R,1,0,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...,A134V
...,...,...,...,...,...
459,S,0,1,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDHLAEAADYHHVVA...,Y34H
460,S,0,1,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...,Y64N
461,S,0,1,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...,Y95H
462,S,0,1,MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...,Y99C


In [12]:
# get node features for wildtype to fill nans

wt_seq = 'MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVATKDFHIDPGDHFSGTPDYSSSWPPHCVSGTPGADFHPSLDTSAIEAVFYKGAYTGAYSGFEGVDENGTPLLNWLRQRGVDEVDVVGIATDHCVRQTAEDAVRNGLATRVLVDLTAGVSADTTVAALEEMRTASVELVCS'
wt_df = pd.DataFrame([{
    'phenotype_label' : 'S',
    'number_resistant_mutations' : 0,
    'number_susceptible_mutations' : 0,
    'allele' : wt_seq
}])


pnca = pncaGraph(
                        pdb=f'../pdb/3PL1-PZA.pdb',
                        lig_resname='PZA', 
                        self_loops=False,
                        cutoff_distance=12)

pnca.gen_dataset(
        sequences= wt_df,
        edge_weights= 'exp',
        lambda_param=2,
        normalise=True
        )

In [15]:
torch.isnan(pnca.dataset[0].x).any()

tensor(False)

In [16]:
# attach sequence from fasta to each key in dictionary
# iterate through dict and creat one sample with gen_dataset per graph
# assign dataset object to dictionary
# will need to fix how to do CLAV dist


for sample in tqdm(test_graph_dict):
    test_graph_dict[sample]['graph'].gen_dataset(
        sequences= test_graph_dict[sample]['metadata'],
        edge_weights= 'exp',
        lambda_param=2,
        normalise=True
        )

for sample in tqdm(train_graph_dict):
    
    train_graph_dict[sample]['graph'].gen_dataset(
        sequences= train_graph_dict[sample]['metadata'],
        edge_weights= 'exp',
        lambda_param=2,
        normalise=True
        )

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [02:36<00:00,  1.28it/s]
100%|██████████| 464/464 [06:00<00:00,  1.29it/s]


In [18]:
# replace nans with wt features 
for sample in tqdm(test_graph_dict):
    
    nan_mask = torch.isnan(test_graph_dict[sample]['graph'].dataset[0].x)
    test_graph_dict[sample]['graph'].dataset[0].x[nan_mask] = pnca.dataset[0].x[nan_mask]
    
    assert torch.isnan(test_graph_dict[sample]['graph'].dataset[0].x).any() == False
    
for sample in tqdm(train_graph_dict):
    
    nan_mask = torch.isnan(train_graph_dict[sample]['graph'].dataset[0].x)
    train_graph_dict[sample]['graph'].dataset[0].x[nan_mask] = pnca.dataset[0].x[nan_mask]
    
    assert torch.isnan(train_graph_dict[sample]['graph'].dataset[0].x).any() == False

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:00<00:00, 27878.39it/s]
100%|██████████| 464/464 [00:00<00:00, 30453.91it/s]


In [19]:
# check it's generated the dataset properly

for sample in test_graph_dict:
    print(sample)
    print(test_graph_dict[sample]['graph'].dataset[0])
    print(test_graph_dict[sample])
    break

pnca_mut_99
Data(x=[185, 18], edge_index=[2, 5264], edge_attr=[5264, 1], y=1)
{'graph': <src.protein_graph.pncaGraph object at 0x2f8a14380>, 'metadata':    phenotype_label  number_resistant_mutations  number_susceptible_mutations  \
99               R                           1                             0   

                                               allele mutation  
99  MRALIIVDVQNDFCEGGSLAVTGGAALARAISDYLAEAADYHHVVA...     I90T  }


In [25]:
# Create MinMax scaler fit only on training data
from sklearn.preprocessing import MinMaxScaler

train_features = [train_graph_dict[sample]['graph'].dataset[0].x for sample in train_graph_dict]
train_feats_cat = torch.cat([x for x in train_features], dim=0)

# fit scaler and save
scaler = MinMaxScaler()
scaler.fit(train_feats_cat.numpy())

# with open("datasets/scaler_x18.pkl", "wb") as f:
#     pkl.dump(scaler, f)

In [26]:
# apply scaler to train and test features

for sample in train_graph_dict:
    train_graph_dict[sample]['graph'].dataset[0].x = torch.tensor(
        scaler.transform(train_graph_dict[sample]['graph'].dataset[0].x.numpy()), 
        dtype=torch.float
    )
    
for sample in test_graph_dict:
    test_graph_dict[sample]['graph'].dataset[0].x = torch.tensor(
        scaler.transform(test_graph_dict[sample]['graph'].dataset[0].x.numpy()), 
        dtype=torch.float
    )

In [29]:
torch.set_printoptions(precision=10)

In [31]:
train_graph_dict['pnca_mut_0']['graph'].dataset[0].x[101]

tensor([0.4764460027, 0.0869565308, 0.3253291845, 0.3992491066, 0.9666666389,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0961175859, 0.1682148278,
        0.8969879150, 0.0543248914, 0.0936825424, 0.0180733204, 0.8917525411,
        0.5736976862, 0.2387447655, 0.1145175770])

In [32]:
graph_dict = {
    'train': train_graph_dict,
    'test': test_graph_dict
}

In [33]:
# save graph_dict as pickle  

with open('datasets/no_af_x18_exp_l2_graph_dict.pkl', 'wb') as f:
    pkl.dump(graph_dict, f)

In [None]:
# compare to dataset with AF§

with open('datasets/x18_exp_l2_graph_dict.pkl', 'rb') as f:
    a = pkl.load(f)

In [36]:
a['train']['pnca_mut_0']['graph'].dataset[0].x[101]

tensor([0.4764460027, 0.0869565308, 0.3253291845, 0.3992491066, 0.9666666389,
        0.0000000000, 0.0000000000, 0.0000000000, 0.1151097417, 0.1654758751,
        0.9462261200, 0.0597532801, 0.0992178172, 0.0592027754, 0.8917525411,
        0.5736976862, 0.2387447655, 0.1145175770])