In [24]:
#data manipulation
import pandas as pd
import numpy as np

#Pytorch geometric
import torch   
import torch_geometric
from torch_geometric.data import Data
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv
from torch.nn import Linear, CrossEntropyLoss, BCELoss
import torch_geometric.transforms as T
import torch.nn.functional as F

from torch_geometric.nn import global_mean_pool as gap,  global_max_pool as gmp

#rdkit
from rdkit import Chem                      
from rdkit.Chem import GetAdjacencyMatrix       
from scipy.sparse import coo_matrix
from rdkit.Chem import AllChem
from rdkit import Chem, DataStructs

#matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

#sklearn
import sklearn
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.model_selection import StratifiedKFold

#shuffle
from random import shuffle

## Utility Functions

### Chem

In [25]:
def onek_encoding_unk(value, choices):
    """
    Creates a one-hot encoding with an extra category for uncommon values.

    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the :code:`value` in a list of length :code:`len(choices) + 1`.
             If :code:`value` is not in :code:`choices`, then the final element in the encoding is -1.
    """
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1

    return encoding

In [26]:
class Featurization_parameters:
    """
    A class holding molecule featurization parameters as attributes.
    """
    def __init__(self) -> None:

        # Atom feature sizes
        self.MAX_ATOMIC_NUM = 100
        #for one-hot-encoding
        self.ATOM_FEATURES = {
            'atomic_num': list(range(self.MAX_ATOMIC_NUM)),
            'degree': [0, 1, 2, 3, 4, 5],
            'formal_charge': [-1, -2, 1, 2, 0],
            'chiral_tag': [0, 1, 2, 3],
            'num_Hs': [0, 1, 2, 3, 4],
            'hybridization': [
                Chem.rdchem.HybridizationType.SP,
                Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3,
                Chem.rdchem.HybridizationType.SP3D,
                Chem.rdchem.HybridizationType.SP3D2
            ],
        }

        # Distance feature sizes
        self.PATH_DISTANCE_BINS = list(range(10))
        self.THREE_D_DISTANCE_MAX = 20
        self.THREE_D_DISTANCE_STEP = 1
        self.THREE_D_DISTANCE_BINS = list(range(0, self.THREE_D_DISTANCE_MAX + 1, self.THREE_D_DISTANCE_STEP))

        # len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
        self.ATOM_FDIM = sum(len(choices) + 1 for choices in self.ATOM_FEATURES.values()) + 2
        self.EXTRA_ATOM_FDIM = 0
        self.BOND_FDIM = 14
        self.EXTRA_BOND_FDIM = 0
        self.REACTION_MODE = None
        self.EXPLICIT_H = False
        self.REACTION = False

In [27]:
PARAMS = Featurization_parameters()

In [28]:
def atom_features(atom: Chem.rdchem.Atom, functional_groups=None):
    """
    Builds a feature vector for an atom.

    :param atom: An RDKit atom.
    :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
    :return: A list containing the atom features.
    """
    if atom is None:
        features = [0] * PARAMS.ATOM_FDIM
    else:
        features = onek_encoding_unk(atom.GetAtomicNum() - 1, PARAMS.ATOM_FEATURES['atomic_num']) + \
            onek_encoding_unk(atom.GetTotalDegree(), PARAMS.ATOM_FEATURES['degree']) + \
            onek_encoding_unk(atom.GetFormalCharge(), PARAMS.ATOM_FEATURES['formal_charge']) + \
            onek_encoding_unk(int(atom.GetChiralTag()), PARAMS.ATOM_FEATURES['chiral_tag']) + \
            onek_encoding_unk(int(atom.GetTotalNumHs()), PARAMS.ATOM_FEATURES['num_Hs']) + \
            onek_encoding_unk(int(atom.GetHybridization()), PARAMS.ATOM_FEATURES['hybridization']) + \
            [1 if atom.GetIsAromatic() else 0] + \
            [atom.GetMass() * 0.01]  # scaled to about the same range as other features
        if functional_groups is not None:
            features += functional_groups
    return features

