In [8]:
import pandas as pd
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

from AttentiveFP import Fingerprint  # assumed to be your attentive FP model
from AttentiveFP import get_smiles_dicts, get_smiles_array, num_atom_features, \
    num_bond_features  # assumed to be your actual featurizer

In [9]:
# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'Irritation_Corrosion_Eye_Irritation'

# --- Dataset ---
class MoleculeDataset(Dataset):
    def __init__(self, smiles_list, targets, feature_dicts):
        self.smiles_list = smiles_list
        self.targets = targets
        self.feature_dicts = feature_dicts
        self.x_atom, self.x_bond, self.x_atom_index, self.x_bond_index, self.x_mask, _ = get_smiles_array(smiles_list, feature_dicts)

    def __len__(self):
        return len(self.feature_dicts)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.x_atom[idx], dtype=torch.float32),
            torch.tensor(self.x_bond[idx], dtype=torch.float32),
            torch.tensor(self.x_atom_index[idx], dtype=torch.long),
            torch.tensor(self.x_bond_index[idx], dtype=torch.long),
            torch.tensor(self.x_mask[idx], dtype=torch.float32),
            torch.tensor(self.targets[idx], dtype=torch.float32)
        )

In [None]:
# --- Load data ---
df = pd.read_csv(f"../data/{model_name}.csv")
smiles_list = df["smiles"].tolist()
targets = df["active"].tolist()
feature_dicts = get_smiles_dicts(smiles_list)

# First split: 60% train, 20% val, 20% test
from sklearn.model_selection import train_test_split
train_smiles, temp_smiles, train_targets, temp_targets = train_test_split(
    smiles_list, targets, test_size=0.4, random_state=42, stratify=targets
)

val_smiles, test_smiles, val_targets, test_targets = train_test_split(
    temp_smiles, temp_targets, test_size=0.5, random_state=42, stratify=temp_targets
)

# Apply oversampling only to training data
from imblearn.over_sampling import RandomOverSampler
from collections import Counter

ros = RandomOverSampler()
ros = RandomOverSampler(random_state=42)
X_res, y_res = ros.fit_resample([[s] for s in train_smiles], train_targets)
train_smiles_resampled = [x[0] for x in X_res]

print("Class distribution in original training set:", Counter(train_targets))
print("Class distribution after resampling:", Counter(y_res))

# Flatten X_res to get new smiles
# Calculate class weights based on original training data
from sklearn.utils.class_weight import compute_class_weight
smiles_list_resampled = [x[0] for x in X_res]
class_weights = compute_class_weight('balanced', classes=[0, 1], y=train_targets)
class_weights_dict = {0: class_weights[0], 1: class_weights[1]}
print("Class weights:", class_weights_dict)

print(Counter(y_res))  # See new class balance
# Convert to tensor for PyTorch
pos_weight = torch.tensor([class_weights_dict[1] / class_weights_dict[0]], device=device)

# Rebuild features and dataset
# Get feature dictionaries for all sets
feature_dicts_resampled = get_smiles_dicts(smiles_list_resampled)
feature_dicts_train = get_smiles_dicts(train_smiles_resampled)
dataset = MoleculeDataset(smiles_list_resampled, y_res, feature_dicts_resampled)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
feature_dicts_val = get_smiles_dicts(val_smiles)
feature_dicts_test = get_smiles_dicts(test_smiles)

# Create datasets
train_dataset = MoleculeDataset(train_smiles_resampled, y_res, feature_dicts_train)
val_dataset = MoleculeDataset(val_smiles, val_targets, feature_dicts_val)
test_dataset = MoleculeDataset(test_smiles, test_targets, feature_dicts_test)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
# --- Model ---
model = Fingerprint(
    radius=5,
    T=4,  # Increased T for better message passing
    input_feature_dim=num_atom_features(),
    input_bond_dim=num_bond_features(),
    fingerprint_dim=256,  # Increased dimension for more complex feature learning
    output_units_num=1,
    p_dropout=0.5  # Increased dropout for better regularization
).to(device)

# Using a lower learning rate and higher weight decay for better generalization
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
optimizer = Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)

# Calculate class weights based on the original class distribution
n_samples = len(targets)
n_positives = sum(targets)
pos_weight = torch.tensor([2.0], device=device)
pos_weight = torch.tensor([(n_samples - n_positives) / n_positives], device=device)

# Use weighted loss function
loss_fn = BCEWithLogitsLoss()
loss_fn = BCEWithLogitsLoss(pos_weight=pos_weight)


In [None]:
from sklearn.model_selection import train_test_split

# Split: 60% train, 20% val, 20% test
train_smiles, temp_smiles, train_targets, temp_targets = train_test_split(
    smiles_list, targets, test_size=0.4, random_state=42
)

val_smiles, test_smiles, val_targets, test_targets = train_test_split(
    temp_smiles, temp_targets, test_size=0.5, random_state=42
)

# Datasets
train_set = MoleculeDataset(train_smiles, train_targets, feature_dicts)
val_set = MoleculeDataset(val_smiles, val_targets, feature_dicts)
test_set = MoleculeDataset(test_smiles, test_targets, feature_dicts)

# Loaders
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=True)


In [None]:
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import numpy as np
import os

def evaluate(loader):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for atom, bond, atom_deg, bond_deg, mask, target in loader:
            atom, bond, atom_deg, bond_deg, mask, target = [
                t.to(device) for t in (atom, bond, atom_deg, bond_deg, mask, target)
            ]
            _, pred, _ = model(atom, bond, atom_deg, bond_deg, mask)
            pred = pred.view(-1)
            target = target.view(-1).float()

            loss = loss_fn(pred, target)
            total_loss += loss.item()

            prob = torch.sigmoid(pred).cpu().numpy()
            all_probs.extend(prob)
            all_preds.extend((prob > 0.5).astype(int))
            all_targets.extend(target.cpu().numpy())

    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    try:
        auc = roc_auc_score(all_targets, all_probs)
    except:
        auc = 0.0
    return total_loss, acc, f1, auc

# Folder to save model
os.makedirs("../models", exist_ok=True)
best_loss = float("inf")
path = f"../models/model_{model_name}.pt"

for epoch in range(30):
    model.train()
    train_loss = 0

    for atom, bond, atom_deg, bond_deg, mask, target in train_loader:
        atom, bond, atom_deg, bond_deg, mask, target = [
            t.to(device) for t in (atom, bond, atom_deg, bond_deg, mask, target)
        ]
        _, pred, _ = model(atom, bond, atom_deg, bond_deg, mask)
        pred = pred.view(-1)
        target = target.view(-1).float()
        loss = loss_fn(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    val_loss, val_acc, val_f1, val_auc = evaluate(val_loader)
    train_eval_loss, _, _, _ = evaluate(train_loader)

    print(
        f"Epoch {epoch+1} - Train Loss: {train_loss:.4f} - Eval Train Loss: {train_eval_loss:.4f} - Val Loss: {val_loss:.4f} - Acc: {val_acc:.3f} - F1: {val_f1:.3f} - AUC: {val_auc:.3f}"
    )

    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), path)
        print(f"✅ Saved best model to {path} (val loss: {best_loss:.4f})")


In [None]:
# Load best model
model.load_state_dict(torch.load(path))
model.to(device)

test_loss, test_acc, test_f1, test_auc = evaluate(test_loader)
print(f"🧪 Test Loss: {test_loss:.4f} - Acc: {test_acc:.3f} - F1: {test_f1:.3f} - AUC: {test_auc:.3f}")