In [154]:
import os
import numpy as np 
import pandas as pd
from tqdm import tqdm
from rdkit import Chem

import torch.nn.functional as F 
from torch.nn import Linear, BatchNorm1d, ModuleList
from torch_geometric.nn import TransformerConv, TopKPooling 
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

import torch
import torch_geometric
from torch_geometric.data import Dataset, Data , DataLoader
from torch_geometric.nn import GATConv
import torch.optim as optim

from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score


## Loading raw data

In [155]:
Data_Path = "~/ML/HIV/HIV.csv"
data = pd.read_csv(Data_Path)
data.head

<bound method NDFrame.head of                                                   smiles activity  HIV_active
0      CCC1=[O+][Cu-3]2([O+]=C(CC)C1)[O+]=C(CC)CC(CC)...       CI           0
1      C(=Cc1ccccc1)C1=[O+][Cu-3]2([O+]=C(C=Cc3ccccc3...       CI           0
2                       CC(=O)N1c2ccccc2Sc2c1ccc1ccccc21       CI           0
3        Nc1ccc(C=Cc2ccc(N)cc2S(=O)(=O)O)c(S(=O)(=O)O)c1       CI           0
4                                 O=S(=O)(O)CCS(=O)(=O)O       CI           0
...                                                  ...      ...         ...
41122  CCC1CCC2c3c([nH]c4ccc(C)cc34)C3C(=O)N(N(C)C)C(...       CI           0
41123  Cc1ccc2[nH]c3c(c2c1)C1CCC(C(C)(C)C)CC1C1C(=O)N...       CI           0
41124  Cc1ccc(N2C(=O)C3c4[nH]c5ccccc5c4C4CCC(C(C)(C)C...       CI           0
41125  Cc1cccc(N2C(=O)C3c4[nH]c5ccccc5c4C4CCC(C(C)(C)...       CI           0
41126  CCCCCC=C(c1cc(Cl)c(OC)c(-c2nc(C)no2)c1)c1cc(Cl...       CI           0

[41127 rows x 3 columns]>

## Define MoleculeDataset class

In [156]:

class MoleculeDataset(Dataset):
    def __init__(self, root, filename, test=False, transform=None, pre_transform=None, data=None):
        self.root = os.path.expanduser(root)
        self.filename = filename
        self.test = test
        self.transform = transform
        self.pre_transform = pre_transform
        self.data = data if data is not None else self.load_data()
        self._indices = list(range(len(self.data)))
        self.my_processed_dir = os.path.join(self.root, "processed")
        os.makedirs(self.my_processed_dir, exist_ok=True)

        if self.pre_transform:
            self.data = self.pre_transform(self.data)
    
    def load_data(self):
        file_path = os.path.join(self.root, self.filename)
        return pd.read_csv(file_path)

    def process(self, title="Processing Data"):
        print(f"{title} and saving to {self.my_processed_dir}...")
        for index, mol in tqdm(self.data.iterrows(), total=self.data.shape[0], disable=True):
            mol_obj = Chem.MolFromSmiles(mol["smiles"])
            if mol_obj is None:
                continue
            data = Data(
                x=self._get_node_features(mol_obj),
                edge_index=self._get_adjacency_info(mol_obj),
                edge_attr=self._get_edge_features(mol_obj),
                y=self._get_labels(mol["HIV_active"]),
                smiles=mol["smiles"]
            )
            file_name = f"data_test_{index}.pt" if self.test else f"data_{index}.pt"
            file_path = os.path.join(self.my_processed_dir, file_name)
            torch.save(data, file_path)

    def _get_node_features(self, mol):
        features = [
            [atom.GetAtomicNum(), atom.GetDegree(), atom.GetFormalCharge(), 
             atom.GetHybridization(), atom.GetIsAromatic(), atom.GetTotalNumHs(), 
             atom.GetNumRadicalElectrons(), atom.IsInRing(), atom.GetChiralTag()]
            for atom in mol.GetAtoms()
        ]
        return torch.tensor(features, dtype=torch.float)

    def _get_edge_features(self, mol):
        features = []
        for bond in mol.GetBonds():
            bond_feats = [bond.GetBondTypeAsDouble(), bond.IsInRing()]
            features.append(bond_feats)
            features.append(bond_feats)  # Add again for the reverse direction
        return torch.tensor(features, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            indices.append([i, j])
            indices.append([j, i])  # Add reverse direction

        edge_indices = torch.tensor(indices).t().to(torch.long)
        return edge_indices

    def _get_labels(self, label):
        return torch.tensor([label], dtype=torch.int64)

    def __len__(self):
        return len(self.data)

    def get(self, idx):
        file_name = f"data_test_{idx}.pt" if self.test else f"data_{idx}.pt"
        file_path = os.path.join(self.my_processed_dir, file_name)
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"No such file: '{file_path}'")
        return torch.load(file_path)

    def __getitem__(self, idx):
        if isinstance(idx, int):
            return self.get(idx)
        elif isinstance(idx, slice):
            indices = range(*idx.indices(len(self)))
            data_subset = self.data.iloc[indices].reset_index(drop=True)
            return MoleculeDataset(root=self.root, filename=self.filename, test=self.test, transform=self.transform, pre_transform=self.pre_transform, data=data_subset)
        else:
            raise TypeError("Invalid argument type.")

    @property
    def num_node_features(self):
        sample_data = self[0]
        return sample_data.x.shape[1]

    @property
    def num_classes(self):
        return self.data['HIV_active'].nunique()

    def shuffle(self):
        indices = torch.randperm(len(self.data)).tolist()
        self.data = self.data.iloc[indices].reset_index(drop=True)
        return self


## Preprocessing data

Loads the Data: Reads the raw data from a CSV file.           
Converts SMILES Strings: Converts each SMILES string into an RDKit molecule object.          
Extracts Features: Extracts node features, edge features, and adjacency information from the molecule.         
Creates Data Objects: Creates Data objects containing the extracted features and labels.        
Saves Data Objects: Saves each Data object to a file in the specified directory.          

In [157]:
dataset = MoleculeDataset(root='~/ML/HIV', filename="HIV.csv")
dataset

MoleculeDataset(41127)

In [158]:
print(dataset[0].edge_index)
print(dataset[0].x)
print(dataset[0].edge_attr)
print(dataset[0].y)

tensor([[ 1,  2,  2,  3,  3,  4,  3,  5,  5,  6,  6,  7,  6,  8,  8,  9,  9, 10,
         10, 11, 11, 12, 12, 13, 13,  2, 13,  8],
        [ 2,  1,  3,  2,  4,  3,  5,  3,  6,  5,  7,  6,  8,  6,  9,  8, 10,  9,
         11, 10, 12, 11, 13, 12,  2, 13,  8, 13]])
tensor([[17.,  0.,  0.,  4.,  0.,  1.,  0.,  0.,  0.],
        [ 6.,  1.,  0.,  4.,  0.,  3.,  0.,  0.,  0.],
        [ 7.,  3.,  0.,  3.,  1.,  0.,  0.,  1.,  0.],
        [ 6.,  3.,  0.,  3.,  1.,  0.,  0.,  1.,  0.],
        [ 7.,  1.,  0.,  3.,  0.,  1.,  0.,  0.,  0.],
        [ 7.,  2.,  0.,  3.,  1.,  0.,  0.,  1.,  0.],
        [ 7.,  3.,  1.,  3.,  1.,  0.,  0.,  1.,  0.],
        [ 8.,  1., -1.,  3.,  0.,  0.,  0.,  0.,  0.],
        [ 6.,  3.,  0.,  3.,  1.,  0.,  0.,  1.,  0.],
        [ 6.,  2.,  0.,  3.,  1.,  1.,  0.,  1.,  0.],
        [ 6.,  2.,  0.,  3.,  1.,  1.,  0.,  1.,  0.],
        [ 6.,  2.,  0.,  3.,  1.,  1.,  0.,  1.,  0.],
        [ 6.,  2.,  0.,  3.,  1.,  1.,  0.,  1.,  0.],
        [ 6.,  3.,  0.

In [159]:
dataset.process()  # Ensure the data is processed and saved

Processing Data and saving to /home/mpir0002/ML/HIV/processed...




In [160]:
print(len(dataset))
print(dataset.num_classes)

41127
2


In [161]:
dataset.num_node_features

9

In [162]:
data = dataset[0]
print(data.is_directed())

False


In [163]:
print(data.num_nodes)
print(data.num_edges)

19
40


## Configuring model architecture

In [164]:

class GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN, self).__init__()
        
        num_classes = 2
        embedding_size = 1024
        
        # GNN layers
        self.conv1 = GATConv(feature_size , embedding_size , heads=3 , dropout = 0.3)
        self.head_transform1 = Linear(embedding_size*3 , embedding_size)
        self.pool1 = TopKPooling(embedding_size , ratio = 0.8)
        self.conv2 = GATConv(embedding_size , embedding_size , heads=3 , dropout = 0.3)
        self.head_transform2 = Linear(embedding_size*3 , embedding_size)
        self.pool2 = TopKPooling(embedding_size , ratio = 0.5)
        self.conv3 = GATConv(embedding_size , embedding_size , heads=3 , dropout = 0.3)
        self.head_transform3 = Linear(embedding_size*3 , embedding_size)
        self.pool3 = TopKPooling(embedding_size , ratio = 0.3)
        
        # Linear layers
        self.linear1 = Linear(embedding_size*2 , 1024)
        self.linear2 = Linear(1024 , num_classes)
        
    def forward(self , x , edge_index , batch_index):
        
        x = self.conv1(x , edge_index)
        x = self.head_transform1(x)
        x , edge_index , edge_attr , batch_index , _ , _ = self.pool1(x , edge_index, None, batch_index)
        x1 = torch.cat([gmp(x , batch_index) , gap(x , batch_index)] , dim = 1)
        
        x = self.conv2(x , edge_index)
        x = self.head_transform2(x)
        x , edge_index , edge_attr , batch_index , _ , _ = self.pool2(x , edge_index, None, batch_index)
        x2 = torch.cat([gmp(x , batch_index) , gap(x , batch_index)] , dim = 1)
        
        x = self.conv3(x , edge_index)
        x = self.head_transform3(x)
        x , edge_index , edge_attr , batch_index , _ , _ = self.pool3(x , edge_index, None, batch_index)
        x3 = torch.cat([gmp(x , batch_index) , gap(x , batch_index)] , dim = 1)
        
        x = x1 + x2 + x3
        
        x = self.linear1(x).relu()
        x = F.dropout(x , p = 0.5 , training = self.training)
        x = self.linear2(x)
        
        return x



## Model training and evaluation

In [165]:
model = GNN(feature_size=train_dataset[0].x.shape[1])
model

GNN(
  (conv1): GATConv(9, 1024, heads=3)
  (head_transform1): Linear(in_features=3072, out_features=1024, bias=True)
  (pool1): TopKPooling(1024, ratio=0.8, multiplier=1.0)
  (conv2): GATConv(1024, 1024, heads=3)
  (head_transform2): Linear(in_features=3072, out_features=1024, bias=True)
  (pool2): TopKPooling(1024, ratio=0.5, multiplier=1.0)
  (conv3): GATConv(1024, 1024, heads=3)
  (head_transform3): Linear(in_features=3072, out_features=1024, bias=True)
  (pool3): TopKPooling(1024, ratio=0.3, multiplier=1.0)
  (linear1): Linear(in_features=2048, out_features=1024, bias=True)
  (linear2): Linear(in_features=1024, out_features=2, bias=True)
)

In [166]:
def calculate_metrics(y_pred, y_true, epoch, phase):
    metrics = {}
    metrics['confusion_matrix'] = confusion_matrix(y_true, y_pred)
    metrics['f1_score'] = f1_score(y_true, y_pred, average='weighted')
    metrics['accuracy'] = accuracy_score(y_true, y_pred)
    metrics['precision'] = precision_score(y_true, y_pred, average='weighted')
    metrics['recall'] = recall_score(y_true, y_pred, average='weighted')

    print(f"\n{phase.capitalize()} Epoch {epoch} Metrics:")
    print(f"Confusion Matrix:\n{metrics['confusion_matrix']}")
    print(f"F1 Score: {metrics['f1_score']}")
    print(f"Accuracy: {metrics['accuracy']}")
    print(f"Precision: {metrics['precision']}")
    print(f"Recall: {metrics['recall']}")

    return metrics


In [167]:
def train(epoch, model, train_loader, loss_fn, optimizer, scheduler):
    model.train()  # Set the model to training mode
    all_preds = []
    all_labels = []
    running_loss = 0.0
    
    for _, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()  # Reset gradients
        pred = model(batch.x.float(), batch.edge_index, batch.batch)  # Forward pass
        loss = loss_fn(pred, batch.y)  # Calculate loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update parameters
        running_loss += loss.item()
        
        all_preds.append(np.argmax(pred.detach().cpu().numpy(), axis=1))
        all_labels.append(batch.y.cpu().detach().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    metrics = calculate_metrics(all_preds, all_labels, epoch, "train")
    scheduler.step()  # Update learning rate
    return running_loss, metrics

def test(epoch, model, test_loader, loss_fn):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_labels = []
    running_loss = 0.0
    
    with torch.no_grad():
        for batch in test_loader:
            pred = model(batch.x.float(), batch.edge_index, batch.batch)
            loss = loss_fn(pred, batch.y)
            running_loss += loss.item()
            
            all_preds.append(np.argmax(pred.cpu().numpy(), axis=1))
            all_labels.append(batch.y.cpu().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    metrics = calculate_metrics(all_preds, all_labels, epoch, "test")
    return running_loss, metrics


In [168]:
weights = torch.tensor([1,10] , dtype = torch.float32)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(model.parameters() , lr = 0.1 , momentum = 0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer , gamma = 0.95)

In [169]:
train_loader = DataLoader(train_dataset , batch_size=40 , shuffle=True)
test_loader = DataLoader(test_dataset , batch_size = 40 , shuffle=True)



In [170]:

for epoch in range(1, 11): 
    train_loss, train_metrics = train(epoch, model, train_loader, loss_fn, optimizer, scheduler)
    test_loss, test_metrics = test(epoch, model, test_loader, loss_fn)
    print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
    print(f"Train Metrics: {train_metrics}")
    print(f"Test Metrics: {test_metrics}")


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.38it/s]



Train Epoch 1 Metrics:
Confusion Matrix:
[[278  13]
 [  7   2]]
F1 Score: 0.9413194444444444
Accuracy: 0.9333333333333333
Precision: 0.9501754385964911
Recall: 0.9333333333333333


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 1 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 1, Train Loss: 10.7168, Test Loss: 1.9455
Train Metrics: {'confusion_matrix': array([[278,  13],
       [  7,   2]]), 'f1_score': 0.9413194444444444, 'accuracy': 0.9333333333333333, 'precision': 0.9501754385964911, 'recall': 0.9333333333333333}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.25it/s]



Train Epoch 2 Metrics:
Confusion Matrix:
[[240  51]
 [  8   1]]
F1 Score: 0.8648054989506979
Accuracy: 0.8033333333333333
Precision: 0.9392866004962779
Recall: 0.8033333333333333


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 2 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 2, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[240,  51],
       [  8,   1]]), 'f1_score': 0.8648054989506979, 'accuracy': 0.8033333333333333, 'precision': 0.9392866004962779, 'recall': 0.8033333333333333}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.51it/s]
  _warn_prf(average, modifier, msg_start, len(result))



Train Epoch 3 Metrics:
Confusion Matrix:
[[291   0]
 [  9   0]]
F1 Score: 0.9552284263959391
Accuracy: 0.97
Precision: 0.9409
Recall: 0.97


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 3 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 3, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[291,   0],
       [  9,   0]]), 'f1_score': 0.9552284263959391, 'accuracy': 0.97, 'precision': 0.9409, 'recall': 0.97}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.54it/s]
  _warn_prf(average, modifier, msg_start, len(result))



Train Epoch 4 Metrics:
Confusion Matrix:
[[291   0]
 [  9   0]]
F1 Score: 0.9552284263959391
Accuracy: 0.97
Precision: 0.9409
Recall: 0.97


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 4 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 4, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[291,   0],
       [  9,   0]]), 'f1_score': 0.9552284263959391, 'accuracy': 0.97, 'precision': 0.9409, 'recall': 0.97}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.41it/s]
  _warn_prf(average, modifier, msg_start, len(result))



