In [None]:
## Standard libraries
import os
import json
import math
import numpy as np
import time
import sys

## Imports for plotting
#import matplotlib.pyplot as plt

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import pytorch_lightning as pl
import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data
from torch.nn import Linear, BatchNorm1d, ModuleList
from torch_geometric.nn import GATConv, GraphConv, TopKPooling, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.loader import DataLoader
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score
import mlflow.pytorch


## RDKit
import rdkit
from rdkit import Chem

In [None]:
class ProteinDataset(geom_data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        # root = where data set is stored
        super(ProteinDataset, self).__init__(root, transform, pre_transform)
        self.root = root

    @property
    def raw_file_names(self):

        return os.listdir(f'{self.root}/raw')

    @property
    def processed_file_names(self):
        '''
        If ALL of these files are found in processed_dir, processing is skipped.
        If one or more are NOT found, all files in raw_dir are processed again.
        '''
        inxs = []

        for pdb in self.raw_paths:
            inxs.append(pdb.split('/')[-1].split('.p')[0])

        return [f'{i}.pt' for i in inxs]

    def download(self):
        pass

    def process(self):

        #self.data = self.raw_paths

        for pdb in self.raw_paths:

            try:
                mol_obj = Chem.rdmolfiles.MolFromPDBFile(pdb)
            except AttributeError:
                os.remove(pdb)
                continue

            # Get node features
            node_feats = self._get_node_features(mol_obj)

            if node_feats == 'NaN':
                os.remove(pdb)
                continue

            # Get edge features
            edge_feats = self._get_edge_features(mol_obj)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol_obj)

            label = self._get_labels(pdb)

            # Create Data object
            data = geom_data.Data(x=node_feats,
                                  edge_index=edge_index,
                                  edge_attr=edge_feats,
                                  y=label)

            i = pdb.split('/')[-1].split('.p')[0]

            torch.save(data, os.path.join(self.processed_dir,f'{i}.pt'))

    def _get_node_features(self, mol):
        '''
        Returns a 2d array of shape:
        [Number of nodes, Node feature size]
        '''
        all_node_feats = []

        try:
            for atom in mol.GetAtoms():
                '''
                node_feats = []
                node_feats.append(dist)
                all_node_feats.append(node_feats)
                '''
                all_node_feats.append(atom.GetMass())
        except AttributeError:
            return 'NaN'
        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float).reshape([-1,1])

    def _get_edge_features(self, mol):
        '''
        Returns a 2d array of shape:
        [Number of edges, Edge feature size]
        '''
        all_edge_feats = []

        dists = Chem.rdmolops.Get3DDistanceMatrix(mol)

        # CA-CA Distances
        for bond in mol.GetBonds():
            '''
            edge_feats = []
            edge_feats.append(dist)
            all_edge_feats.append(edge_feats)
            '''
            begin = bond.GetBeginAtomIdx()
            end = bond.GetEndAtomIdx()

            all_edge_feats.append(dists[begin,end])

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float).reshape([-1,1])

    def _get_adjacency_info(self, mol):
        '''
        Need to look into this further
        '''
        adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(mol)
        row, col = np.where(adj_matrix)
        coo = np.array(list(zip(row, col)))
        coo = np.reshape(coo, (2, -1))
        return torch.tensor(coo, dtype=torch.long)

    def _get_labels(self, fn):

        with open(fn, 'r') as f:
            label = float(f.readline())
            f.close()

        label = np.asarray([label])
        return torch.tensor(label, dtype=torch.float)

    def len(self):
        return len(self.raw_paths)

    def get(self, inx):
        '''
        Same as '__getitem__'
        (Not needed for PyG's InMemoryDataset class)
        '''
        data = torch.load(self.processed_paths[inx])

        return data

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN, self).__init__()

        embedding_size = 1024

        # GNN Layers

        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.3)
        self.head1 = Linear(embedding_size*3, embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)

        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head2 = Linear(embedding_size*3, embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)

        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head3 = Linear(embedding_size*3, embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.2)

        # Linear Layers

        self.fc1 = Linear(embedding_size*2, 1024)
        self.fc2 = Linear(1024, 128)
        self.fc3 = Linear(128, 1)

    def forward(self, x, edge_attr, edge_index, batch_index):

        # Reshaping
        #print(x)
        #x = x.view(1, 188)
        #print(x)

        # First block
        x = self.conv1(x, edge_index).relu()
        x = self.head1(x)

        x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x,
                                                                 edge_index,
                                                                 None,
                                                                 batch_index)

        x1 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)

        # Second block
        x = self.conv2(x, edge_index).relu()
        x = self.head2(x)

        x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x,
                                                                 edge_index,
                                                                 None,
                                                                 batch_index)

        x2 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)

        # Third block
        x = self.conv3(x, edge_index).relu()
        x = self.head3(x)

        x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x,
                                                                 edge_index,
                                                                 None,
                                                                 batch_index)

        x3 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)

        # Concat pooled vectors
        x = x1 + x2 + x3

        # Apply Linear Layers
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)

        return x

    def initialize_weights(self, m):
        # parameter initialization
        if hasattr(m, 'weight') and m.weight.dim() > 1:
            nn.init.uniform_(m.weight.data)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
