In [35]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Load the Tox21 dataset
df = pd.read_csv('tox21.csv')

# Check for missing values in SMILES
print(f"Number of molecules with missing SMILES: {df['smiles'].isna().sum()}")

# Define the target columns
target_columns = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", 
    "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", 
    "SR-HSE", "SR-MMP", "SR-p53"
]

# Define atom feature extraction function
def get_atom_features(atom):
    """
    Returns an array of atom features.
    """
    # Atom type (one-hot encoded)
    atom_type_one_hot = np.zeros(100)
    atom_num = atom.GetAtomicNum()
    if atom_num < 100:
        atom_type_one_hot[atom_num] = 1
    
    # Other atom features
    formal_charge = atom.GetFormalCharge()
    hybridization = atom.GetHybridization()
    is_aromatic = int(atom.GetIsAromatic())
    num_h = atom.GetTotalNumHs()
    
    # Hybridization (one-hot encoded)
    hybridization_one_hot = np.zeros(6)
    if hybridization.name in ['S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2']:
        hyb_idx = ['S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2'].index(hybridization.name)
        hybridization_one_hot[hyb_idx] = 1
    
    # Combine all features
    atom_features = np.concatenate([
        atom_type_one_hot,
        np.array([formal_charge + 4]),  # Shift +4 to ensure positive values
        hybridization_one_hot,
        np.array([is_aromatic]),
        np.array([num_h])
    ])
    
    return atom_features

# Define bond feature extraction function
def get_bond_features(bond):
    """
    Returns an array of bond features.
    """
    # Bond type (one-hot encoded)
    bond_type_one_hot = np.zeros(4)
    bond_type = bond.GetBondType()
    if bond_type == Chem.rdchem.BondType.SINGLE:
        bond_type_one_hot[0] = 1
    elif bond_type == Chem.rdchem.BondType.DOUBLE:
        bond_type_one_hot[1] = 1
    elif bond_type == Chem.rdchem.BondType.TRIPLE:
        bond_type_one_hot[2] = 1
    elif bond_type == Chem.rdchem.BondType.AROMATIC:
        bond_type_one_hot[3] = 1
    
    # Other bond features
    is_conjugated = int(bond.GetIsConjugated())
    is_in_ring = int(bond.IsInRing())
    
    # Combine all features
    bond_features = np.concatenate([
        bond_type_one_hot,
        np.array([is_conjugated, is_in_ring])
    ])
    
    return bond_features

# Convert SMILES to molecular graphs
def smiles_to_graph(smiles):
    """
    Converts a SMILES string to a PyTorch Geometric Data object containing the molecular graph.
    """
    try:
        # Convert SMILES to RDKit molecule
        mol = Chem.MolFromSmiles(smiles)
        
        if mol is None:
            return None
            
        # Get atom features
        atom_features_list = []
        for atom in mol.GetAtoms():
            atom_features_list.append(get_atom_features(atom))
        x = torch.tensor(np.array(atom_features_list), dtype=torch.float)
        
        # Get edge indices and edge features
        edge_indices = []
        edge_features_list = []
        
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            
            edge_indices.append([i, j])
            edge_indices.append([j, i])  # Add reverse edge for undirected graph
            
            edge_features = get_bond_features(bond)
            edge_features_list.append(edge_features)
            edge_features_list.append(edge_features)  # Duplicate for reverse edge
            
        edge_index = torch.tensor(np.array(edge_indices).T, dtype=torch.long)
        edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.float)
        
        # Create PyTorch Geometric Data object
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        
        return data
    
    except Exception as e:
        print(f"Error converting SMILES to graph: {e}")
        return None