Train Epoch 5 Metrics:
Confusion Matrix:
[[291   0]
 [  9   0]]
F1 Score: 0.9552284263959391
Accuracy: 0.97
Precision: 0.9409
Recall: 0.97


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 5 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 5, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[291,   0],
       [  9,   0]]), 'f1_score': 0.9552284263959391, 'accuracy': 0.97, 'precision': 0.9409, 'recall': 0.97}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.56it/s]
  _warn_prf(average, modifier, msg_start, len(result))



Train Epoch 6 Metrics:
Confusion Matrix:
[[291   0]
 [  9   0]]
F1 Score: 0.9552284263959391
Accuracy: 0.97
Precision: 0.9409
Recall: 0.97


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 6 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 6, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[291,   0],
       [  9,   0]]), 'f1_score': 0.9552284263959391, 'accuracy': 0.97, 'precision': 0.9409, 'recall': 0.97}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.45it/s]
  _warn_prf(average, modifier, msg_start, len(result))



Train Epoch 7 Metrics:
Confusion Matrix:
[[291   0]
 [  9   0]]
F1 Score: 0.9552284263959391
Accuracy: 0.97
Precision: 0.9409
Recall: 0.97


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 7 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 7, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[291,   0],
       [  9,   0]]), 'f1_score': 0.9552284263959391, 'accuracy': 0.97, 'precision': 0.9409, 'recall': 0.97}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.45it/s]
  _warn_prf(average, modifier, msg_start, len(result))



