In [26]:
#data manipulation
import math
import pandas as pd
import numpy as np

#Pytorch geometric
import torch
from torch import Tensor   
import torch_geometric
from torch_geometric.data import Data
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, MessagePassing
from torch.nn import Linear, CrossEntropyLoss, Sequential, ReLU, BCELoss 
import torch_geometric.transforms as T
import torch.nn.functional as F

from torch_geometric.nn import global_mean_pool as gap,  global_max_pool as gmp, global_add_pool as gsp


#rdkit
from rdkit import Chem                      
from rdkit.Chem import GetAdjacencyMatrix       
from scipy.sparse import coo_matrix
from rdkit.Chem import AllChem
from rdkit import Chem, DataStructs

#matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

#chemprop
import chemprop
from chemprop.args import TrainArgs, PredictArgs
from chemprop.train import cross_validate, run_training, make_predictions

#sklearn
import sklearn
from sklearn.metrics import classification_report, roc_auc_score

#shuffle
from random import shuffle

#for word embeddings
import re
import gensim
from torchtext.vocab import build_vocab_from_iterator

from torch.nn.utils.rnn import pad_sequence

#import the pretrained word-embeddings
from gensim.models import KeyedVectors
wv = KeyedVectors.load("wordvectors.kv", mmap='r')

#GPU
import gc

## Utility Functions

### Chem

In [27]:
def onek_encoding_unk(value, choices):
    """
    Creates a one-hot encoding with an extra category for uncommon values.

    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the :code:`value` in a list of length :code:`len(choices) + 1`.
             If :code:`value` is not in :code:`choices`, then the final element in the encoding is -1.
    """
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1

    return encoding

In [28]:
class Featurization_parameters:
    """
    A class holding molecule featurization parameters as attributes.
    """
    def __init__(self) -> None:

        # Atom feature sizes
        self.MAX_ATOMIC_NUM = 100
        #for one-hot-encoding
        self.ATOM_FEATURES = {
            'atomic_num': list(range(self.MAX_ATOMIC_NUM)),
            'degree': [0, 1, 2, 3, 4, 5],
            'formal_charge': [-1, -2, 1, 2, 0],
            'chiral_tag': [0, 1, 2, 3],
            'num_Hs': [0, 1, 2, 3, 4],
            'hybridization': [
                Chem.rdchem.HybridizationType.SP,
                Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3,
                Chem.rdchem.HybridizationType.SP3D,
                Chem.rdchem.HybridizationType.SP3D2
            ],
        }

        # Distance feature sizes
        self.PATH_DISTANCE_BINS = list(range(10))
        self.THREE_D_DISTANCE_MAX = 20
        self.THREE_D_DISTANCE_STEP = 1
        self.THREE_D_DISTANCE_BINS = list(range(0, self.THREE_D_DISTANCE_MAX + 1, self.THREE_D_DISTANCE_STEP))

        # len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
        self.ATOM_FDIM = sum(len(choices) + 1 for choices in self.ATOM_FEATURES.values()) + 2
        self.EXTRA_ATOM_FDIM = 0
        self.BOND_FDIM = 14
        self.EXTRA_BOND_FDIM = 0
        self.REACTION_MODE = None
        self.EXPLICIT_H = False
        self.REACTION = False

In [29]:
PARAMS = Featurization_parameters()

In [30]:
def atom_features(atom: Chem.rdchem.Atom, functional_groups=None):
    """
    Builds a feature vector for an atom.

    :param atom: An RDKit atom.
    :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
    :return: A list containing the atom features.
    """
    if atom is None:
        features = [0] * PARAMS.ATOM_FDIM
    else:
        features = onek_encoding_unk(atom.GetAtomicNum() - 1, PARAMS.ATOM_FEATURES['atomic_num']) + \
            onek_encoding_unk(atom.GetTotalDegree(), PARAMS.ATOM_FEATURES['degree']) + \
            onek_encoding_unk(atom.GetFormalCharge(), PARAMS.ATOM_FEATURES['formal_charge']) + \
            onek_encoding_unk(int(atom.GetChiralTag()), PARAMS.ATOM_FEATURES['chiral_tag']) + \
            onek_encoding_unk(int(atom.GetTotalNumHs()), PARAMS.ATOM_FEATURES['num_Hs']) + \
            onek_encoding_unk(int(atom.GetHybridization()), PARAMS.ATOM_FEATURES['hybridization']) + \
            [1 if atom.GetIsAromatic() else 0] + \
            [atom.GetMass() * 0.01]  # scaled to about the same range as other features
        if functional_groups is not None:
            features += functional_groups
    return features

