<a href="https://colab.research.google.com/github/james-yu2005/Affi-NN-ity/blob/main/Base_Model_Affi_NN_ity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
 !pip install PyTDC rdkit-pypi torch-geometric pandas tqdm

# PyTDC: Loading clean DTI datasets (like DAVIS)
# rdkit-pypi: Parsing drug SMILES and converting to molecule graphs
# torch-geometric: Building GNN architecture for drug inputs
# tqdm: Shows real-time progress for preprocessing steps like SMILES parsing)



In [None]:
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import Draw
from tdc.multi_pred import DTI
import pandas as pd

In [None]:
data = DTI(name = 'DAVIS')

Downloading...
100%|██████████| 21.4M/21.4M [00:01<00:00, 20.8MiB/s]
Loading...
Done!


In [None]:
data.print_stats()

--- Dataset Statistics ---
68 unique drugs.
379 unique targets.
25772 drug-target pairs.
--------------------------


In [None]:
# To see what the original dataset looks like
df_DAVIS = data.get_data()
df_DAVIS.head()

Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,AAK1,MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQV...,43.0
1,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ABL1p,PFWKILNPLLERGTYYYFMGQQPGKVLGDQRRPSLPALHFIKGAGK...,10000.0
2,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ABL2,MVLGTVLLPPNSYGRDQDTSLCCLCTEASESALPDLTDHFASCVED...,10000.0
3,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVR1,MVDGVMILPVLIMIALPSPSMEDEKPKVNPKLYMCVCEGLSCGNED...,10000.0
4,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVR1B,MAESAGASSFFPLVVLLLAGSGGSGPRGVQALLCACTSCLQANYTC...,10000.0


In [None]:
# Kd values have wide range and are skewed so it’s hard for a neural network to learn from.
# PyTDC transforms the Y column from Kd to pKd which makes the data unitless, normalized, and log-scaled, preparing it for regression.

data.convert_to_log(form = 'binding')
df_DAVIS = data.get_data()

To log space...


In [None]:
# Remove duplicates
df_DAVIS.drop_duplicates(inplace=True)
df_DAVIS.head(5)

Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,AAK1,MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQV...,7.365523
1,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ABL1p,PFWKILNPLLERGTYYYFMGQQPGKVLGDQRRPSLPALHFIKGAGK...,4.999996
2,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ABL2,MVLGTVLLPPNSYGRDQDTSLCCLCTEASESALPDLTDHFASCVED...,4.999996
3,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVR1,MVDGVMILPVLIMIALPSPSMEDEKPKVNPKLYMCVCEGLSCGNED...,4.999996
4,11314340,Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12,ACVR1B,MAESAGASSFFPLVVLLLAGSGGSGPRGVQALLCACTSCLQANYTC...,4.999996