In [29]:
def bond_features(bond: Chem.rdchem.Bond):
    """
    Builds a feature vector for a bond.

    :param bond: An RDKit bond.
    :return: A list containing the bond features.
    """
    if bond is None:
        fbond = [1] + [0] * (PARAMS.BOND_FDIM - 1)
    else:
        bt = bond.GetBondType()
        fbond = [
            0,  # bond is not None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
        fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
    return fbond

In [30]:
MORGAN_RADIUS = 2
MORGAN_NUM_BITS = 2048
#a vector representation (1x2048) for molecular feature 

In [31]:
def morgan_binary_features_generator(mol,
                                     radius: int = MORGAN_RADIUS,
                                     num_bits: int = MORGAN_NUM_BITS):
    """
    Generates a binary Morgan fingerprint for a molecule.
    :param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
    :param radius: Morgan fingerprint radius.
    :param num_bits: Number of bits in Morgan fingerprint.
    :return: A 1D numpy array containing the binary Morgan fingerprint.
    """
    mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
    features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=num_bits)
    features = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(features_vec, features)

    return features


### Data processing

In [32]:
def data_process(dataset,batch_size):
    SMILES = dataset['SMILES']
    data_list = []
    for smiles in SMILES:
        mol = Chem.MolFromSmiles(smiles)     
        mol = Chem.AddHs(mol)  

        #generate a global vector features (binary Morgan fingerprint) and convert them
        mol_feature = torch.tensor(np.array(morgan_binary_features_generator(mol)))

        xs = []
        for atom in mol.GetAtoms():
            x = atom_features(atom)
            xs.append(x)
            
        x = torch.tensor(np.array(xs))
        
        edge_indices, edge_attrs = [], []
        
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
    
            e = bond_features(bond)

            edge_indices += [[i,j],[j,i]]
            edge_attrs += [e, e]
        
        edge_index = torch.tensor(edge_indices)
        edge_index = edge_index.t().to(torch.long).view(2, -1)
        edge_attr = torch.tensor(edge_attrs).view(-1, 14)
        
        y = torch.tensor(int(dataset.loc[dataset['SMILES'] == smiles,'Activity'])) #response variable y

        smi = smiles

        # add smiles and num_feature as the attributes
        data = Data(x=x, y=y, edge_index=edge_index,edge_attr = edge_attr, smiles=smi, mol_feature=mol_feature)  
        data_list.append(data)   # store processed data into the list
        
    return DataLoader(data_list,batch_size,shuffle=True)

### Training

In [33]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [34]:
def reset_weights(m):
  '''
    Try resetting model weights to avoid
    weight leakage.
  '''
  for layer in m.children():
    if hasattr(layer, 'reset_parameters'):
        print(f'Reset trainable parameters of layer = {layer}')
        layer.reset_parameters()

In [35]:
def train(epoch,train_loader):
    
    model.train()   
    running_loss = 0 
    correct = 0
    total = 0
    criterion = BCELoss()
    for batch in train_loader:

        optimizer.zero_grad()
        outputs = model(batch)
        label = batch.y.view(-1,1)
        loss = criterion(outputs.float(),label.float())
        

        loss.backward()   # Compute the gradient of loss function 
        optimizer.step()  # Update parameters based on gradients.
        running_loss += loss.item()
        # probability that is larger than 0.5, classify as 1 

        pred = (outputs >= 0.5).float()

        total += label.size(0)
        correct += (pred == label).float().sum()
        
    
    loss = running_loss/len(train_loader)
    accuracy = 100*correct/total
    
    train_accuracy.append(accuracy)
    train_loss.append(loss)
    
    if epoch % 10 == 0:
        print('Epoch: '+str(int(epoch)))
        print('Train Loss: %.3f | Accuracy: %.3f'%(loss,accuracy))