In [31]:
def bond_features(bond: Chem.rdchem.Bond):
    """
    Builds a feature vector for a bond.

    :param bond: An RDKit bond.
    :return: A list containing the bond features.
    """
    if bond is None:
        fbond = [1] + [0] * (PARAMS.BOND_FDIM - 1)
    else:
        bt = bond.GetBondType()
        fbond = [
            0,  # bond is not None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
        fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
    return fbond

In [32]:
MORGAN_RADIUS = 2
MORGAN_NUM_BITS = 2048
#a vector representation (1x2048) for molecular feature 

In [33]:
def morgan_binary_features_generator(mol,
                                     radius: int = MORGAN_RADIUS,
                                     num_bits: int = MORGAN_NUM_BITS):
    """
    Generates a binary Morgan fingerprint for a molecule.
    :param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
    :param radius: Morgan fingerprint radius.
    :param num_bits: Number of bits in Morgan fingerprint.
    :return: A 1D numpy array containing the binary Morgan fingerprint.
    """
    mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
    features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=num_bits)
    features = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(features_vec, features)

    return features


### Data processing 

In [34]:
def data_process(dataset,batch_size):
    SMILES = dataset['SMILES']
    data_list = []
    for smiles in SMILES:
        mol = Chem.MolFromSmiles(smiles)     
        mol = Chem.AddHs(mol)  

        #generate a global vector features (binary Morgan fingerprint) and convert them
        mol_feature = torch.tensor(np.array(morgan_binary_features_generator(mol)))

        xs = []
        for atom in mol.GetAtoms():
            x = atom_features(atom)
            xs.append(x)
            
        x = torch.tensor(np.array(xs))
        
        edge_indices, edge_attrs = [], []
        
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
    
            e = bond_features(bond)

            edge_indices += [[i,j],[j,i]]
            edge_attrs += [e, e]
        
        edge_index = torch.tensor(edge_indices)
        edge_index = edge_index.t().to(torch.long).view(2, -1)
        edge_attr = torch.tensor(edge_attrs).view(-1, 14)
        
        y = torch.tensor(int(dataset.loc[dataset['SMILES'] == smiles,'Activity'])) #response variable y

        smi = smiles

        # add smiles and num_feature as the attributes
        data = Data(x=x, y=y, edge_index=edge_index,edge_attr = edge_attr, smiles=smi, mol_feature=mol_feature)  
        data_list.append(data)   # store processed data into the list
        
    return DataLoader(data_list,batch_size,shuffle=True)

### Training

In [35]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [36]:
def reset_weights(m):
  '''
    Try resetting model weights to avoid
    weight leakage.
  '''
  for layer in m.children():
    if hasattr(layer, 'reset_parameters'):
        print(f'Reset trainable parameters of layer = {layer}')
        layer.reset_parameters()

In [37]:
def print_metrics(metrics):
    AUC = [] 
    precision = []
    recall = []
    f1_score = []
    accuracy = []
    for i in metrics:
        AUC.append(i[0])
        precision.append(i[1]['weighted avg']['precision'])
        recall.append(i[1]['weighted avg']['recall'])
        f1_score.append(i[1]['weighted avg']['f1-score'])
        accuracy.append(i[1]['accuracy'])
    
    print('AUC:',np.mean(AUC),'+/-',np.std(AUC))
    print('Accuracy:',np.mean(accuracy),'+/-',np.std(accuracy))
    print('Precision:',np.mean(precision),'+/-',np.std(precision))
    print('Recall:',np.mean(recall),'+/-',np.std(recall))
    print('F1-score:',np.mean(f1_score),'+/-',np.std(f1_score))
    