In [None]:
class MoleculeDataset(Dataset):

    # Important Parameter:
    # max_seq_length: Max number of amino acids to one-hot encode per protein (default is 1000)
    def __init__(self, root, dataframe, split='train', test_fraction=0.2, val_fraction=0.1,
                 transform=None, pre_transform=None, max_seq_length=1000, random_state=42):
        self.dataframe = dataframe.reset_index()

        # Following defines how the data will be split for train, test and val
        self.split = split
        self.test_fraction = test_fraction
        self.val_fraction = val_fraction
        self.random_state = random_state

        # Initialize empty list for processed data. This will later hold the fully processed dataset
        # Each item is a PyG Data object representing a drug-target pair
        self.molecule_data = []
        self.max_seq_length = max_seq_length

        # This method (defined next) actually splits the dataframe for train, test, val
        self._split_data()

        super(MoleculeDataset, self).__init__(root, transform, pre_transform)


    # Split the dataframe into train, validation and test sets
    def _split_data(self):
        from sklearn.model_selection import train_test_split

        # First split off the test set
        train_val_df, test_df = train_test_split(
            self.dataframe,
            test_size=self.test_fraction,
            random_state=self.random_state
        )

        # Then split the train set into train and validation
        if self.val_fraction > 0:
            train_df, val_df = train_test_split(
                train_val_df,
                test_size=self.val_fraction / (1 - self.test_fraction),
                random_state=self.random_state
            )
        else:
            train_df = train_val_df
            val_df = train_val_df.iloc[0:0]  # Empty DataFrame with same columns

        # Assign the appropriate dataframe based on the split parameter
        if self.split == 'train':
            self.dataframe = train_df
        elif self.split == 'val':
            self.dataframe = val_df
        elif self.split == 'test':
            self.dataframe = test_df
        else:
            raise ValueError(f"Split '{self.split}' not recognized. Use 'train', 'val', or 'test'.")


    # Process molecules from SMILES into graph format and convert protein sequences to one-hot encoding
    def process(self):
        for index, row in tqdm(self.dataframe.iterrows(), total=self.dataframe.shape[0]): # Going row-by-row through the dataset, showing a progress bar with tqdm
            # Ensure column names are correct
            smiles = row["Drug"]  # If "Drug", it contains SMILES strings
            target_seq = row["Target"]  # If "Target", it contains protein sequences

            # Using RDKit to convert the SMILES string into a molecule object. If it fails (invalid SMILES), skip it.
            mol_obj = Chem.MolFromSmiles(smiles)
            if mol_obj is None:
                continue

            # Creates of tensor where each row represents a single atom in the molecule and its columns represents its features (like hybridization, aromatic ring, etc.)
            node_feats = self._get_node_features(mol_obj)

            # Goes through each bond in the molecule and extracts whether its single or not and whether its part of a ring or not.
            # Each bond is stored twice (once for each direction) so the tensor shape is [num_edges * 2, 2]
            edge_feats = self._get_edge_features(mol_obj)

            # Tensor shape: [2, num_edges]
            # First row is the source node (from where the bond originates). Second row is the destination node
            # NOT an adjacency matrix. Each column is an edge the first row number is the start and second row number is the end
            edge_index = self._get_adjacency_info(mol_obj)

            # One-hot encodes the amino acids (method defined next)
            target_features = self._sequence_to_one_hot(target_seq)

            # Data is a class that represents a single drug molecule graph
            data = Data(
                x=node_feats,
                edge_index=edge_index,
                edge_attr=edge_feats,
                y=torch.tensor([row["Y"]], dtype=torch.float)
            )

            # Attach the one-hot-encoded proteins to their respective drug graphs
            data.target_features = target_features

            # Store in list instead of saving to disk
            self.molecule_data.append(data)


    # Given a protein sequence, we convert it into a one-hot-encoding of shape [max_seq_length, 20]
    # Each row is one amino acid (up to 1000). Each column is one of the 20 standard amino acids
    def _sequence_to_one_hot(self, sequence):
        amino_acids = "ACDEFGHIKLMNPQRSTVWY"
        encoding = np.zeros((self.max_seq_length, len(amino_acids)), dtype=np.float32)
        for idx, amino_acid in enumerate(sequence[:self.max_seq_length]):
            if amino_acid in amino_acids:
                encoding[idx, amino_acids.index(amino_acid)] = 1
        return torch.tensor(encoding.flatten(), dtype=torch.float).unsqueeze(0) # Here we flatten the tensor from [1000 * 20] to [20000]

    # ISSUE 07.05.25
    # If we return [20000], and batch 32 of them together, PyG tries to stack them as torch.cat([ [20,000], [20,000], ..., [20,000] ]) giving [640,000]
    # Instead of getting: [32, 20000] which is what the Linear layer expects
    # Unsqueeze adds a new first dimension converting [20,000] to [1, 20,000]
    # NOTE: [1, 20000] doesn't mean there's a physical "1" in the beginning of each vector.
    # The 1 is a new dimension that converts each flattened tensor into a row so that when 32 (batch size) such tensors are stacked it becomes a 2D matrix.
    # Without the 1, if 32 tensors were stacked it would become a really long 1D matrix giving [640,000] instead of [32, 64,000].


    # The following 3 functions have been (kind of) explained in the process function. Basically extracts node & edge features.
    def _get_node_features(self, mol):
        all_node_feats = []
        for atom in mol.GetAtoms():
            node_feats = [
                atom.GetAtomicNum(),
                atom.GetDegree(),
                atom.GetFormalCharge(),
                atom.GetHybridization(),
                atom.GetIsAromatic(),
                atom.GetTotalNumHs(),
                atom.GetNumRadicalElectrons(),
                atom.IsInRing(),
                atom.GetChiralTag()
            ]
            all_node_feats.append(node_feats)
        return torch.tensor(np.array(all_node_feats), dtype=torch.float)


    def _get_edge_features(self, mol):
        all_edge_feats = []
        for bond in mol.GetBonds():
            edge_feats = [
                bond.GetBondTypeAsDouble(),
                bond.IsInRing()
            ]
            all_edge_feats += [edge_feats, edge_feats]  # Bidirectional edges
        return torch.tensor(np.array(all_edge_feats), dtype=torch.float)


    def _get_adjacency_info(self, mol):
        edge_indices = []
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]  # Bidirectional edges
        return torch.tensor(edge_indices, dtype=torch.long).t().contiguous()


    def len(self):
        return len(self.dataframe)


    # Returns the processed molecule at index idx
    # Each item in self.molecule_data is a PyG Data object representing one drug–protein pair.
    # This is the empty list we initialized earlier
    def get(self, idx):
        return self.molecule_data[idx]


    # Return processed file names. Since we're storing in memory, we'll return an empty list or a dummy file name. """
    def processed_file_names(self):
        return []