In [36]:
def test(epoch,test_loader):
    model.eval()
    
    running_loss = 0
    correct = 0
    total = 0
    
    
    with torch.no_grad():
        criterion = BCELoss()
        for batch in test_loader:
        
            outputs = model(batch)
            label = batch.y.view(-1,1)

            loss = criterion(outputs.float(), label.float())    
            running_loss += loss.item()
            # probability that is larger than 0.5, classify as 1 
            pred = (outputs >= 0.5).float()

            total += label.size(0)
            correct += (pred == label).float().sum()
    
        loss = running_loss/len(test_loader)
        accuracy = 100*correct/total
    
        test_accuracy.append(accuracy)
        test_loss.append(loss)
        if epoch % 10 == 0:
            print('Test Loss: %.3f | Accuracy: %.3f'%(loss,accuracy))

In [37]:
#test_set as a whole loader
def test_metrics(test_loader):
    model.eval()

    with torch.no_grad():
        labels = []
        preds = []
        for batch in test_loader:
            
            labels += list(batch.y.view(-1,1).numpy())
            preds += list(model(batch).detach().numpy())
        
        pred_labels = [1 if i > 0.5 else 0 for i in preds]
        auc = roc_auc_score(list(labels), list(preds), average='weighted')
        report = classification_report(labels, pred_labels,output_dict=True)
        return auc, report
    

In [38]:
def print_metrics(metrics):
    AUC = [] 
    precision = []
    recall = []
    f1_score = []
    accuracy = []
    for i in metrics:
        AUC.append(i[0])
        precision.append(i[1]['weighted avg']['precision'])
        recall.append(i[1]['weighted avg']['recall'])
        f1_score.append(i[1]['weighted avg']['f1-score'])
        accuracy.append(i[1]['accuracy'])
    
    print('AUC:',np.mean(AUC),'+/-',np.std(AUC))
    print('Accuracy:',np.mean(accuracy),'+/-',np.std(accuracy))
    print('Precision:',np.mean(precision),'+/-',np.std(precision))
    print('Recall:',np.mean(recall),'+/-',np.std(recall))
    print('F1-score:',np.mean(f1_score),'+/-',np.std(f1_score))
    

## Model construction

#### 1. GAT layer to update node(atom) feature vector of a graph(modelcue).
#### 2. Aggregate the updated node feature vector to capture global property
####     i.e. apply global_mean_pool function over the node features 
#### 3. Concatnate the aggregated local features with binary morgan fingerprint. This vector will be the predictors of the model. 
#### 4. Then pass the processed features to a fully connected layer for binary classification. 

In [39]:
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
    
        #before the Attention Mechanism, parse in a fully connected network and output a 50 dimensional vector 
        self.hidden = 50  
        self.in_head = 5 #repeat the mechanism for 5 times
        
        self.conv1 = GATConv(in_channels = 133, 
                             out_channels = self.hidden, 
                             heads=self.in_head, concat=False) #set concat to be False, so it will take average instead
        
        #fully connected layers
        self.linear1 = Linear(self.hidden + 2048, 2048)
        self.linear2 = Linear(2048, 50)
        self.linear3 = Linear(50, 10)
        self.linear4 = Linear(10, 1)

    def forward(self, data):
        
        x, edge_index, batch_index, mol_feature = data.x, data.edge_index, data.batch, data.mol_feature        
        #extract node vectors, edge_index, batch index, and binary morgan fingerprint 

        x = self.conv1(x,edge_index)

        #aggregate the learned local node feature to capture the global property 
        #here, we apply global_mean_pool over the updated node feature
        x = gap(x,batch_index)
        #also include some other global information i.e. binary morgan fingerprint
        x = torch.cat([x, mol_feature.reshape(data.num_graphs,2048)],dim=1)

        # now train a fully connected graph classification network  
        x = F.relu(self.linear1(x))            
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = torch.sigmoid(self.linear4(x))
        
    
        return x

