# Create Graph Dataset

In [None]:
import sys
import os
import pickle as pkl
import pandas as pd
import torch

path = os.path.join('..', '.')
if path not in sys.path:
    sys.path.append(os.path.abspath(path))

from src.protein_graph import pncaGraph

from tqdm import tqdm

In [3]:
train_seqs = pd.read_csv('../data/real_train_sequences.csv')
test_seqs = pd.read_csv('../data/real_test_sequences.csv')

### Create graphs and corresponding Data objects

Using AlphaFold predicted structures.

In [4]:
def create_graphs(structs_path, seqs):
    
    output_dict = {}
    
    for f in tqdm(os.listdir(structs_path)):
        
        index = f[:f.find('_')]
        name = 'pnca_mut_' + index

        pnca_m = pncaGraph(
                        pdb=f'{structs_path}/{f}',
                        lig_resname='PZA', 
                        self_loops=False,
                        cutoff_distance=12)
        
        metadata = seqs.iloc[[int(index)]]
        # display(metadata)

        output_dict[name] = {
            'graph':pnca_m, 
            'metadata':metadata
            }
        
    return output_dict

In [None]:
test_structs = '../pdb/test_pza'
train_structs = '../pdb/train_pza'

train_graph_dict = create_graphs(train_structs, train_seqs)
test_graph_dict = create_graphs(test_structs, test_seqs)

In [None]:
# attach sequence from fasta to each key in dictionary
# iterate through dict and creat one sample with gen_dataset per graph
# assign dataset object to dictionary


for sample in tqdm(test_graph_dict):
    test_graph_dict[sample]['graph'].gen_dataset(
        sequences= test_graph_dict[sample]['metadata'],
        edge_weights= 'exp',
        lambda_param=2,
        normalise=True
        )

for sample in tqdm(train_graph_dict):
    
    train_graph_dict[sample]['graph'].gen_dataset(
        sequences= train_graph_dict[sample]['metadata'],
        edge_weights= 'exp',
        lambda_param=2,
        normalise=True
        )

In [None]:
# check no nans in features

for sample in tqdm(test_graph_dict):
    assert torch.isnan(test_graph_dict[sample]['graph'].dataset[0].x).any() == False
    
for sample in tqdm(train_graph_dict):
    assert torch.isnan(train_graph_dict[sample]['graph'].dataset[0].x).any() == False

In [9]:
# Create MinMax scaler fit only on training data
from sklearn.preprocessing import MinMaxScaler

train_features = [train_graph_dict[sample]['graph'].dataset[0].x for sample in train_graph_dict]
train_feats_cat = torch.cat([x for x in train_features], dim=0)

# fit scaler and save
scaler = MinMaxScaler()
scaler.fit(train_feats_cat.numpy())


with open("../data/scaler.pkl", "wb") as f:
    pkl.dump(scaler, f)

In [10]:
# apply scaler to train and test features

for sample in train_graph_dict:
    train_graph_dict[sample]['graph'].dataset[0].x = torch.tensor(
        scaler.transform(train_graph_dict[sample]['graph'].dataset[0].x.numpy()), 
        dtype=torch.float
    )
    
for sample in test_graph_dict:
    test_graph_dict[sample]['graph'].dataset[0].x = torch.tensor(
        scaler.transform(test_graph_dict[sample]['graph'].dataset[0].x.numpy()), 
        dtype=torch.float
    )

In [11]:
graph_dict = {
    'train': train_graph_dict,
    'test': test_graph_dict
}

In [12]:
# save graph_dict as pickle  

with open('datasets/x18_exp_l2_graph_dict.pkl', 'wb') as f:
    pkl.dump(graph_dict, f)