In [None]:
# Create train, validation, and test datasets. The process() method converts each row into a PyTorch Geometric Data object
train_dataset = MoleculeDataset(root='.', dataframe=df_DAVIS, split='train')
train_dataset.process()

val_dataset = MoleculeDataset(root='.', dataframe=df_DAVIS, split='val')
val_dataset.process()

test_dataset = MoleculeDataset(root='.', dataframe=df_DAVIS, split='test')
test_dataset.process()

# Create DataLoaders
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

100%|██████████| 18039/18039 [00:38<00:00, 471.12it/s]
100%|██████████| 18039/18039 [00:29<00:00, 621.53it/s]
100%|██████████| 2578/2578 [00:04<00:00, 611.75it/s]
100%|██████████| 2578/2578 [00:04<00:00, 591.81it/s]
100%|██████████| 5155/5155 [00:09<00:00, 572.64it/s]
100%|██████████| 5155/5155 [00:09<00:00, 570.85it/s]


In [None]:
# Each protein is one-hot encoded as a [1000, 20] matrix
seq_dim = 1000 * 20

class GINDrugTargetModel(torch.nn.Module):
    def __init__(self, node_feat_dim=9, seq_dim=20000, hidden_dim=128, output_dim=1):
        super(GINDrugTargetModel, self).__init__()

        # Node feature embedding
        self.node_embedding = Sequential(
            Linear(node_feat_dim, hidden_dim),
            ReLU(),
            Linear(hidden_dim, hidden_dim)
        )

        # GIN convolution layers
        # Atoms exchanging information with their neighbors happens here
        nn1 = Sequential(Linear(hidden_dim, hidden_dim), ReLU(), Linear(hidden_dim, hidden_dim))
        self.conv1 = GINConv(nn1)
        self.bn1 = torch.nn.BatchNorm1d(hidden_dim)

        nn2 = Sequential(Linear(hidden_dim, hidden_dim), ReLU(), Linear(hidden_dim, hidden_dim))
        self.conv2 = GINConv(nn2)
        self.bn2 = torch.nn.BatchNorm1d(hidden_dim)

        self.seq_embedding = torch.nn.Sequential(
            torch.nn.Linear(seq_dim,   hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
        )

        self.predictor = torch.nn.Sequential(
            torch.nn.Linear(2*hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(hidden_dim, hidden_dim//2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim//2, output_dim)
        )

    # Forward Pass
    def forward(self, batch):
      x = self.node_embedding(batch.x)
      x = F.relu(self.conv1(x, batch.edge_index))
      x = self.bn1(x)
      x = F.relu(self.conv2(x, batch.edge_index))
      x = self.bn2(x)

      # Only do this ONCE
      x = global_add_pool(x, batch.batch)

      seq = batch.target_features  # shape: [batch_size, seq_dim]
      seq_emb = self.seq_embedding(seq)  # shape: [batch_size, hidden_dim]

      combined = torch.cat([x, seq_emb], dim=1)  # shape: [batch_size, 2*hidden_dim]
      return self.predictor(combined)


In [None]:
import torch
import torch.nn.functional as F

def train_model(model, train_loader, val_loader, num_epochs=100, lr=0.001, device='cuda'):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        total_train_loss = 0

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            preds = model(batch).squeeze()  # Model predicts pKd scores
            targets = batch.y.squeeze()     # True pKd values from the data

            loss = F.mse_loss(preds, targets) # Compute mean squared error loss
            loss.backward() # Backpropagate the error
            optimizer.step() # Update weights

            total_train_loss += loss.item() * batch.num_graphs # batch.num_graphs is the number of samples in the batch (32)

        avg_train_loss = total_train_loss / len(train_loader.dataset) # This gives average MSE loss over all training samples.

        # Validation Phase
        model.eval()
        total_val_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)

                preds = model(batch).squeeze()
                targets = batch.y.squeeze()

                loss = F.mse_loss(preds, targets)
                total_val_loss += loss.item() * batch.num_graphs

        avg_val_loss = total_val_loss / len(val_loader.dataset)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    return model

In [None]:
for batch in train_loader:
    print(f"Node features shape: {batch.x.shape}")          # 1007 (1st number) = total number of nodes (atoms) across all 32 graphs in the batch
                                                            # 9 (2nd number) = number of features per node (like atomic number, valency, etc.)
    print(f"Edge features shape: {batch.edge_attr.shape}")  # 2208 (1st number) = total number of edges (bonds) across the 32 graphs
                                                            # 2 (2nd number) = number of features per edge (e.g., bond type, in ring)
    print(f"Target features shape: {batch.y.shape}")
    print(f"Target features type: {type(batch.y)}")
    print(f"Target features element: {batch.y[0]}")         # Prints pKd of the first sample in the training batch (actual not predicted)
    break

Node features shape: torch.Size([1007, 9])
Edge features shape: torch.Size([2208, 2])
Target features shape: torch.Size([32])
Target features type: <class 'torch.Tensor'>
Target features element: 4.999995708465576


In [None]:
def test_model(model, test_loader, device='cuda'):
    model.eval()
    preds = []
    trues = []

    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            pred = model(batch)
            preds.append(pred.view(-1).cpu())
            trues.append(batch.y.view(-1).cpu())

    preds = torch.cat(preds, dim=0)
    trues = torch.cat(trues, dim=0)

    mse = F.mse_loss(preds, trues)
    print(f'Test MSE: {mse.item():.4f}')

    # Print few predictions vs actual value
    print("\nSample Predictions vs True Binding Affinities:")
    for i in range(min(100, len(preds))):  # Show 10 samples (or fewer if smaller dataset)
        print(f"True: {trues[i].item():.4f}, Predicted: {preds[i].item():.4f}")

    return preds, trues, mse.item()

In [None]:
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv, global_add_pool
seq_dim = 1000 * 20
model = GINDrugTargetModel(node_feat_dim=9, seq_dim=seq_dim)

# Train the model
trained_model = train_model(model, train_loader, val_loader, num_epochs=40)

Epoch 1/40 | Train Loss: 1.1017 | Val Loss: 37.5491
Epoch 2/40 | Train Loss: 0.7322 | Val Loss: 0.7588
Epoch 3/40 | Train Loss: 0.6322 | Val Loss: 0.8940
Epoch 4/40 | Train Loss: 0.5836 | Val Loss: 0.9323
Epoch 5/40 | Train Loss: 0.5471 | Val Loss: 0.6765
Epoch 6/40 | Train Loss: 0.5091 | Val Loss: 1.1971
Epoch 7/40 | Train Loss: 0.4829 | Val Loss: 1.2158
Epoch 8/40 | Train Loss: 0.4742 | Val Loss: 2.7243
Epoch 9/40 | Train Loss: 0.4946 | Val Loss: 0.7800
Epoch 10/40 | Train Loss: 0.4631 | Val Loss: 0.9248
Epoch 11/40 | Train Loss: 0.4317 | Val Loss: 1.5384
Epoch 12/40 | Train Loss: 0.4331 | Val Loss: 0.7947
Epoch 13/40 | Train Loss: 0.4190 | Val Loss: 1.7050
Epoch 14/40 | Train Loss: 0.4036 | Val Loss: 2.9514
Epoch 15/40 | Train Loss: 0.3977 | Val Loss: 1.3010
Epoch 16/40 | Train Loss: 0.3766 | Val Loss: 1.3051
Epoch 17/40 | Train Loss: 0.3694 | Val Loss: 0.7695
Epoch 18/40 | Train Loss: 0.3694 | Val Loss: 1.4793
Epoch 19/40 | Train Loss: 0.3631 | Val Loss: 1.1396
Epoch 20/40 | Train 

In [None]:
# Test the model
test_preds, test_targets, test_mse = test_model(trained_model, test_loader)

Test MSE: 0.8480

Sample Predictions vs True Binding Affinities:
True: 5.0000, Predicted: 4.4211
True: 5.0000, Predicted: 4.4903
True: 5.0000, Predicted: 4.6736
True: 5.0000, Predicted: 4.2273
True: 5.0000, Predicted: 4.5009
True: 5.0000, Predicted: 4.2712
True: 5.0000, Predicted: 4.0705
True: 5.6989, Predicted: 5.5111
True: 5.8538, Predicted: 4.5572
True: 5.0000, Predicted: 4.3157
True: 5.0000, Predicted: 4.3462
True: 5.0000, Predicted: 4.3199
True: 7.7670, Predicted: 6.5214
True: 5.0000, Predicted: 4.4849
True: 5.0000, Predicted: 4.3661
True: 5.0000, Predicted: 4.4984
True: 5.0000, Predicted: 4.3369
True: 5.0000, Predicted: 4.1348
True: 5.7695, Predicted: 5.0628
True: 5.0000, Predicted: 4.4543
True: 7.1186, Predicted: 5.3831
True: 5.6383, Predicted: 5.0678
True: 5.0000, Predicted: 4.4818
True: 5.0000, Predicted: 4.6684
True: 7.3862, Predicted: 5.7712
True: 5.0000, Predicted: 4.3926
True: 5.0000, Predicted: 4.4756
True: 5.0000, Predicted: 4.4631
True: 5.0000, Predicted: 4.7591
True: 5