In [38]:
def train(epoch,train_loader):
    
    model.train()   
    running_loss = 0 
    correct = 0
    total = 0
    criterion = BCELoss()
    for batch in train_loader:

        optimizer.zero_grad()
        outputs = model(batch)
        label = batch.y.view(-1,1)
        loss = criterion(outputs.float(),label.float())
        

        loss.backward()   # Compute the gradient of loss function 
        optimizer.step()  # Update parameters based on gradients.
        running_loss += loss.item()
        # probability that is larger than 0.5, classify as 1 

        pred = (outputs >= 0.5).float()

        total += label.size(0)
        correct += (pred == label).float().sum()
        
    
    loss = running_loss/len(train_loader)
    accuracy = 100*correct/total
    
    train_accuracy.append(accuracy)
    train_loss.append(loss)
    
    if epoch % 10 == 0:
        print('Epoch: '+str(int(epoch)))
        print('Train Loss: %.3f | Accuracy: %.3f'%(loss,accuracy))

In [39]:
def test(epoch,test_loader):
    model.eval()
    
    running_loss = 0
    correct = 0
    total = 0
    
    
    with torch.no_grad():
        criterion = BCELoss()
        for batch in test_loader:
        
            outputs = model(batch)
            label = batch.y.view(-1,1)

            loss = criterion(outputs.float(), label.float())    
            running_loss += loss.item()
            # probability that is larger than 0.5, classify as 1 
            pred = (outputs >= 0.5).float()

            total += label.size(0)
            correct += (pred == label).float().sum()
    
        loss = running_loss/len(test_loader)
        accuracy = 100*correct/total
    
        test_accuracy.append(accuracy)
        test_loss.append(loss)
        if epoch % 10 == 0:
            print('Test Loss: %.3f | Accuracy: %.3f'%(loss,accuracy))

In [40]:
#test_set as a whole loader
def test_metrics(test_loader):
    model.eval()

    with torch.no_grad():
        labels = []
        preds = []
        for batch in test_loader:
            
            labels += list(batch.y.view(-1,1).numpy())
            preds += list(model(batch).detach().numpy())
        
        pred_labels = [1 if i > 0.5 else 0 for i in preds]
        auc = roc_auc_score(list(labels), list(preds), average='weighted')
        report = classification_report(labels, pred_labels,output_dict=True)
        return auc, report
    

## Model construction for ECC

#### 1. ECC layer to update node(atom) feature vector of a graph(modelcue).
#### 2. Aggregate the updated node feature vector to capture global property
####     i.e. apply global_add_pool function over the node features 
#### 3. Then pass the processed features to a fully connected layer for binary classification. 

