In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATv2Conv, global_mean_pool
from rdkit import Chem
from torch_geometric.data import Data
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (roc_auc_score, accuracy_score, 
                             matthews_corrcoef, cohen_kappa_score,
                             brier_score_loss, confusion_matrix, 
                             precision_score, f1_score, recall_score)

def atom_features(atom):
    """Function to generate features for an atom."""
    return torch.tensor([
        atom.GetAtomicNum(),        
        atom.GetDegree(),           
        atom.GetFormalCharge(),     
        int(atom.GetIsAromatic())   
    ], dtype=torch.float)

def mol_to_graph(smiles):
    """Function to convert SMILES string to a graph with node features."""
    mol = Chem.MolFromSmiles(smiles)
    atoms = [atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.stack(atoms, dim=0)  
    
    edge_index = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i])  

    edge_index = torch.tensor(edge_index).t().contiguous()
    
    return Data(x=x, edge_index=edge_index)

def load_dataset(file_path):
    """Load dataset and convert SMILES to molecular graphs."""
    df = pd.read_excel(file_path)
    graphs = []
    for i, row in df.iterrows():
        smiles = row['smiles']
        target = row['Target']
        graph = mol_to_graph(smiles)
        graph.y = torch.tensor([target], dtype=torch.float)
        graphs.append(graph)
    return graphs

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=4):
        super(GAT, self).__init__()
        self.conv1 = GATv2Conv(in_channels, hidden_channels, heads=num_heads)
        self.conv2 = GATv2Conv(hidden_channels * num_heads, hidden_channels, heads=num_heads)
        self.fc1 = torch.nn.Linear(hidden_channels * num_heads, hidden_channels // 2)
        self.fc2 = torch.nn.Linear(hidden_channels // 2, out_channels)
        self.dropout = torch.nn.Dropout(p=0.5) 

    def forward(self, data):
        x = data.x.float()
        batch = data.batch.long()  
        edge_index = data.edge_index.long()
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.sigmoid(x)  

def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    for data in loader:
        optimizer.zero_grad()
        out = model(data).squeeze()  
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = (out > 0.5).float()  
        correct += pred.eq(data.y).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

def test(model, loader):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data).squeeze()
        pred = (out > 0.5).float()
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

def test_with_metrics(model, loader):
    model.eval()
    y_true = []
    y_pred_probs = []
    
    for data in loader:
        out = model(data).squeeze()
        y_true.extend(data.y.numpy()) 
        y_pred_probs.extend(out.detach().numpy())  

    y_true = np.array(y_true)
    y_pred_probs = np.array(y_pred_probs)

    y_pred = (y_pred_probs > 0.5).astype(int)

    mcc = matthews_corrcoef(y_true, y_pred)
    kappa = cohen_kappa_score(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0  
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  
    f1 = f1_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    auc = roc_auc_score(y_true, y_pred_probs)  
    brier_loss = brier_score_loss(y_true, y_pred_probs)
    accuracy = accuracy_score(y_true, y_pred)  

    return {
        "MCC": mcc,
        "Kappa": kappa,
        "Sensitivity": sensitivity,
        "Specificity": specificity,
        "F1 Score": f1,
        "Precision": precision,
        "AUC": auc,
        "Brier Loss": brier_loss,
        "Accuracy": accuracy,
    }

def main():
    file_path = 'molecules.xlsx'  
    graphs = load_dataset(file_path)
    train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, random_state=42)

    best_params = {
        'hidden_channels': 38,
        'learning_rate': 0.014771344254742054,
        'dropout_rate': 0.11860822630764481,
        'num_epochs': 140,
    }
    
    train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

    model = GAT(in_channels=4, hidden_channels=best_params['hidden_channels'], out_channels=1, num_heads=4)
    optimizer = torch.optim.Adam(model.parameters(), lr=best_params['learning_rate'])
    criterion = torch.nn.BCELoss()  

    train_losses = []
    train_accuracies = []
    test_accuracies = []

    for epoch in range(best_params['num_epochs']):
        train_loss, train_acc = train(model, train_loader, optimizer, criterion)
        test_acc = test(model, test_loader)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)

        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")

    plt.figure(figsize=(8, 6), dpi=300)
    plt.subplot(1, 2, 1)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Train vs Test Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_losses, label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Train Loss')
    plt.legend()

    plt.tight_layout()
    plt.savefig('GAT_model.png')
    plt.show()

    test_metrics = test_with_metrics(model, test_loader)
    print("Test Metrics:", test_metrics)

    external_file_path = 'externalshiamiles_797.xlsx'  # External dataset
    external_graphs = load_dataset(external_file_path)
    external_loader = DataLoader(external_graphs, batch_size=32, shuffle=False)

    external_metrics = test_with_metrics(model, external_loader)
    print("External Validation Metrics:", external_metrics)

if __name__ == "__main__":
    main()
