In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
import pandas as pd
from AttentiveFP import get_smiles_dicts, get_smiles_array, num_atom_features, num_bond_features  # assumed to be your actual featurizer
from AttentiveFP import Fingerprint  # assumed to be your attentive FP model

In [8]:
# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'Endocrine_Disruption_NR-AhR'

# --- Dataset ---
class MoleculeDataset(Dataset):
    def __init__(self, smiles_list, targets, feature_dicts):
        self.smiles_list = smiles_list
        self.targets = targets
        self.feature_dicts = feature_dicts
        self.x_atom, self.x_bond, self.x_atom_index, self.x_bond_index, self.x_mask, _ = get_smiles_array(smiles_list, feature_dicts)

    def __len__(self):
        return len(self.feature_dicts)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.x_atom[idx], dtype=torch.float32),
            torch.tensor(self.x_bond[idx], dtype=torch.float32),
            torch.tensor(self.x_atom_index[idx], dtype=torch.long),
            torch.tensor(self.x_bond_index[idx], dtype=torch.long),
            torch.tensor(self.x_mask[idx], dtype=torch.float32),
            torch.tensor(self.targets[idx], dtype=torch.float32)
        )

In [9]:
# --- Load data ---
df = pd.read_csv(f"../data/{model_name}.csv")
smiles_list = df["smiles"].tolist()
targets = df["active"].tolist()
feature_dicts = get_smiles_dicts(smiles_list)

# --- Dataloader ---
dataset = MoleculeDataset(smiles_list, targets, feature_dicts)
loader = DataLoader(dataset, batch_size=32, shuffle=True)





In [10]:
# --- Model ---
model = Fingerprint(
    radius=5,
    T=3,
    input_feature_dim=num_atom_features(),
    input_bond_dim=num_bond_features(),
    fingerprint_dim=150,
    output_units_num=1,
    p_dropout=0.3
).to(device)

optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

pos_weight = torch.tensor([2.0], device=device)
loss_fn = BCEWithLogitsLoss()

In [11]:
from sklearn.model_selection import train_test_split

# Split SMILES and targets
train_smiles, val_smiles, train_targets, val_targets = train_test_split(
    smiles_list, targets, test_size=0.2, random_state=42
)

train_set = MoleculeDataset(train_smiles, train_targets, feature_dicts)
val_set = MoleculeDataset(val_smiles, val_targets, feature_dicts)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)




In [12]:
import os

# Folder to save model
os.makedirs("../models", exist_ok=True)
best_loss = float("inf")
path = f"../models/model_{model_name}.pt"

try:
    # Load existing model
    model.load_state_dict(torch.load(path))
    model.to(device)
    model.eval()
except:
    print("The model could not be loaded")

best_loss = float("inf")
os.makedirs("../models", exist_ok=True)

for epoch in range(30):
    model.train()
    train_loss = 0

    for atom, bond, atom_deg, bond_deg, mask, target in train_loader:
        atom, bond, atom_deg, bond_deg, mask, target = [
            t.to(device) for t in (atom, bond, atom_deg, bond_deg, mask, target)
        ]
        _, pred, _ = model(atom, bond, atom_deg, bond_deg, mask)
        loss = loss_fn(pred.squeeze(), target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # --- Validation ---
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for atom, bond, atom_deg, bond_deg, mask, target in val_loader:
            atom, bond, atom_deg, bond_deg, mask, target = [
                t.to(device) for t in (atom, bond, atom_deg, bond_deg, mask, target)
            ]
            _, pred, _ = model(atom, bond, atom_deg, bond_deg, mask)
            loss = loss_fn(pred.squeeze(), target)
            val_loss += loss.item()

    print(f"Epoch {epoch+1} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), path)
        print(f"✅ Saved best model to {path} (val loss: {best_loss:.4f})")



Epoch 1 - Train Loss: 0.0520 - Val Loss: 4.1277
✅ Saved best model to ../models/model_Endocrine_Disruption_NR-AhR.pt (val loss: 4.1277)
Epoch 2 - Train Loss: 0.0054 - Val Loss: 6.3867
Epoch 3 - Train Loss: 0.0004 - Val Loss: 8.4239
Epoch 4 - Train Loss: 0.0001 - Val Loss: 10.1033
Epoch 5 - Train Loss: 0.0000 - Val Loss: 11.4481
Epoch 6 - Train Loss: 0.0000 - Val Loss: 12.5433
Epoch 7 - Train Loss: 0.0000 - Val Loss: 13.4420
Epoch 8 - Train Loss: 0.0000 - Val Loss: 14.1765
Epoch 9 - Train Loss: 0.0000 - Val Loss: 14.7960
Epoch 10 - Train Loss: 0.0000 - Val Loss: 15.3271
Epoch 11 - Train Loss: 0.0000 - Val Loss: 15.7717
Epoch 12 - Train Loss: 0.0000 - Val Loss: 16.1348
Epoch 13 - Train Loss: 0.0000 - Val Loss: 16.4271
Epoch 14 - Train Loss: 0.0000 - Val Loss: 16.6624
Epoch 15 - Train Loss: 0.0000 - Val Loss: 16.8533
Epoch 16 - Train Loss: 0.0000 - Val Loss: 17.0096
Epoch 17 - Train Loss: 0.0000 - Val Loss: 17.1387
Epoch 18 - Train Loss: 0.0000 - Val Loss: 17.2464
Epoch 19 - Train Loss: 0