#%% Loading the dataset
train_dataset = ProteinDataset(root="data/lys50_2/train")
test_dataset = ProteinDataset(root="data/lys50_2/test")

#%% Loading the model
model = GNN(feature_size=train_dataset[0].x.shape[1])
model = model.to(device)
print(f"Number of parameters: {count_parameters(model)}")
model

In [None]:
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0000025)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [None]:
#%% Prepare training
NUM_GRAPHS_PER_BATCH = 10
train_loader = DataLoader(train_dataset,
                    batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
test_loader = DataLoader(test_dataset,
                         batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)

In [None]:
def train(epoch):
    # Enumerate over the data
    all_preds = []
    all_labels = []
    for _, batch in enumerate(train_loader):
        # Use GPU
        batch.to(device)
        # Reset gradients

        optimizer.zero_grad()
        # Passing the node features and the connection info
        pred = model(batch.x.float(),
                                batch.edge_attr.float(),
                                batch.edge_index,
                                batch.batch)
        # Calculating the loss and gradients
        loss = torch.sqrt(loss_fn(pred, batch.y))
        loss.backward()
        # Update using the gradients
        optimizer.step()

        all_preds.append(np.argmax(pred.cpu().detach().numpy(), axis=1))
        all_labels.append(batch.y.cpu().detach().numpy())
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    #calculate_metrics(all_preds, all_labels, epoch, "train")
    return loss

In [None]:
def test(epoch):
    all_preds = []
    all_labels = []
    for batch in test_loader:
        batch.to(device)
        pred = model(batch.x.float(),
                        batch.edge_attr.float(),
                        batch.edge_index,
                        batch.batch)
        loss = torch.sqrt(loss_fn(pred, batch.y))
        all_preds.append(np.argmax(pred.cpu().detach().numpy(), axis=1))
        all_labels.append(batch.y.cpu().detach().numpy())

    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    #calculate_metrics(all_preds, all_labels, epoch, "test")
    return loss

In [None]:
def calculate_metrics(y_pred, y_true, epoch, type):
    #print(f"\n Confusion matrix: \n {confusion_matrix(y_pred, y_true)}")
    #print(f"F1 Score: {f1_score(y_pred, y_true)}")
    #print(f"Accuracy: {accuracy_score(y_pred, y_true)}")
    #print(f"Precision: {precision_score(y_pred, y_true)}")
    #print(f"Recall: {recall_score(y_pred, y_true)}")
    try:
        roc = roc_auc_score(y_pred, y_true)
        print(f"ROC AUC: {roc}")
        mlflow.log_metric(key=f"ROC-AUC-{type}", value=float(roc), step=epoch)
    except:
        mlflow.log_metric(key=f"ROC-AUC-{type}", value=float(0), step=epoch)
        print(f"ROC AUC: notdefined")

In [None]:
# Loop for training
for epoch in range(500):
    model.train()
    loss = train(epoch=epoch).detach().cpu().numpy()
    if epoch % 10 == 0:
        model.eval()
        testloss = test(epoch=epoch).detach().cpu().numpy()
        print(epoch, loss, testloss)
    else:
        print(epoch, loss)

In [None]:
def main():

    dataset = ProteinDataset(root='data/lys50_2/train')
    for data in dataset:
        print(data.y)


if __name__ == '__main__':
    main()