## Model training and evaluation

### GAT model training and evalutation

### 1:3

In [40]:
df = pd.read_csv('C:/Users/jimmy/Desktop/FYP/training_data.csv')
df = df[['SMILES','Activity']]

In [41]:
df = pd.read_csv('C:/Users/jimmy/Desktop/FYP/training_data.csv')
pos = df[df['Activity'] == 1]
neg = df[df['Activity'] == 0]

# postivie: negative = 1: 3
full = pd.concat([neg.sample(3*len(pos)),pos])
full = full[['SMILES','Activity']]

In [42]:
train_losses = []
train_acc = []
test_acc = []
test_losses= []
metrics = []

skf = StratifiedKFold(n_splits=5,shuffle=True)
i = 0
for train_idx, test_idx in skf.split(full['SMILES'],full['Activity']):
    
    print('Split '+str(i+1)+' ......')
    train_set = full.iloc[train_idx]
    test_set = full.iloc[test_idx]
    train_loader = data_process(train_set,32)
    test_loader = data_process(test_set,len(test_set))
    
    model = GAT().double()
    model.apply(reset_weights)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=0.001)
    
    train_loss = []
    test_loss = []
    train_accuracy = []
    test_accuracy = []
    
    early_stopping = EarlyStopping(patience=10, verbose=True)
    #update the model parameters with 50 epochs
    for epoch in range(100):
        train(epoch,train_loader)
        test(epoch,test_loader)
        
        early_stopping(test_loss[-1],model)
        
        if early_stopping.early_stop:
            print('Early stopping')
            break
    
    model.load_state_dict(torch.load('checkpoint.pt'))
    auc, report = test_metrics(test_loader)
    metric = [auc,report]
    
    
    train_losses.append(train_loss)
    train_acc.append(train_accuracy)
    metrics.append(metric)
    
    i += 1
    
    


Split 1 ......
Reset trainable parameters of layer = Linear(133, 250, bias=False)
Reset trainable parameters of layer = GATConv(133, 50, heads=5)
Reset trainable parameters of layer = Linear(in_features=2098, out_features=2048, bias=True)
Reset trainable parameters of layer = Linear(in_features=2048, out_features=50, bias=True)
Reset trainable parameters of layer = Linear(in_features=50, out_features=10, bias=True)
Reset trainable parameters of layer = Linear(in_features=10, out_features=1, bias=True)
Epoch: 0
Train Loss: 0.615 | Accuracy: 63.802
Test Loss: 0.472 | Accuracy: 75.000
Validation loss decreased (inf --> 0.472192).  Saving model ...
Validation loss decreased (0.472192 --> 0.361444).  Saving model ...
Validation loss decreased (0.361444 --> 0.337647).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
Validation loss decreased (0.337647 --> 0.336577).  Saving model ...


In [43]:
print_metrics(metrics)

AUC: 0.8778935185185187 +/- 0.02477230414130137
Accuracy: 0.8791666666666668 +/- 0.019320038532282726
Precision: 0.8771581731278764 +/- 0.018817928938836766
Recall: 0.8791666666666668 +/- 0.019320038532282726
F1-score: 0.8770742994705367 +/- 0.018495409930806166


### 1:5

In [44]:
# postivie: negative = 1: 5
full = pd.concat([neg.sample(5*len(pos)),pos])
full = full[['SMILES','Activity']]

In [45]:
train_losses = []
train_acc = []
test_acc = []
test_losses= []
metrics = []