# Create a PyTorch Dataset
class Tox21Dataset(Dataset):
    def __init__(self, dataframe, target_columns, smiles_column='smiles'):
        self.dataframe = dataframe.reset_index(drop=True)
        self.target_columns = target_columns
        self.smiles_column = smiles_column
        
        # Convert SMILES to molecular graphs
        self.graphs = []
        self.valid_indices = []
        
        for idx, row in tqdm(self.dataframe.iterrows(), total=len(self.dataframe), desc="Converting SMILES to graphs"):
            graph = smiles_to_graph(row[self.smiles_column])
            if graph is not None:
                self.graphs.append(graph)
                self.valid_indices.append(idx)
        
        # Only keep rows with valid graphs
        self.dataframe = self.dataframe.iloc[self.valid_indices].reset_index(drop=True)
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        graph = self.graphs[idx]
        
        # Get targets (handling NaN values)
        targets = []
        for col in self.target_columns:
            value = self.dataframe.loc[idx, col]
            # Convert NaN to -1 (will be masked during loss calculation)
            targets.append(-1 if pd.isna(value) else value)
            
        return graph, torch.tensor(targets, dtype=torch.float)

# Define Graph Convolutional Network (GCN) model
class GCNModel(nn.Module):
    def __init__(self, node_features, edge_features, hidden_dim, output_dim):
        super(GCNModel, self).__init__()
        
        # GCN layers
        self.conv1 = GCNConv(node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # MLP for final prediction
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Apply GCN layers with residual connections
        x1 = F.relu(self.conv1(x, edge_index))
        x2 = F.relu(self.conv2(x1, edge_index)) + x1
        x3 = F.relu(self.conv3(x2, edge_index)) + x2
        
        # Global pooling to get graph-level representation
        x = global_mean_pool(x3, batch)
        
        # Apply MLP for final prediction
        out = self.mlp(x)
        
        return out

# Define Message Passing Neural Network (MPNN) model
class MPNNLayer(MessagePassing):
    def __init__(self, node_dim, edge_dim, hidden_dim):
        super(MPNNLayer, self).__init__(aggr='add')
        
        # MLPs for message passing
        self.message_mlp = nn.Sequential(
            nn.Linear(node_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Update function
        self.update_mlp = nn.Sequential(
            nn.Linear(node_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, x, edge_index, edge_attr):
        # Start propagating messages
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)
    
    def message(self, x_i, x_j, edge_attr):
        # Create messages based on source nodes and edge features
        message_input = torch.cat([x_j, edge_attr], dim=1)
        return self.message_mlp(message_input)
    
    def update(self, aggr_out, x):
        # Update node embeddings
        update_input = torch.cat([x, aggr_out], dim=1)
        return self.update_mlp(update_input)

class MPNNModel(nn.Module):
    def __init__(self, node_features, edge_features, hidden_dim, output_dim):
        super(MPNNModel, self).__init__()
        
        # MPNN layers
        self.mpnn1 = MPNNLayer(node_features, edge_features, hidden_dim)
        self.mpnn2 = MPNNLayer(hidden_dim, edge_features, hidden_dim)
        self.mpnn3 = MPNNLayer(hidden_dim, edge_features, hidden_dim)
        
        # MLP for final prediction
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        # Apply MPNN layers with residual connections
        x1 = F.relu(self.mpnn1(x, edge_index, edge_attr))
        x2 = F.relu(self.mpnn2(x1, edge_index, edge_attr)) + x1
        x3 = F.relu(self.mpnn3(x2, edge_index, edge_attr)) + x2
        
        # Global pooling to get graph-level representation
        x = global_mean_pool(x3, batch)
        
        # Apply MLP for final prediction
        out = self.mlp(x)
        
        return out

# Define a masked loss function to handle missing values
def masked_bce_loss(pred, target):
    # Create a mask for non-missing values (where target != -1)
    mask = (target != -1).float()
    
    # Apply sigmoid to get probabilities
    pred_probs = torch.sigmoid(pred)
    
    # Compute BCE loss only for non-missing values
    loss = F.binary_cross_entropy_with_logits(pred, target * mask, reduction='none')
    loss = loss * mask  # Zero out loss for missing values
    
    # Compute mean loss over non-missing values
    non_missing = mask.sum()
    if non_missing > 0:
        return loss.sum() / non_missing
    else:
        return torch.tensor(0.0, device=pred.device)

# Training function
def train_model(model, train_loader, val_loader, optimizer, num_epochs=50, device='cuda'):
    model.to(device)
    best_val_roc_auc = 0
    best_model_state = None
    
    train_losses = []
    val_losses = []
    val_roc_aucs = []
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        epoch_loss = 0
        
        for data, targets in train_loader:
            data = data.to(device)
            targets = targets.to(device)
            
            # Forward pass
            outputs = model(data)
            loss = masked_bce_loss(outputs, targets)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item() * len(targets)
        
        avg_train_loss = epoch_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        all_targets = []
        all_outputs = []
        
        with torch.no_grad():
            for data, targets in val_loader:
                data = data.to(device)
                targets = targets.to(device)
                
                # Forward pass
                outputs = model(data)
                loss = masked_bce_loss(outputs, targets)
                
                val_loss += loss.item() * len(targets)
                
                # Store predictions and targets for ROC-AUC calculation
                # Only store non-missing values
                mask = (targets != -1).cpu().numpy()
                all_targets.append(targets.cpu().numpy())
                all_outputs.append(outputs.cpu().numpy())
        
        avg_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(avg_val_loss)
        
        # Compute ROC-AUC for each task
        all_targets = np.vstack(all_targets)
        all_outputs = np.vstack(all_outputs)
        
        task_roc_aucs = []
        for task_idx in range(len(target_columns)):
            # Get mask for non-missing values
            mask = (all_targets[:, task_idx] != -1)
            if mask.sum() > 0 and len(np.unique(all_targets[mask, task_idx])) > 1:
                y_true = all_targets[mask, task_idx]
                y_score = all_outputs[mask, task_idx]
                try:
                    roc_auc = roc_auc_score(y_true, y_score)
                    task_roc_aucs.append(roc_auc)
                except:
                    pass
        
        if task_roc_aucs:
            mean_roc_auc = np.mean(task_roc_aucs)
            val_roc_aucs.append(mean_roc_auc)
            
            # Save best model
            if mean_roc_auc > best_val_roc_auc:
                best_val_roc_auc = mean_roc_auc
                best_model_state = model.state_dict().copy()
            
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Mean ROC-AUC: {mean_roc_auc:.4f}")
        else:
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Mean ROC-AUC: N/A")
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, train_losses, val_losses, val_roc_aucs

# Evaluation function
def evaluate_model(model, test_loader, device='cuda'):
    model.eval()
    all_targets = []
    all_outputs = []
    
    with torch.no_grad():
        for data, targets in test_loader:
            data = data.to(device)
            
            # Forward pass
            outputs = model(data)
            
            # Store predictions and targets
            all_targets.append(targets.cpu().numpy())
            all_outputs.append(outputs.cpu().numpy())
    
    all_targets = np.vstack(all_targets)
    all_outputs = np.vstack(all_outputs)
    
    # Compute metrics for each task
    task_metrics = {}
    for task_idx, task_name in enumerate(target_columns):
        # Get mask for non-missing values
        mask = (all_targets[:, task_idx] != -1)
        if mask.sum() > 0 and len(np.unique(all_targets[mask, task_idx])) > 1:
            y_true = all_targets[mask, task_idx]
            y_score = all_outputs[mask, task_idx]
            y_pred = (y_score > 0).astype(int)  # Threshold at 0 (sigmoid=0.5)
            
            try:
                task_roc_auc = roc_auc_score(y_true, y_score)
                task_accuracy = accuracy_score(y_true, y_pred)
                
                task_metrics[task_name] = {
                    'roc_auc': task_roc_auc,
                    'accuracy': task_accuracy
                }
            except:
                task_metrics[task_name] = {
                    'roc_auc': np.nan,
                    'accuracy': np.nan
                }
    
    return task_metrics

# Main function to run the full pipeline
def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load and preprocess data
    df = pd.read_csv('tox21.csv')
    
    # Split data into train, validation, and test sets (64%, 16%, 20%)
    train_df, temp_df = train_test_split(df, test_size=0.36, random_state=42)  # 64% for training
    val_df, test_df = train_test_split(temp_df, test_size=0.5556, random_state=42)  # 16% for validation, 20% for testing
    
    print(f"Train set size: {len(train_df)}")
    print(f"Validation set size: {len(val_df)}")
    print(f"Test set size: {len(test_df)}")
    
    # Create datasets
    train_dataset = Tox21Dataset(train_df, target_columns)
    val_dataset = Tox21Dataset(val_df, target_columns)
    test_dataset = Tox21Dataset(test_df, target_columns)
    
    print(f"Train dataset size after processing: {len(train_dataset)}")
    print(f"Validation dataset size after processing: {len(val_dataset)}")
    print(f"Test dataset size after processing: {len(test_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=Batch.from_data_list)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=Batch.from_data_list)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=Batch.from_data_list)
    
    # Get dimensions for model initialization
    sample_graph = train_dataset[0][0]
    node_features = sample_graph.x.size(1)
    edge_features = sample_graph.edge_attr.size(1)
    output_dim = len(target_columns)
    
    print(f"Node feature dimension: {node_features}")
    print(f"Edge feature dimension: {edge_features}")
    print(f"Output dimension: {output_dim}")
    
    # Initialize models
    gcn_model = GCNModel(node_features, edge_features, hidden_dim=64, output_dim=output_dim)
    mpnn_model = MPNNModel(node_features, edge_features, hidden_dim=64, output_dim=output_dim)
    
    # Train and evaluate GCN model
    print("\n--- Training GCN Model ---")
    gcn_optimizer = torch.optim.Adam(gcn_model.parameters(), lr=0.001)
    gcn_model, gcn_train_losses, gcn_val_losses, gcn_val_roc_aucs = train_model(
        gcn_model, train_loader, val_loader, gcn_optimizer, num_epochs=30, device=device
    )
    
    # Evaluate GCN model on test set
    print("\n--- Evaluating GCN Model ---")
    gcn_metrics = evaluate_model(gcn_model, test_loader, device=device)
    
    # Print GCN metrics
    print("\nGCN Test Metrics:")
    gcn_roc_aucs = []
    for task, metrics in gcn_metrics.items():
        print(f"{task}: ROC-AUC={metrics['roc_auc']:.4f}, Accuracy={metrics['accuracy']:.4f}")
        gcn_roc_aucs.append(metrics['roc_auc'])
    
    print(f"Mean GCN ROC-AUC: {np.nanmean(gcn_roc_aucs):.4f}")
    
    # Train and evaluate MPNN model
    print("\n--- Training MPNN Model ---")
    mpnn_optimizer = torch.optim.Adam(mpnn_model.parameters(), lr=0.001)
    mpnn_model, mpnn_train_losses, mpnn_val_losses, mpnn_val_roc_aucs = train_model(
        mpnn_model, train_loader, val_loader, mpnn_optimizer, num_epochs=30, device=device
    )
    
    # Evaluate MPNN model on test set
    print("\n--- Evaluating MPNN Model ---")
    mpnn_metrics = evaluate_model(mpnn_model, test_loader, device=device)
    
    # Print MPNN metrics
    print("\nMPNN Test Metrics:")
    mpnn_roc_aucs = []
    for task, metrics in mpnn_metrics.items():
        print(f"{task}: ROC-AUC={metrics['roc_auc']:.4f}, Accuracy={metrics['accuracy']:.4f}")
        mpnn_roc_aucs.append(metrics['roc_auc'])
    
    print(f"Mean MPNN ROC-AUC: {np.nanmean(mpnn_roc_aucs):.4f}")
    
    # Compare GCN and MPNN results
    print("\n--- Model Comparison ---")
    print(f"GCN Mean ROC-AUC: {np.nanmean(gcn_roc_aucs):.4f}")
    print(f"MPNN Mean ROC-AUC: {np.nanmean(mpnn_roc_aucs):.4f}")
    
    # Plot training curves
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(gcn_train_losses, label='GCN Train Loss')
    plt.plot(gcn_val_losses, label='GCN Val Loss')
    plt.plot(mpnn_train_losses, label='MPNN Train Loss')
    plt.plot(mpnn_val_losses, label='MPNN Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(gcn_val_roc_aucs, label='GCN ROC-AUC')
    plt.plot(mpnn_val_roc_aucs, label='MPNN ROC-AUC')
    plt.xlabel('Epoch')
    plt.ylabel('ROC-AUC')
    plt.legend()
    plt.title('Validation ROC-AUC')
    
    plt.tight_layout()
    plt.savefig('gnn_training_curves.png')
    plt.close()
    
    # Create performance comparison bar chart
    plt.figure(figsize=(14, 8))
    
    # Get data for plotting
    tasks = list(gcn_metrics.keys())
    gcn_scores = [metrics['roc_auc'] for metrics in gcn_metrics.values()]
    mpnn_scores = [metrics['roc_auc'] for metrics in mpnn_metrics.values()]
    
    x = np.arange(len(tasks))
    width = 0.35
    
    plt.bar(x - width/2, gcn_scores, width, label='GCN')
    plt.bar(x + width/2, mpnn_scores, width, label='MPNN')
    
    plt.xlabel('Target')
    plt.ylabel('ROC-AUC')
    plt.title('Model Performance Comparison')
    plt.xticks(x, tasks, rotation=45, ha='right')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('gnn_performance_comparison.png')
    plt.close()

if __name__ == "__main__":
    main()

Number of molecules with missing SMILES: 0
Using device: cpu
Train set size: 5123
Validation set size: 1281
Test set size: 1602


Converting SMILES to graphs: 100%|██████████| 5123/5123 [00:02<00:00, 2266.43it/s]
Converting SMILES to graphs: 100%|██████████| 1281/1281 [00:00<00:00, 2284.05it/s]
Converting SMILES to graphs: 100%|██████████| 1602/1602 [00:00<00:00, 2295.45it/s]


Train dataset size after processing: 5123
Validation dataset size after processing: 1281
Test dataset size after processing: 1602
Node feature dimension: 109
Edge feature dimension: 6
Output dimension: 12

--- Training GCN Model ---
Epoch 1/30, Train Loss: 0.3014, Val Loss: 0.2500, Mean ROC-AUC: 0.6338
Epoch 2/30, Train Loss: 0.2527, Val Loss: 0.2495, Mean ROC-AUC: 0.6462
Epoch 3/30, Train Loss: 0.2511, Val Loss: 0.2444, Mean ROC-AUC: 0.6563
Epoch 4/30, Train Loss: 0.2468, Val Loss: 0.2418, Mean ROC-AUC: 0.7058
Epoch 5/30, Train Loss: 0.2450, Val Loss: 0.2397, Mean ROC-AUC: 0.7078
Epoch 6/30, Train Loss: 0.2415, Val Loss: 0.2402, Mean ROC-AUC: 0.7159
Epoch 7/30, Train Loss: 0.2422, Val Loss: 0.2392, Mean ROC-AUC: 0.7197
Epoch 8/30, Train Loss: 0.2393, Val Loss: 0.2363, Mean ROC-AUC: 0.7230
Epoch 9/30, Train Loss: 0.2372, Val Loss: 0.2365, Mean ROC-AUC: 0.7249
Epoch 10/30, Train Loss: 0.2359, Val Loss: 0.2355, Mean ROC-AUC: 0.7281
Epoch 11/30, Train Loss: 0.2350, Val Loss: 0.2368, Mean 