In [None]:
# !pip install biopython
!pip install --upgrade --no-cache-dir biopython
!pip install rdkit-pypi
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.2.0+cu118.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.2.0+cu118.html
!pip install -q torch-geometric
!pip install fair-esm


In [None]:
import numpy as np
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from tqdm import tqdm
import torch
from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
import pickle
from torch.utils.data import DataLoader, Subset, random_split
import esm
from joblib import Parallel, delayed
import pickle




In [None]:
pIC50=np.load('/kaggle/input/drug-virus-features/pIC50.npy')
with open("/kaggle/input/drug-virus-features/drug_graphs.pkl", "rb") as f:
    drug_graphs = pickle.load(f)
with open("/kaggle/input/drug-virus-features/protein_graphs.pkl", "rb") as f:
    protein_graphs = pickle.load(f)



In [None]:
def drug_graph_to_data(drug_graph):
    mol_size, nodes, edges, edges_type = drug_graph
    x = torch.tensor(nodes, dtype=torch.float)  # [num_nodes, node_features]
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()  # [2, num_edges]
    edge_attr = torch.tensor(edges_type, dtype=torch.float).unsqueeze(1)  # [num_edges, 1]
    
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data

In [None]:
def protein_graph_to_data(protein_graph):
    node_features,edge_index,edge_attr = protein_graph
    x = node_features
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  # [2, num_edges]
    edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)  # [num_edges, 1]
    
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data

In [None]:
class DrugProteinDataset(torch.utils.data.Dataset):
    def __init__(self, protein_graphs, drug_graphs, pIC50_values):
        self.protein_graphs = protein_graphs
        self.drug_graphs = drug_graphs
        self.pIC50_values = pIC50_values
    
    def __len__(self):
        return len(self.pIC50_values)
    
    def __getitem__(self, idx):
        protein_graphs = protein_graph_to_data(self.protein_graphs[idx])
        drug_graph = drug_graph_to_data(self.drug_graphs[idx])
        pIC50_value = torch.tensor(self.pIC50_values[idx], dtype=torch.float)
        return protein_graphs, drug_graph, pIC50_value

def custom_collate(batch):
    protein_graphs = ([item[0] for item in batch])  
    drug_graphs = [item[1] for item in batch]                 # List of PyG Data objects
    labels = torch.stack([item[2] for item in batch])         # [batch_size]

    batch_protein_graphs = Batch.from_data_list(protein_graphs)     # Combine graphs into a single batched graph
    batch_drug_graphs = Batch.from_data_list(drug_graphs)     # Combine graphs into a single batched graph

    return batch_protein_graphs, batch_drug_graphs, labels


In [None]:
class DrugTargetGNN(nn.Module):
    def __init__(self, node_feature_dim=78, protein_feature_dim=20, hidden_dim=128):
        super().__init__()
        # GNN layers for drug graph
        self.drugconv1 = GCNConv(node_feature_dim, node_feature_dim)
        self.drugconv2 = GCNConv(node_feature_dim, node_feature_dim*2)
        self.drugconv3 = GCNConv(node_feature_dim*2, node_feature_dim*4)
        self.druglinear1 = nn.Linear(node_feature_dim*4, 1024)
        self.druglinear2 = nn.Linear(1024, hidden_dim)

        
        #GNN layers for protein graph
        self.proteinconv1 = GCNConv(protein_feature_dim, protein_feature_dim)
        self.proteinconv2 = GCNConv(protein_feature_dim, protein_feature_dim*2)
        self.proteinconv3 = GCNConv(protein_feature_dim*2, protein_feature_dim*4)
        self.proteinlinear1 = nn.Linear(protein_feature_dim*4, 1024)
        self.proteinlinear2 = nn.Linear(1024, hidden_dim)

        
        # Final layers for combined features
        self.final_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, 1024),
            nn.ReLU(),
            
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 1)  # regression output for pIC50
        )

        
    def forward(self, protein_graph, drug_graph):
        # GNN on drug graph
        x, edge_index,edge_attr = drug_graph.x, drug_graph.edge_index,drug_graph.edge_attr
        p, edge_index_p,edge_attr_p = protein_graph.x, protein_graph.edge_index,protein_graph.edge_attr

        x = F.relu(self.drugconv1(x, edge_index,edge_attr))
        x = F.relu(self.drugconv2(x, edge_index,edge_attr))
        x = F.relu(self.drugconv3(x, edge_index,edge_attr))
        
        x = global_mean_pool(x, drug_graph.batch)  # [batch_size, hidden_dim]
        x = F.relu(self.druglinear1(x))
        x = F.relu(self.druglinear2(x))


        p = F.relu(self.proteinconv1(p, edge_index_p,edge_attr_p))
        p = F.relu(self.proteinconv2(p, edge_index_p,edge_attr_p))
        p = F.relu(self.proteinconv3(p, edge_index_p,edge_attr_p))
        
        p = global_mean_pool(p, protein_graph.batch)  # [batch_size, hidden_dim]
        p = F.relu(self.proteinlinear1(p))
        p = F.relu(self.proteinlinear2(p))



        
        
        
        # Combine embeddings
        combined = torch.cat([x, p], dim=1)
        out = self.final_mlp(combined)
        return out.squeeze()  # [batch_size]

In [None]:

dataset = DrugProteinDataset(protein_graphs, drug_graphs, pIC50)

# Set split sizes
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size  # to avoid rounding issues

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    lengths=[train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)  # for reproducibility
)


train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, collate_fn=custom_collate)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DrugTargetGNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()  # for regression

model.train()
for epoch in range(50):
    total_loss = 0
    for protein_graph, drug_graph, values in train_loader:
        protein_graph = protein_graph.to(device)
        drug_graph = drug_graph.to(device)
        values = values.to(device)

        optimizer.zero_grad()
        outputs = model(protein_graph, drug_graph)
        loss = criterion(outputs, values)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

In [None]:

def evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for protein_graph, drug_graphs, values in dataloader:
            protein_graph = protein_graph.to(device)
            drug_graphs = drug_graphs.to(device)
            values = values.to(device)

            outputs = model(protein_graph, drug_graphs)
            all_preds.append(outputs.cpu())
            all_labels.append(values.cpu())
         
    preds = torch.cat(all_preds).numpy()
   
    values = torch.cat(all_labels).numpy()

    mse = mean_squared_error(values, preds)
    rmse = mse ** 0.5
    pearson_corr, _ = pearsonr(values, preds)

    return {
        "MSE": mse,
        "RMSE": rmse,
        "Pearson": pearson_corr
    }


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

val_metrics = evaluate(model, val_loader, device)
print("Validation Metrics:", val_metrics)