skf = StratifiedKFold(n_splits=5,shuffle=True)
i = 0
for train_idx, test_idx in skf.split(full['SMILES'],full['Activity']):
    
    print('Split '+str(i+1)+' ......')
    train_set = full.iloc[train_idx]
    test_set = full.iloc[test_idx]
    train_loader = data_process(train_set,32)
    test_loader = data_process(test_set,len(test_set))
    
    model = GAT().double()
    model.apply(reset_weights)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=0.001)
    
    train_loss = []
    test_loss = []
    train_accuracy = []
    test_accuracy = []
    
    early_stopping = EarlyStopping(patience=10, verbose=True)
    #update the model parameters with 50 epochs
    for epoch in range(100):
        train(epoch,train_loader)
        test(epoch,test_loader)
        
        early_stopping(test_loss[-1],model)
        
        if early_stopping.early_stop:
            print('Early stopping')
            break
    
    model.load_state_dict(torch.load('checkpoint.pt'))
    auc, report = test_metrics(test_loader)
    metric = [auc,report]
    
    
    train_losses.append(train_loss)
    train_acc.append(train_accuracy)
    metrics.append(metric)
    
    i += 1
    
    


Split 1 ......
Reset trainable parameters of layer = Linear(133, 250, bias=False)
Reset trainable parameters of layer = GATConv(133, 50, heads=5)
Reset trainable parameters of layer = Linear(in_features=2098, out_features=2048, bias=True)
Reset trainable parameters of layer = Linear(in_features=2048, out_features=50, bias=True)
Reset trainable parameters of layer = Linear(in_features=50, out_features=10, bias=True)
Reset trainable parameters of layer = Linear(in_features=10, out_features=1, bias=True)
Epoch: 0
Train Loss: 0.638 | Accuracy: 53.472
Test Loss: 0.431 | Accuracy: 83.333
Validation loss decreased (inf --> 0.431129).  Saving model ...
Validation loss decreased (0.431129 --> 0.314351).  Saving model ...
Validation loss decreased (0.314351 --> 0.295263).  Saving model ...
Validation loss decreased (0.295263 --> 0.277856).  Saving model ...
Validation loss decreased (0.277856 --> 0.269210).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
E

In [46]:
print_metrics(metrics)

AUC: 0.8927083333333334 +/- 0.015782386357449284
Accuracy: 0.9083333333333334 +/- 0.009212846639876114
Precision: 0.9055531341326818 +/- 0.010170564715767197
Recall: 0.9083333333333334 +/- 0.009212846639876114
F1-score: 0.9013563897466771 +/- 0.009666139175034885


### 1:10

In [47]:
# postivie: negative = 1: 10
full = pd.concat([neg.sample(10*len(pos)),pos])
full = full[['SMILES','Activity']]

In [48]:
train_losses = []
train_acc = []
test_acc = []
test_losses= []
metrics = []
#用于储存每个每一折的AUC记录
skf = StratifiedKFold(n_splits=5,shuffle=True)
i = 0
for train_idx, test_idx in skf.split(full['SMILES'],full['Activity']):
    
    print('Split '+str(i+1)+' ......')
    train_set = full.iloc[train_idx]
    test_set = full.iloc[test_idx]
    train_loader = data_process(train_set,32)
    test_loader = data_process(test_set,len(test_set))
    
    model = GAT().double()
    model.apply(reset_weights)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=0.001)
    
    train_loss = []
    test_loss = []
    train_accuracy = []
    test_accuracy = []
    
    early_stopping = EarlyStopping(patience=10, verbose=True)
    #update the model parameters with 50 epochs
    for epoch in range(100):
        train(epoch,train_loader)
        test(epoch,test_loader)
        
        early_stopping(test_loss[-1],model)
        
        if early_stopping.early_stop:
            print('Early stopping')
            break
    
    model.load_state_dict(torch.load('checkpoint.pt'))
    auc, report = test_metrics(test_loader)
    metric = [auc,report]
    
    
    train_losses.append(train_loss)
    train_acc.append(train_accuracy)
    metrics.append(metric)
    
    i += 1
    
    