Train Epoch 8 Metrics:
Confusion Matrix:
[[291   0]
 [  9   0]]
F1 Score: 0.9552284263959391
Accuracy: 0.97
Precision: 0.9409
Recall: 0.97


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 8 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 8, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[291,   0],
       [  9,   0]]), 'f1_score': 0.9552284263959391, 'accuracy': 0.97, 'precision': 0.9409, 'recall': 0.97}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.52it/s]
  _warn_prf(average, modifier, msg_start, len(result))



Train Epoch 9 Metrics:
Confusion Matrix:
[[291   0]
 [  9   0]]
F1 Score: 0.9552284263959391
Accuracy: 0.97
Precision: 0.9409
Recall: 0.97


  _warn_prf(average, modifier, msg_start, len(result))



Test Epoch 9 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 9, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[291,   0],
       [  9,   0]]), 'f1_score': 0.9552284263959391, 'accuracy': 0.97, 'precision': 0.9409, 'recall': 0.97}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.55it/s]
  _warn_prf(average, modifier, msg_start, len(result))



Train Epoch 10 Metrics:
Confusion Matrix:
[[291   0]
 [  9   0]]
F1 Score: 0.9552284263959391
Accuracy: 0.97
Precision: 0.9409
Recall: 0.97

Test Epoch 10 Metrics:
Confusion Matrix:
[[96  0]
 [ 3  0]]
