Necessary imports
dataset: Represents mutations
model: GNN


In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import sys
sys.path.append("./src")


from src.make_dataset import make_dataset
from src.dataset import MutationDataset
from src.model import ProBindNN
from src.train import train
from src.visualize import comparator


from torch_geometric.loader import DataLoader
import torch
from torch import nn
from torch.optim.lr_scheduler import ExponentialLR


import copy
import os
import time
from datetime import datetime



  from .autonotebook import tqdm as notebook_tqdm
To use the Graphein submodule graphein.protein.features.sequence.embeddings, you need to install: biovec 
biovec cannot be installed via conda
To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d 
To do so, use the following command: conda install -c pytorch3d pytorch3d


Make the dataset if needed

In [None]:
make_dataset(index_xlsx="../index.xlsx", root="../dataset")

Dataset/dataloaders

In [2]:
dataset = MutationDataset(index_xlsx="index.xlsx", root="dataset12aa")
train_size = int(len(dataset)*0.9)
val_size = len(dataset)-train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=True)
loaders = {"val_loader": val_loader, "train_loader":train_loader}

In [3]:
print("Length test dataset: ", len(val_dataset))
print("Length train dataset: ", len(train_dataset))
print("Take a look at the data: ", dataset[0])

Length test dataset:  495
Length train dataset:  4448
Take a look at the data:  {'mutated': Data(x=[71, 11], edge_index=[2, 109], edge_weights=[109], ddg=-1.47152859), 'non_mutated': Data(x=[71, 11], edge_index=[2, 109], edge_weights=[109])}


In [4]:
#CUDA/CPU Training
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Using cuda device


In [7]:

model = ProBindNN().to(device)

#Uncomment next line if you have  a pretrained model
model.load_state_dict(torch.load("models/model_2022_09_02_14_18_22.pt"))


<All keys matched successfully>

Optimizer and Loss Function and Scheduler

In [None]:
optimizer = torch.optim.Adam(model.parameters())
scheduler = ExponentialLR(optimizer, gamma=0.9)
loss_fn =  nn.MSELoss()

In [None]:
best_model, path = train(model, loaders, optimizer, loss_fn, scheduler, n_epochs=1500)

In [10]:
#change path to point to your model if available

#model.load_state_dict(torch.load(path))
 
#identifier for saved datapoints

t = time.time()
stamp = datetime.utcfromtimestamp(t).strftime('%Y_%m_%d_%H_%M_%S')

N = len(val_dataset)

comparator(model,train_dataset, N, "_data/predictions/{}_12aa.csv".format(stamp))

 10%|â–‰         | 48/495 [00:01<00:15, 28.04it/s]

If you want to take a look at the raw predictions_

In [8]:
model.eval()
d = next(iter(train_loader))
model(d["mutated"].cuda(), d["non_mutated"].cuda()).squeeze()

tensor(3.2161, device='cuda:0', grad_fn=<SqueezeBackward0>)

In [None]:
from graphein.protein.graphs import construct_graph
from graphein.ml.conversion import GraphFormatConvertor
from graphein.protein.edges.atomic import add_atomic_edges, add_bond_order, add_ring_status
from graphein.protein.edges.distance import add_hydrogen_bond_interactions, add_ionic_interactions, add_peptide_bonds
from graphein.protein.visualisation import plotly_protein_structure_graph
from graphein.protein.config import ProteinGraphConfig

params_to_change = {"granularity": "atom", "edge_construction_functions": [add_atomic_edges, add_bond_order, add_hydrogen_bond_interactions, add_ionic_interactions, add_peptide_bonds]}
config = ProteinGraphConfig(**params_to_change)
graph_mutated = construct_graph(config=config,pdb_path="datasetmf/raw/temp/1_mutated_interface.pdb")


p = plotly_protein_structure_graph(
    graph_mutated,
    colour_edges_by="kind",
    colour_nodes_by="element_symbol",
    label_node_ids=False,
    node_size_min=5,
    node_alpha=0.85,
    node_size_multiplier=1,
    plot_title="Atom-level graph. Nodes coloured by their Element"
    )
p.show()