Split 1 ......
Reset trainable parameters of layer = Linear(133, 250, bias=False)
Reset trainable parameters of layer = GATConv(133, 50, heads=5)
Reset trainable parameters of layer = Linear(in_features=2098, out_features=2048, bias=True)
Reset trainable parameters of layer = Linear(in_features=2048, out_features=50, bias=True)
Reset trainable parameters of layer = Linear(in_features=50, out_features=10, bias=True)
Reset trainable parameters of layer = Linear(in_features=10, out_features=1, bias=True)
Epoch: 0
Train Loss: 0.483 | Accuracy: 71.875
Test Loss: 0.324 | Accuracy: 90.909
Validation loss decreased (inf --> 0.324160).  Saving model ...
Validation loss decreased (0.324160 --> 0.235950).  Saving model ...
Validation loss decreased (0.235950 --> 0.197766).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
EarlyStopping counter: 5 out of 10
EarlyStopping counter: 6 out of 10

In [49]:
print_metrics(metrics)

AUC: 0.8765277777777778 +/- 0.0340446740930933
Accuracy: 0.9431818181818181 +/- 0.011236664374387388
Precision: 0.9414939578212858 +/- 0.011038968128222656
Recall: 0.9431818181818181 +/- 0.011236664374387388
F1-score: 0.9394156232890758 +/- 0.01172392990927586


### 120:2215

In [50]:
full = df[['SMILES','Activity']]

In [51]:
train_losses = []
train_acc = []
test_acc = []
test_losses= []
metrics = []
#用于储存每个每一折的AUC记录
skf = StratifiedKFold(n_splits=5,shuffle=True)
i = 0
for train_idx, test_idx in skf.split(full['SMILES'],full['Activity']):
    
    print('Split '+str(i+1)+' ......')
    train_set = full.iloc[train_idx]
    test_set = full.iloc[test_idx]
    train_loader = data_process(train_set,32)
    test_loader = data_process(test_set,len(test_set))
    
    model = GAT().double()
    model.apply(reset_weights)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=0.001)
    
    train_loss = []
    test_loss = []
    train_accuracy = []
    test_accuracy = []
    
    early_stopping = EarlyStopping(patience=10, verbose=True)
    #update the model parameters with 50 epochs
    for epoch in range(100):
        train(epoch,train_loader)
        test(epoch,test_loader)
        
        early_stopping(test_loss[-1],model)
        
        if early_stopping.early_stop:
            print('Early stopping')
            break
    
    model.load_state_dict(torch.load('checkpoint.pt'))
    auc, report = test_metrics(test_loader)
    metric = [auc,report]
    
    
    train_losses.append(train_loss)
    train_acc.append(train_accuracy)
    metrics.append(metric)
    
    i += 1
    
    


Split 1 ......
Reset trainable parameters of layer = Linear(133, 250, bias=False)
Reset trainable parameters of layer = GATConv(133, 50, heads=5)
Reset trainable parameters of layer = Linear(in_features=2098, out_features=2048, bias=True)
Reset trainable parameters of layer = Linear(in_features=2048, out_features=50, bias=True)
Reset trainable parameters of layer = Linear(in_features=50, out_features=10, bias=True)
Reset trainable parameters of layer = Linear(in_features=10, out_features=1, bias=True)
Epoch: 0
Train Loss: 0.273 | Accuracy: 89.507
Test Loss: 0.140 | Accuracy: 95.503
Validation loss decreased (inf --> 0.140281).  Saving model ...
Validation loss decreased (0.140281 --> 0.122213).  Saving model ...
Validation loss decreased (0.122213 --> 0.117680).  Saving model ...
EarlyStopping counter: 1 out of 10
Validation loss decreased (0.117680 --> 0.116004).  Saving model ...
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [52]:
print_metrics(metrics)

AUC: 0.8897479307750189 +/- 0.037739490081784656
Accuracy: 0.9635974304068521 +/- 0.007896826087617052
Precision: 0.9523966240201037 +/- 0.026404948150059563
Recall: 0.9635974304068521 +/- 0.007896826087617052
F1-score: 0.9570415279908652 +/- 0.017070480279998915