F1 Score: 0.9547785547785548
Accuracy: 0.9696969696969697
Precision: 0.9403122130394858
Recall: 0.9696969696969697
Epoch 10, Train Loss: nan, Test Loss: nan
Train Metrics: {'confusion_matrix': array([[291,   0],
       [  9,   0]]), 'f1_score': 0.9552284263959391, 'accuracy': 0.97, 'precision': 0.9409, 'recall': 0.97}
Test Metrics: {'confusion_matrix': array([[96,  0],
       [ 3,  0]]), 'f1_score': 0.9547785547785548, 'accuracy': 0.9696969696969697, 'precision': 0.9403122130394858, 'recall': 0.9696969696969697}


  _warn_prf(average, modifier, msg_start, len(result))


In [151]:
# import torch.optim as optim

# loss_fn = torch.nn.CrossEntropyLoss()

# optimizer = optim.Adam(model.parameters(), lr=0.001)


# def train(epoch, model, train_loader, loss_fn):
#     model.train()
#     running_loss = 0.0
#     for step, batch in enumerate(train_loader):
#         optimizer.zero_grad()

#         # Forward pass
#         pred = model(batch.x.float(), batch.edge_index, batch.batch)
        
#         # Calculate the loss
#         loss = loss_fn(pred, batch.y)

#         # Backward pass
#         loss.backward()

#         # Update parameters
#         optimizer.step()

#         running_loss += loss.item()

#     # Print or log the training loss
#     print(f"Epoch {epoch}, Loss: {running_loss / len(train_loader)}")

# # Example usage:
# num_epochs = 10
# for epoch in range(1, num_epochs + 1):
#     train(epoch, model, train_dataset, loss_fn)
