In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from rdkit import Chem
from torch_geometric.data import Data
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import optuna

def atom_features(atom):
    """Function to generate features for an atom."""
    return torch.tensor([
        atom.GetAtomicNum(),        
        atom.GetDegree(),           
        atom.GetFormalCharge(),     
        int(atom.GetIsAromatic())   
    ], dtype=torch.float)

def mol_to_graph(smiles):
    """Function to convert SMILES string to a graph with node features."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None  
    atoms = [atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.stack(atoms, dim=0)  
    
    edge_index = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i]) 

    edge_index = torch.tensor(edge_index).t().contiguous()
    
    return Data(x=x, edge_index=edge_index)

def load_dataset(file_path):
    """Load dataset and convert SMILES to molecular graphs."""
    df = pd.read_excel(file_path)
    graphs = []
    for i, row in df.iterrows():
        smiles = row['smiles']
        target = row['Target']
        graph = mol_to_graph(smiles)
        if graph is not None: 
            graph.y = torch.tensor([target], dtype=torch.float)
            graphs.append(graph)
    return graphs

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout_rate):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.fc1 = torch.nn.Linear(hidden_channels, hidden_channels // 2)
        self.fc2 = torch.nn.Linear(hidden_channels // 2, out_channels)
        self.dropout = torch.nn.Dropout(p=dropout_rate)

    def forward(self, data):
        x = data.x.float()
        batch = data.batch.long()  
        edge_index = data.edge_index.long()
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.sigmoid(x)  

def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    for data in loader:
        optimizer.zero_grad()
        out = model(data).squeeze()
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = (out > 0.5).float()
        correct += pred.eq(data.y).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

def test(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data in loader:
            out = model(data).squeeze()
            pred = (out > 0.5).float()
            correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

def objective(trial):
    hidden_channels = trial.suggest_int('hidden_channels', 32, 128)
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-1)
    dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
    num_epochs = trial.suggest_int('num_epochs', 10, 100)
    batch_size = trial.suggest_int('batch_size', 16, 64)

    file_path = 'molecules.xlsx'  
    graphs = load_dataset(file_path)
    train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, random_state=42)

    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)

    model = GCN(in_channels=4, hidden_channels=hidden_channels, out_channels=1, dropout_rate=dropout_rate)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.BCELoss()

    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, train_loader, optimizer, criterion)

    test_acc = test(model, test_loader)

    return test_acc

def main():
    study = optuna.create_study(direction='maximize')  
    study.optimize(objective, n_trials=15)  

    print("Best hyperparameters: ", study.best_params)
    print("Best accuracy: ", study.best_value)
    retrain_with_best_params(best_params)

def retrain_with_best_params(best_params):
    file_path = 'molecules.xlsx' 
    graphs = load_dataset(file_path)
    train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, random_state=42)

    train_loader = DataLoader(train_graphs, batch_size=best_params['batch_size'], shuffle=True)
    test_loader = DataLoader(test_graphs, batch_size=best_params['batch_size'], shuffle=False)

    model = GCN(in_channels=4, hidden_channels=best_params['hidden_channels'], out_channels=1, dropout_rate=best_params['dropout_rate'])
    optimizer = torch.optim.Adam(model.parameters(), lr=best_params['learning_rate'])
    criterion = torch.nn.BCELoss()

    train_losses = []
    train_accuracies = []
    test_accuracies = []

    for epoch in range(best_params['num_epochs']):
        train_loss, train_acc = train(model, train_loader, optimizer, criterion)
        test_acc = test(model, test_loader)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)

        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")

    plt.subplot(1, 2, 1)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Train vs Test Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_losses, label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Train Loss')
    plt.legend()

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()