In [41]:
class NNConv_graph(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
        self.conv1 = NNConv(in_channels=133 ,out_channels=50, nn=Linear(14, 133*50))
        
        self.linear1 = Linear(50, 10)
        self.linear2 = Linear(10, 1)
        
    def forward(self,data):
        x, edge_index, batch_index, mol_feature, edge_attr = data.x, data.edge_index, data.batch, data.mol_feature, data.edge_attr
        x = self.conv1(x.double(), edge_index, edge_attr.double())
        
        x = gsp(x,batch_index)
        
        x = F.relu(self.linear1(x))
        x = torch.sigmoid(self.linear2(x))
        
        return x

## Model training and evaluation

### GraphSAGE model training and evalutation

In [42]:
train_losses = []
train_acc = []
test_acc = []
test_losses= []
metrics = []

for i in range(5):
    
    print('Split '+str(i+1)+' ......')
    train_set = pd.read_csv('C:/Users/jimmy/Desktop/FYP/train_split'+str(i+1)+'.csv')
    test_set = pd.read_csv('C:/Users/jimmy/Desktop/FYP/test_split'+str(i+1)+'.csv')
    train_loader = data_process(train_set,32)
    test_loader = data_process(test_set,len(test_set))
    
    model = NNConv_graph().double()
    model.apply(reset_weights)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=0.001)
    
    train_loss = []
    test_loss = []
    train_accuracy = []
    test_accuracy = []
    
    early_stopping = EarlyStopping(patience=10, verbose=True)
    #update the model parameters with 50 epochs
    for epoch in range(100):
        train(epoch,train_loader)
        test(epoch,test_loader)
        
        early_stopping(test_loss[-1],model)
        
        if early_stopping.early_stop:
            print('Early stopping')
            break
    
    model.load_state_dict(torch.load('checkpoint.pt'))
    auc, report = test_metrics(test_loader)
    metric = [auc,report]
    
    
    train_losses.append(train_loss)
    train_acc.append(train_accuracy)
    metrics.append(metric)
    
    


Split 1 ......
Reset trainable parameters of layer = Linear(in_features=14, out_features=6650, bias=True)
Reset trainable parameters of layer = Linear(133, 50, bias=False)
Reset trainable parameters of layer = NNConv(133, 50, aggr=add, nn=Linear(in_features=14, out_features=6650, bias=True))
Reset trainable parameters of layer = Linear(in_features=50, out_features=10, bias=True)
Reset trainable parameters of layer = Linear(in_features=10, out_features=1, bias=True)
Epoch: 0
Train Loss: 0.607 | Accuracy: 78.472
Test Loss: 0.548 | Accuracy: 83.333
Validation loss decreased (inf --> 0.548013).  Saving model ...
Validation loss decreased (0.548013 --> 0.517220).  Saving model ...
Validation loss decreased (0.517220 --> 0.498044).  Saving model ...
Validation loss decreased (0.498044 --> 0.489252).  Saving model ...
Validation loss decreased (0.489252 --> 0.479912).  Saving model ...
Validation loss decreased (0.479912 --> 0.467204).  Saving model ...
Validation loss decreased (0.467204 -->

EarlyStopping counter: 1 out of 10
Epoch: 30
Train Loss: 0.377 | Accuracy: 86.458
Test Loss: 0.372 | Accuracy: 85.417
Validation loss decreased (0.375447 --> 0.372213).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
Validation loss decreased (0.372213 --> 0.363583).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
Epoch: 40
Train Loss: 0.365 | Accuracy: 86.806
Test Loss: 0.365 | Accuracy: 88.889
EarlyStopping counter: 5 out of 10
EarlyStopping counter: 6 out of 10
Validation loss decreased (0.363583 --> 0.348076).  Saving model ...
Validation loss decreased (0.348076 --> 0.346638).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
EarlyStopping counter: 5 out of 1

Epoch: 40
Train Loss: 0.407 | Accuracy: 85.590
Test Loss: 0.448 | Accuracy: 82.639
Validation loss decreased (0.453546 --> 0.447669).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
Validation loss decreased (0.447669 --> 0.428338).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
Epoch: 50
Train Loss: 0.311 | Accuracy: 88.194
Test Loss: 0.422 | Accuracy: 84.028
Validation loss decreased (0.428338 --> 0.421970).  Saving model ...
EarlyStopping counter: 1 out of 10
Validation loss decreased (0.421970 --> 0.394153).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
EarlyStopping counter: 5 out of 10
EarlyStopping counter: 6 out of 10
EarlyStopping counter: 7 out of 1

In [43]:
print_metrics(metrics)

AUC: 0.8209722222222222 +/- 0.026067440948882144
Accuracy: 0.8777777777777777 +/- 0.020412414523193145
Precision: 0.8661339717571608 +/- 0.02565406421852644
Recall: 0.8777777777777777 +/- 0.020412414523193145
F1-score: 0.8672254982868836 +/- 0.023653096214976493
