


### AI for Drug Discovery

Drug discovery traditionally involves:
1. Identifying a biological target (e.g., an enzyme in bacteria).
2. Screening large collections of molecules to find those with activity against the target.
3. Optimizing hits into drug-like molecules.

AI reshapes this pipeline by:
* Predicting properties (activity, toxicity, pharmacokinetics) from molecular structures.
* Generating new molecules with desired properties.
* Providing natural language interface to interact with scientists.

Deep learning is particularly well-suited because it can automatically extract useful representations of molecules from raw formats (SMILES, molecular graphs, or 3D structures) rather than relying entirely on handcrafted features.

---

### Molecules and SMILES

A molecule is a collection of atoms bonded together, forming the fundamental units of chemicals that make up the world around us. In drug discovery, molecules are potential small-molecule drugs that can interact with biological targets such as proteins, DNA, or membranes.

To represent molecules in a form that computers can understand, one common format is SMILES (Simplified Molecular Input Line Entry System).
* SMILES encodes the structure of a molecule as a text string, using characters to denote atoms (e.g., C for carbon, O for oxygen), symbols for bonds (e.g., = for double bond), and parentheses for branches.
* Examples:
  * Ethanol has SMILES CCO (two carbons and an oxygen).
  * Benzene has SMILES c1ccccc1.

SMILES are convenient because they are compact, human-readable, and directly usable in sequence models (CNNs, RNNs, Transformers). But they are not unique — the same molecule can have multiple valid SMILES strings — and they lose explicit 3D information. Therefore, we will also explore graph-based models (GNNs and EGNNs)

---

### Molecular Property Prediction

Molecular property prediction is the task of predicting how a molecule behaves — its physical, chemical, or biological properties — directly from its structure. Examples include:
* Solubility in water.
* Toxicity to human cells.
* Binding affinity to a protein target.
* Biological activity, such as antibacterial activity.

This problem is central to drug discovery, because experimentally measuring properties for millions of molecules is slow and expensive. AI models can act as filters, rapidly prioritizing which molecules are promising candidates for laboratory testing.

---

### Antibiotic Discovery

In this project, I explore the use of deep learning models to predict whether a small molecule has antibacterial activity against Pseudomonas aeruginosa. The dataset includes SMILES strings paired with binary activity labels (1 = active, 0 = inactive), split into training, validation, and test sets.

The goal is to compare several neural architectures for molecular property prediction, using PyTorch implementations built from scratch. The models include:

A feedforward network trained on Morgan fingerprints

1D CNNs applied to tokenized SMILES

Recurrent models (GRU/LSTM) over SMILES sequences

Transformer-based sequence models

Graph neural networks (GCN, GIN, GraphSAGE) operating on molecular graphs

Equivariant graph neural networks (EGNN) incorporating 3D structural information

I built a full training pipeline with early stopping, checkpointing, and evaluation on a held-out test set. After training all models, I compared their predictive performance and analyzed how different input representations (fingerprints, sequences, 2D graphs, 3D graphs) influence accuracy on this antibacterial prediction task.

In [3]:
import torch, torch_geometric
print(torch.__version__, torch_geometric.__version__)


2.8.0 2.6.1


In [40]:
# %pip install torch torchvision #--index-url https://download.pytorch.org/whl/cu126  # uncomment this if you have a GPU
# %pip install pandas numpy scikit-learn tqdm
# %pip install rdkit
# %pip install torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv #-f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__.split('+')[0])").html

## 1) Configuration

In [3]:
from dataclasses import dataclass, asdict
from pathlib import Path
import json

@dataclass
class Config:
    # Data
    data_dir: str = ""      # Insert path containing train.csv, val.csv, test.csv
    smiles_col: str = "smiles"
    target_cols: tuple = ("activity",)  # can be multiple targets: ("y_0","y_1",...)
    task_type: str = "classification"  # "regression" or "classification"

    # General training
    seed: int = 42
    batch_size: int = 8
    num_workers: int = 0
    max_epochs: int = 50
    patience: int = 10           # early stopping
    lr: float = 3e-4
    weight_decay: float = 1e-5
    grad_clip: float = 1.0

    # Fingerprints
    fp_nbits: int = 2048
    fp_radius: int = 2

    # SMILES sequence models
    max_len: int = 256
    vocab_extra: str = ""        # optional custom tokens
    embed_dim: int = 256
    cnn_channels: int = 256
    cnn_kernel_sizes: tuple = (3,5,7)
    rnn_hidden: int = 256
    rnn_layers: int = 2
    transformer_heads: int = 8
    transformer_layers: int = 4
    transformer_ff: int = 512
    dropout: float = 0.1

    # Graph models
    pyg_backend: str = "gcn"     # "gcn", "gin", "graphsage"
    gnn_hidden: int = 256
    gnn_layers: int = 4
    gnn_dropout: float = 0.1
    global_pool: str = "mean"    # "mean", "add", "max"

    # EGNN (3D)
    egnn_hidden: int = 128
    egnn_layers: int = 4
    egnn_use_radial: bool = True
    egnn_cutoff: float = 8.0
    egnn_conformers: int = 1     # number of conformers to generate if 3D not given

    # Logging/checkpoint
    out_dir: str = "" #insert output dir

cfg = Config()
Path(cfg.out_dir).mkdir(parents=True, exist_ok=True)
print(json.dumps(asdict(cfg), indent=2))

{
  "data_dir": "",
  "smiles_col": "smiles",
  "target_cols": [
    "activity"
  ],
  "task_type": "classification",
  "seed": 42,
  "batch_size": 8,
  "num_workers": 0,
  "max_epochs": 50,
  "patience": 10,
  "lr": 0.0003,
  "weight_decay": 1e-05,
  "grad_clip": 1.0,
  "fp_nbits": 2048,
  "fp_radius": 2,
  "max_len": 256,
  "vocab_extra": "",
  "embed_dim": 256,
  "cnn_channels": 256,
  "cnn_kernel_sizes": [
    3,
    5,
    7
  ],
  "rnn_hidden": 256,
  "rnn_layers": 2,
  "transformer_heads": 8,
  "transformer_layers": 4,
  "transformer_ff": 512,
  "dropout": 0.1,
  "pyg_backend": "gcn",
  "gnn_hidden": 256,
  "gnn_layers": 4,
  "gnn_dropout": 0.1,
  "global_pool": "mean",
  "egnn_hidden": 128,
  "egnn_layers": 4,
  "egnn_use_radial": true,
  "egnn_cutoff": 8.0,
  "egnn_conformers": 1,
  "out_dir": ""
}


## 2) Utilities (Seed, Device, Metrics)

In [5]:
import os, random, math
import numpy as np
import torch

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
set_seed(cfg.seed)

# Metrics
from sklearn.metrics import roc_auc_score, average_precision_score

def classification_metrics(y_true, y_pred_proba):
    metrics = {}
    y_true = np.asarray(y_true)
    y_pred_proba = np.asarray(y_pred_proba)
    metrics["ROC-AUC"] = roc_auc_score(y_true, y_pred_proba)
    metrics["AP"] = average_precision_score(y_true, y_pred_proba)
    return metrics

Device: cpu


## 3) Data Loading

We support three representations:
- **Fingerprints (Morgan/ECFP)** → fast MLP baseline.
- **SMILES-as-sequence** → CNN/RNN/Transformer.
- **Molecular graphs (± 3D)** → GNN/EGNN via RDKit + PyG.

In [6]:
import pandas as pd

def load_split(csv_path, smiles_col, target_cols):
    df = pd.read_csv(csv_path)
    assert smiles_col in df.columns, f"missing {smiles_col}"
    for c in target_cols:
        assert c in df.columns, f"missing target col {c}"
    return df[[smiles_col] + list(target_cols)].copy()

train_df = load_split(Path(cfg.data_dir)/"train.csv", cfg.smiles_col, cfg.target_cols)
val_df = load_split(Path(cfg.data_dir)/"val.csv", cfg.smiles_col, cfg.target_cols)
test_df = load_split(Path(cfg.data_dir)/"test.csv", cfg.smiles_col, cfg.target_cols)

print("Train/Val/Test sizes:", len(train_df), len(val_df), len(test_df))
train_df.head()

Train/Val/Test sizes: 1421 474 474


Unnamed: 0,smiles,activity
0,C1=CC(=C(C=C1C(CN)O)O)O.[C@@H]([C@H](C(=O)O)O)...,0
1,CC(=O)OC1=CC=C(C=C1)O,0
2,CC(C)NC(=O)C1=CC=C(C=C1)CNNC.Cl,0
3,CC1=CC(=CC(=C1C(=O)O)O)O,0
4,C1=CC=C(C=C1)CN=C=S,0


### 3.1 SMILES Tokenization (Character-level)

Build a vocabulary over the training set and write functions to encode (batch of) smiles to a vector (or a matrix)

In [7]:
# Build vocabulary from training SMILES
PAD, BOS, EOS, UNK = "<pad>", "<bos>", "<eos>", "<unk>"

def build_vocab(smiles_list, extra=""):
    charset = set()
    for s in smiles_list:
        for ch in s:
            charset.add(ch)
    for ch in extra:
        charset.add(ch)
    vocab = [PAD, BOS, EOS, UNK] + sorted(list(charset))
    stoi = {c:i for i,c in enumerate(vocab)}
    itos = {i:c for c,i in stoi.items()}
    return vocab, stoi, itos

vocab, stoi, itos = build_vocab(train_df[cfg.smiles_col].tolist(), cfg.vocab_extra)
vocab_size = len(vocab)
print("Vocab size:", vocab_size)

def encode_smiles(s, max_len, stoi):
    seq = [stoi.get(ch, stoi[UNK]) for ch in s[:max_len-2]]
    seq = [stoi[BOS]] + seq + [stoi[EOS]]
    if len(seq) < max_len:
        seq += [stoi[PAD]] * (max_len - len(seq))
    else:
        seq = seq[:max_len]
    return np.array(seq, dtype=np.int64)

def batch_collate_sequence(batch):
    xs = torch.tensor([b[0] for b in batch], dtype=torch.long)
    ys = torch.tensor([b[1] for b in batch], dtype=torch.float32)
    return xs.to(device), ys.to(device)

Vocab size: 49


### 3.2 Morgan Fingerprint featurization (ECFP)

In [8]:
try:
    from rdkit import Chem
    from rdkit.Chem import AllChem
    has_rdkit = True
except Exception as e:
    print("RDKit not available:", e)
    has_rdkit = False

def morgan_fp_from_smiles(smiles: str, n_bits: int, radius: int):
    if not has_rdkit:
        return None
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None
    bv = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
    arr = np.zeros((n_bits,), dtype=np.float32)
    from rdkit.DataStructs.cDataStructs import ConvertToNumpyArray
    ConvertToNumpyArray(bv, arr)
    return arr

### 3.3 Graph Construction (RDKit → PyG)


In [36]:
# Minimal RDKit -> PyG conversion
try:
    from rdkit import Chem
    from rdkit.Chem import AllChem
    has_rdkit = True
except Exception as e:
    print("RDKit not available:", e)
    has_rdkit = False

try:
    import torch_geometric as tg
    from torch_geometric.data import Data
    from torch_geometric.loader import DataLoader as GeoDataLoader
    from torch_geometric.nn import GCNConv, GINConv, SAGEConv, global_mean_pool, global_add_pool, global_max_pool
    has_pyg = True
except Exception as e:
    print("PyG not available:", e)
    has_pyg = False

def atom_features(atom):
    atom_types = ['C','N','O','F','P','S','Cl','Br','I']
    elem_onehot = [int(atom.GetSymbol() == e) for e in atom_types]

    degree = [int(atom.GetDegree() == i) for i in range(5)]

    formal_charge = [atom.GetFormalCharge()]

    hybs = [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
    ]
    hybridization = [int(atom.GetHybridization() == h) for h in hybs]

    aromatic = [int(atom.GetIsAromatic())]
    in_ring = [int(atom.IsInRing())]
    implicit_h = [atom.GetTotalNumHs(includeNeighbors=True)]
    chiral = [int(atom.HasProp("_ChiralityPossible"))]

    feats = (
        elem_onehot
        + degree
        + formal_charge
        + hybridization
        + aromatic
        + in_ring
        + implicit_h
        + chiral
    )
    return feats

def bond_features(bond):
    btype = [
        int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE),
        int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE),
        int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE),
        int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC),
    ]
    conjugated = [int(bond.GetIsConjugated())]
    in_ring = [int(bond.IsInRing())]
    stereo = [
        int(bond.GetStereo() == Chem.rdchem.BondStereo.STEREOZ),
        int(bond.GetStereo() == Chem.rdchem.BondStereo.STEREOE),
    ]
    feats = btype + conjugated + in_ring + stereo
    assert len(feats) == 8, f"Inconsistent bond feature length: {len(feats)}"
    return feats


def smiles_to_pyg(smiles, need_3d=False, egnn_confs=1):
    if not has_rdkit:
        return None
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    mol = Chem.AddHs(mol)
    pos = None
    if need_3d:
        try:
            params = AllChem.ETKDGv3()
            params.randomSeed = 0xf00d
            AllChem.EmbedMultipleConfs(mol, numConfs=max(1, egnn_confs), params=params)
            AllChem.UFFOptimizeMoleculeConfs(mol, maxIters=200)
            conf = mol.GetConformer(id=0)
            pos = []
        except Exception:
            pass
    atoms = [atom_features(a) for a in mol.GetAtoms()]
    x = torch.tensor(atoms, dtype=torch.float32)
    edge_index = [[], []]
    edge_attr = []
    for b in mol.GetBonds():
        u, v = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        bf = bond_features(b)
        edge_index[0] += [u, v]
        edge_index[1] += [v, u]
        edge_attr += [bf, bf]
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float32) if len(edge_attr)>0 else None
    data_kwargs = {"x": x, "edge_index": edge_index}
    if edge_attr is None or len(edge_attr) == 0:
        edge_attr = torch.zeros((0, 8), dtype=torch.float32)
    data_kwargs["edge_attr"] = edge_attr

    if need_3d and pos is None and mol.GetNumConformers()>0:
        conf = mol.GetConformer(id=0)
        pos = torch.tensor([[conf.GetAtomPosition(i).x,
                             conf.GetAtomPosition(i).y,
                             conf.GetAtomPosition(i).z] for i in range(mol.GetNumAtoms())],
                           dtype=torch.float32)
        data_kwargs["pos"] = pos
    elif need_3d and pos is None:
        # fallback zeros
        data_kwargs["pos"] = torch.zeros((mol.GetNumAtoms(),3), dtype=torch.float32)
    return Data(**data_kwargs)

### 3.3 Dataset & Loaders

Write a Pytorch Dataset and DataLoader class for SMILES data

In [52]:
from torch.utils.data import Dataset, DataLoader

# Fingerprint dataset
class FingerprintDataset(Dataset):
    def __init__(self, df, smiles_col, target_cols, n_bits, radius):
        assert has_rdkit, "RDKit required for Morgan fingerprints"
        self.df = df.reset_index(drop=True)
        self.smiles_col = smiles_col
        self.target_cols = list(target_cols)
        self.n_bits = n_bits; self.radius = radius
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        arr = morgan_fp_from_smiles(row[self.smiles_col], self.n_bits, self.radius)
        if arr is None: raise ValueError(f"Failed to featurize SMILES: {row[self.smiles_col]}")
        x = torch.tensor(arr, dtype=torch.float32)
        y = torch.tensor(row[self.target_cols].values.astype(np.float32))
        return x, y

def batch_collate_fp(batch):
    xs = torch.stack([b[0] for b in batch], dim=0)
    ys = torch.stack([b[1] for b in batch], dim=0)
    return xs.to(device), ys.to(device)

def get_fp_loaders(train_df, val_df, test_df, batch_size, n_bits, radius):
    train_ds = FingerprintDataset(train_df, cfg.smiles_col, cfg.target_cols, n_bits, radius)
    val_ds   = FingerprintDataset(val_df,   cfg.smiles_col, cfg.target_cols, n_bits, radius)
    test_ds  = FingerprintDataset(test_df,  cfg.smiles_col, cfg.target_cols, n_bits, radius)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=cfg.num_workers, collate_fn=batch_collate_fp)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=cfg.num_workers, collate_fn=batch_collate_fp)
    test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=cfg.num_workers, collate_fn=batch_collate_fp)
    return train_loader, val_loader, test_loader

# Sequence dataset
class SmilesSequenceDataset(Dataset):
    def __init__(self, df, smiles_col, target_cols, stoi, max_len):
        self.df = df.reset_index(drop=True); self.smiles_col=smiles_col
        self.target_cols = list(target_cols); self.stoi=stoi; self.max_len=max_len
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        x = encode_smiles(row[self.smiles_col], self.max_len, self.stoi)
        y = row[self.target_cols].values.astype(np.float32)
        return x, y

def get_sequence_loaders(train_df, val_df, test_df, batch_size):
    train_ds = SmilesSequenceDataset(train_df, cfg.smiles_col, cfg.target_cols, stoi, cfg.max_len)
    val_ds   = SmilesSequenceDataset(val_df,   cfg.smiles_col, cfg.target_cols, stoi, cfg.max_len)
    test_ds  = SmilesSequenceDataset(test_df,  cfg.smiles_col, cfg.target_cols, stoi, cfg.max_len)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=cfg.num_workers, collate_fn=batch_collate_sequence)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=cfg.num_workers, collate_fn=batch_collate_sequence)
    test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=cfg.num_workers, collate_fn=batch_collate_sequence)
    return train_loader, val_loader, test_loader

# Graph dataset
class SmilesGraphDataset(Dataset):
    def __init__(self, df, smiles_col, target_cols, need_3d=False):
        assert has_rdkit, "RDKit required for graph datasets"
        self.df = df.reset_index(drop=True); self.smiles_col=smiles_col
        self.target_cols = list(target_cols); self.need_3d=need_3d
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        data = smiles_to_pyg(row[self.smiles_col], need_3d=self.need_3d, egnn_confs=cfg.egnn_conformers)
        if data is None: raise ValueError("Failed to parse SMILES:", row[self.smiles_col])
        y = torch.tensor(row[self.target_cols].values.astype(np.float32)); data.y = y
        return data

def get_graph_loaders(train_df, val_df, test_df, batch_size, need_3d=False):
    assert has_pyg, "PyTorch Geometric required"
    train_ds = SmilesGraphDataset(train_df, cfg.smiles_col, cfg.target_cols, need_3d=need_3d)
    val_ds   = SmilesGraphDataset(val_df,   cfg.smiles_col, cfg.target_cols, need_3d=need_3d)
    test_ds  = SmilesGraphDataset(test_df,  cfg.smiles_col, cfg.target_cols, need_3d=need_3d)
    train_loader = GeoDataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=cfg.num_workers)
    val_loader   = GeoDataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=cfg.num_workers)
    test_loader  = GeoDataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=cfg.num_workers)
    return train_loader, val_loader, test_loader

## 4) Models

**TODO:** Implement for the model families (DNN, CNN, RNN, Transformer, GNN).

### DNN

In [11]:
import torch.nn as nn
import torch

class MLPFingerprint(nn.Module):
    """
    Simple configurable MLP that maps a fingerprint vector -> single logit.
    Keep class definition in the models cell. Instantiate in an experiment cell.
    Args:
        in_dim: input dimensionality (fp length)
        hidden_sizes: tuple of hidden layer sizes, e.g. (512,256)
        dropout: dropout probability applied after activation
        use_bn: whether to apply BatchNorm1d between Linear -> ReLU
    Output:
        forward(x) returns logits shaped (B,) (one logit per example)
    """
    def __init__(self, in_dim: int, hidden_sizes=(512,256), dropout: float = 0.1, use_bn: bool = False):
        super().__init__()
        layers = []
        prev = int(in_dim)
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h))
            if use_bn:
                layers.append(nn.BatchNorm1d(h))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Dropout(dropout))
            prev = int(h)
        layers.append(nn.Linear(prev, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, D) float tensor
        returns: logits of shape (B,)
        """
        logits = self.net(x)
        return logits.view(-1)


# Weight init helper
def init_mlp_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.BatchNorm1d):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)


# Convenience builder: creates model, applies init, moves to device
def build_mlp_fingerprint(in_dim, hidden_sizes=(512,256), dropout=0.1, use_bn=False, device=torch.device("cpu")):
    model = MLPFingerprint(in_dim=in_dim, hidden_sizes=hidden_sizes, dropout=dropout, use_bn=use_bn)
    model.apply(init_mlp_weights)
    return model.to(device)

### CNN

In [12]:
# === CNN over SMILES ===
import torch
import torch.nn as nn

class CNNFingerprint(nn.Module):
    """
    1D CNN model for SMILES sequences.
    Each SMILES string is tokenized as integer IDs and embedded, then
    processed through convolutional layers to learn local substructure patterns.
    """

    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 128,
        num_filters: int = 128,
        kernel_sizes=(3, 5, 7),
        dropout: float = 0.3,
        use_bn: bool = False
    ):
        super().__init__()

        # Embedding layer for SMILES tokens
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)

        # Convolutional layers (with different kernel sizes)
        convs = []
        for k in kernel_sizes:
            conv = nn.Conv1d(embed_dim, num_filters, kernel_size=k, padding=k//2)
            convs.append(conv)
        self.convs = nn.ModuleList(convs)

        self.use_bn = use_bn
        if use_bn:
            self.bns = nn.ModuleList([nn.BatchNorm1d(num_filters) for _ in kernel_sizes])
        else:
            self.bns = [nn.Identity() for _ in kernel_sizes]

        # Fully connected classifier head
        in_dim = num_filters * len(kernel_sizes)
        self.classifier = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        """
        x: (B, L) integer token indices
        """
        # Embedding -> (B, L, E)
        x = self.embedding(x)
        # Rearrange for Conv1d: (B, E, L)
        x = x.transpose(1, 2)

        # Apply each convolution + activation + global max pool
        conv_outputs = []
        for conv, bn in zip(self.convs, self.bns):
            h = conv(x)                # (B, num_filters, L)
            h = bn(h)
            h = torch.relu(h)
            h = torch.max(h, dim=2).values  # Global max pool over sequence length
            conv_outputs.append(h)

        # Concatenate features from all kernel sizes
        z = torch.cat(conv_outputs, dim=1)

        # Classification head -> single logit
        logits = self.classifier(z)
        return logits.view(-1)


def build_cnn_fingerprint(vocab_size, embed_dim=128, num_filters=128,
                          kernel_sizes=(3,5,7), dropout=0.3, use_bn=False,
                          device=torch.device("cpu")):
    model = CNNFingerprint(
        vocab_size=vocab_size,
        embed_dim=embed_dim,
        num_filters=num_filters,
        kernel_sizes=kernel_sizes,
        dropout=dropout,
        use_bn=use_bn
    )
    # Kaiming initialization for convs and linear layers
    for m in model.modules():
        if isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            if m.bias is not None:
                nn.init.zeros_(m.bias)
    return model.to(device)


### RNN

In [18]:
# === RNN (LSTM) over SMILES ===
import torch
import torch.nn as nn

class RNNSmiles(nn.Module):
    """
    Recurrent model (LSTM/GRU) over SMILES sequences.
    Captures sequential dependencies in tokenized SMILES strings.
    """

    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 128,
        hidden_dim: int = 256,
        num_layers: int = 2,
        rnn_type: str = "lstm",
        bidirectional: bool = True,
        dropout: float = 0.3,
    ):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)

        if rnn_type.lower() == "gru":
            self.rnn = nn.GRU(
                embed_dim, hidden_dim, num_layers=num_layers,
                batch_first=True, dropout=dropout, bidirectional=bidirectional
            )
        else: 
            self.rnn = nn.LSTM(
                embed_dim, hidden_dim, num_layers=num_layers,
                batch_first=True, dropout=dropout, bidirectional=bidirectional
            )

        # Bidirectional doubles hidden size
        self.rnn_out_dim = hidden_dim * (2 if bidirectional else 1)

        self.classifier = nn.Sequential(
            nn.Linear(self.rnn_out_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        """
        x: (B, L) token indices
        returns: logits (B,)
        """
        emb = self.embed(x)  # (B, L, E)
        rnn_out, hidden = self.rnn(emb)

        if isinstance(hidden, tuple): 
            hidden = hidden[0]

        last_hidden = hidden[-2:] if self.rnn.bidirectional else hidden[-1:]
        last_hidden = torch.cat(list(last_hidden), dim=-1).squeeze(0)  # (B, hidden_dim * num_dirs)

        logits = self.classifier(last_hidden)
        return logits.view(-1)


def build_rnn_smiles(vocab_size, embed_dim=128, hidden_dim=256, num_layers=2,
                     rnn_type="lstm", bidirectional=True, dropout=0.3,
                     device=torch.device("cpu")):
    model = RNNSmiles(
        vocab_size=vocab_size,
        embed_dim=embed_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        rnn_type=rnn_type,
        bidirectional=bidirectional,
        dropout=dropout
    )
    for m in model.modules():
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            if m.bias is not None:
                nn.init.zeros_(m.bias)
    return model.to(device)


### Transformer

In [20]:
class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding from Vaswani et al. (2017).
    Adds position-dependent sin/cos patterns to token embeddings.
    """
    def __init__(self, embed_dim: int, max_len: int = 512, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute sinusoidal encoding matrix [max_len, embed_dim]
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # shape (1, max_len, embed_dim)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: shape (B, L, D)
        returns: same shape, with position encodings added
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class TransformerSmiles(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int = 256, nhead: int = 8,
                 num_layers: int = 4, ff_dim: int = 512, dropout: float = 0.1,
                 max_len: int = 256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_encoding = PositionalEncoding(embed_dim, max_len, dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=nhead,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation="relu",
            batch_first=True  # ensures (B, L, D)
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, L) long tensor of token IDs
        output: (B,) logits
        """
        mask = (x == 0)  # padding mask
        x = self.embed(x) * math.sqrt(x.size(-1))
        x = self.pos_encoding(x)
        x = self.encoder(x, src_key_padding_mask=mask)
        x = x.mean(dim=1)  # mean pooling over sequence
        logits = self.classifier(x)
        return logits.view(-1)

def build_transformer_smiles(vocab_size, device, embed_dim=256, nhead=8,
                             num_layers=4, ff_dim=512, dropout=0.1, max_len=256):
    model = TransformerSmiles(vocab_size, embed_dim, nhead, num_layers, ff_dim, dropout, max_len)
    model.apply(lambda m: nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) else None)
    return model.to(device)


### GNN

In [22]:
# === GNN over molecular graphs (GCN / GIN / GraphSAGE) ===
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GINConv, SAGEConv, global_mean_pool, global_add_pool, global_max_pool

POOLERS = {
    "mean": global_mean_pool,
    "add":  global_add_pool,
    "max":  global_max_pool,
}

class GraphBlock(nn.Module):
    def __init__(self, in_dim, out_dim, kind="gcn", dropout=0.1):
        super().__init__()
        if kind == "gcn":
            self.conv = GCNConv(in_dim, out_dim)
        elif kind == "graphsage":
            self.conv = SAGEConv(in_dim, out_dim)
        elif kind == "gin":
            mlp = nn.Sequential(
                nn.Linear(in_dim, out_dim),
                nn.ReLU(inplace=True),
                nn.Linear(out_dim, out_dim),
            )
            self.conv = GINConv(mlp)
        else:
            raise ValueError(f"Unknown GNN kind: {kind}")
        self.act = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.act(x)
        x = self.dropout(x)
        return x

class GraphNet(nn.Module):
    """
    Configurable GNN with N layers of {GCN, GIN, GraphSAGE} and a pooled classifier head.
    Expects PyG Batch with .x (node features), .edge_index, .batch, optional .edge_attr ignored.
    """
    def __init__(self, in_dim, hidden=256, layers=3, kind="gcn", pool="mean", dropout=0.1):
        super().__init__()
        self.blocks = nn.ModuleList()
        dims = [in_dim] + [hidden] * layers
        for i in range(layers):
            self.blocks.append(GraphBlock(dims[i], dims[i+1], kind=kind, dropout=dropout))
        self.pool = POOLERS.get(pool, global_mean_pool)
        self.head = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden // 2, 1),
        )

    def forward(self, data):
        # data: PyG Batch (x, edge_index, batch, [edge_attr], [pos])
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for blk in self.blocks:
            x = blk(x, edge_index)
        g = self.pool(x, batch)          # (B, hidden)
        logits = self.head(g)            # (B, 1)
        return logits.view(-1)


## 5) Training & Evaluation Loops


In [13]:

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')   # silence RDKit warnings

import os, json, time
from pathlib import Path
from collections import defaultdict
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import copy

from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, confusion_matrix

# Helpers for safe metric computation
def safe_roc_auc(y_true, y_score):
    try:
        return float(roc_auc_score(y_true, y_score))
    except Exception:
        return float("nan")

def safe_average_precision(y_true, y_score):
    try:
        return float(average_precision_score(y_true, y_score))
    except Exception:
        return float("nan")

from sklearn.metrics import precision_score, recall_score

def compute_classification_metrics(y_true, y_score, threshold=0.5):
    y_true = np.asarray(y_true).astype(int)
    y_score = np.asarray(y_score).astype(float)
    metrics = {}

    # ROC-AUC & AP (may be NaN if only one class present)
    try:
        metrics["ROC-AUC"] = float(roc_auc_score(y_true, y_score))
    except Exception:
        metrics["ROC-AUC"] = float("nan")

    try:
        metrics["AP"] = float(average_precision_score(y_true, y_score))
    except Exception:
        metrics["AP"] = float("nan")

    # Binary preds at threshold
    y_pred = (y_score >= threshold).astype(int)

    # Accuracy always defined
    try:
        metrics["Accuracy"] = float(accuracy_score(y_true, y_pred))
    except Exception:
        metrics["Accuracy"] = float("nan")

    # Precision / recall: if positive class absent or zero division, set to nan or 0 depending on choice
    try:
        metrics["Precision"] = float(precision_score(y_true, y_pred, zero_division=0))
    except Exception:
        metrics["Precision"] = float("nan")
    try:
        metrics["Recall"] = float(recall_score(y_true, y_pred, zero_division=0))
    except Exception:
        metrics["Recall"] = float("nan")

    # Confusion matrix (as list)
    try:
        metrics["ConfusionMatrix"] = confusion_matrix(y_true, y_pred).tolist()
    except Exception:
        metrics["ConfusionMatrix"] = []

    return metrics


class EarlyStopping:
    def __init__(self, patience=10, mode="max", delta=0.0, verbose=False):
        assert mode in ("min", "max")
        self.patience = int(patience)
        self.mode = mode
        self.delta = delta
        self.verbose = verbose
        self.best = None
        self.num_bad = 0
        self.is_better = (lambda a, b: a < b - delta) if mode=="min" else (lambda a, b: a > b + delta)

    def step(self, metric):
        if self.best is None:
            self.best = metric
            self.num_bad = 0
            return False
        if self.is_better(metric, self.best):
            self.best = metric
            self.num_bad = 0
            return False
        else:
            self.num_bad += 1
            if self.verbose:
                print(f"EarlyStopping: metric did not improve ({self.num_bad}/{self.patience})")
            return self.num_bad >= self.patience

class ModelCheckpoint:
    def __init__(self, out_dir, monitor="val_metric", mode="max", save_best_only=True):
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(parents=True, exist_ok=True)
        self.monitor = monitor
        self.mode = mode
        self.save_best_only = save_best_only
        self.best = None
        self.is_better = (lambda a, b: a < b) if mode=="min" else (lambda a, b: a > b)

    def save(self, model, optimizer, scheduler, epoch, metric, tag="best"):
        """
        Saves checkpoint. If save_best_only=True, only saves when metric improves wrt self.best.
        Returns path to saved file or None if nothing was saved.
        """
        # check metric improvement
        if self.save_best_only:
            if self.best is None or self.is_better(metric, self.best):
                self.best = metric
            else:
                return None
    
        state = {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "opt_state": optimizer.state_dict() if optimizer is not None else None,
            "sch_state": scheduler.state_dict() if scheduler is not None else None,
            "metric": metric,
        }
        fn = self.out_dir / f"checkpoint_{tag}.pt"
        torch.save(state, fn)
        return fn

def get_pos_weight_from_df(train_df, target_col):
    # returns torch.tensor suitable for BCEWithLogitsLoss pos_weight
    counts = train_df[target_col].value_counts().to_dict()
    n_pos = counts.get(1, 0)
    n_neg = counts.get(0, 0)
    if n_pos == 0:
        print("Warning: no positive examples in training set; pos_weight not defined. Returning 1.0")
        return torch.tensor(1.0, dtype=torch.float32)
    pos_weight = max(1.0, float(n_neg) / float(n_pos))
    return torch.tensor(pos_weight, dtype=torch.float32)

def get_loss_fn(task_type="classification", pos_weight=None, device=torch.device("cpu")):
    if task_type == "classification":
        if pos_weight is not None:
            pw = pos_weight.to(device) if isinstance(pos_weight, torch.Tensor) else torch.tensor(pos_weight, dtype=torch.float32, device=device)
            return nn.BCEWithLogitsLoss(pos_weight=pw)
        else:
            return nn.BCEWithLogitsLoss()
    else:
        return nn.MSELoss()

def unpack_batch(batch):
    """
    Returns (inputs, targets) tuples for common batch types:
    - If batch is (x,y) tuple/list => returns as-is (tensors)
    - If batch has attribute 'y' (torch_geometric Batch), returns (batch, batch.y)
    - Else: raises
    """
    # PyG Batch detection: has attribute 'y' and 'to'
    if hasattr(batch, "y") and hasattr(batch, "to"):
        return batch, batch.y
    if isinstance(batch, (list, tuple)) and len(batch) >= 2:
        return batch[0], batch[1]
    raise ValueError("Unsupported batch type for unpacking. Expected (x,y) or PyG Batch with .y")


def train_one_epoch(model, loader, optimizer, loss_fn, device, scaler=None, grad_clip=None):
    model.train()
    running_loss = 0.0
    all_targets = []
    all_scores = []

    for batch in loader:
        inputs, targets = unpack_batch(batch)
        if isinstance(inputs, torch.Tensor):
            inputs = inputs.to(device)
        if isinstance(targets, torch.Tensor):
            targets = targets.to(device)
        optimizer.zero_grad()
        # forward
        if scaler is not None:
            with torch.cuda.amp.autocast():
                logits = model(inputs)
                logits = logits.view(-1)
                targets_float = targets.view(-1).float()
                loss = loss_fn(logits, targets_float)
            scaler.scale(loss).backward()
            if grad_clip is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(inputs)
            logits = logits.view(-1)
            targets_float = targets.view(-1).float()
            loss = loss_fn(logits, targets_float)
            loss.backward()
            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        running_loss += loss.item() * targets_float.shape[0]
        probs = torch.sigmoid(logits.detach()).cpu().numpy()
        all_scores.append(probs)
        all_targets.append(targets_float.detach().cpu().numpy())

    all_scores = np.concatenate(all_scores, axis=0) if len(all_scores)>0 else np.array([])
    all_targets = np.concatenate(all_targets, axis=0) if len(all_targets)>0 else np.array([])
    avg_loss = running_loss / max(1, len(all_targets))
    metrics = compute_classification_metrics(all_targets, all_scores)
    return avg_loss, metrics

def eval_model(model, loader, loss_fn, device):
    model.eval()
    running_loss = 0.0
    all_targets = []
    all_scores = []
    with torch.no_grad():
        for batch in loader:
            inputs, targets = unpack_batch(batch)
            if isinstance(inputs, torch.Tensor):
                inputs = inputs.to(device)
            if isinstance(targets, torch.Tensor):
                targets = targets.to(device)
            logits = model(inputs)
            logits = logits.view(-1)
            targets_float = targets.view(-1).float()
            loss = loss_fn(logits, targets_float)
            running_loss += loss.item() * targets_float.shape[0]
            probs = torch.sigmoid(logits).cpu().numpy()
            all_scores.append(probs)
            all_targets.append(targets_float.detach().cpu().numpy())

    all_scores = np.concatenate(all_scores, axis=0) if len(all_scores)>0 else np.array([])
    all_targets = np.concatenate(all_targets, axis=0) if len(all_targets)>0 else np.array([])
    avg_loss = running_loss / max(1, len(all_targets))
    metrics = compute_classification_metrics(all_targets, all_scores)
    return avg_loss, metrics

def run_training(model, train_loader, val_loader, optimizer, scheduler, loss_fn, cfg, model_name="model", device=torch.device("cpu")):
    out_dir = Path(cfg.out_dir) / model_name
    ckpt_dir = out_dir / "checkpoints"
    out_dir.mkdir(parents=True, exist_ok=True)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # Sch may be None
    checkpoint = ModelCheckpoint(ckpt_dir, monitor="val_ROC-AUC", mode="max")
    earlystop = EarlyStopping(patience=cfg.patience, mode="max", verbose=True)

    scaler = torch.cuda.amp.GradScaler() if (device.type=="cuda") else None

    history = defaultdict(list)
    best_metric = None
    best_epoch = -1

    for epoch in range(1, cfg.max_epochs+1):
        t0 = time.time()
        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, loss_fn, device, scaler=scaler, grad_clip=cfg.grad_clip)
        val_loss, val_metrics = eval_model(model, val_loader, loss_fn, device)

        if scheduler is not None:
            from torch.optim.lr_scheduler import ReduceLROnPlateau
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(val_metrics.get("ROC-AUC", val_loss))
            else:
                scheduler.step()

        epoch_time = time.time() - t0
        # Logging
        history["epoch"].append(epoch)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_ROC-AUC"].append(train_metrics.get("ROC-AUC"))
        history["val_ROC-AUC"].append(val_metrics.get("ROC-AUC"))
        history["train_AP"].append(train_metrics.get("AP"))
        history["val_AP"].append(val_metrics.get("AP"))

        # print
        print(f"Epoch {epoch}/{cfg.max_epochs} — train_loss={train_loss:.4f} val_loss={val_loss:.4f} val_ROC-AUC={val_metrics.get('ROC-AUC'):.4f} val_AP={val_metrics.get('AP'):.4f} time={epoch_time:.1f}s")

        val_metric = val_metrics.get("ROC-AUC", float("nan"))

        if not math.isnan(val_metric):
            if best_metric is None or val_metric > best_metric:
                best_metric = val_metric
                best_epoch = epoch
                checkpoint.save(model, optimizer, scheduler, epoch, best_metric, tag="best")
                print(f"  -> new best val_ROC-AUC={best_metric:.4f} saved checkpoint")
        else:
            print("  -> val_ROC-AUC is NaN for this epoch (likely single-class in val). Skipping checkpointing.")

        if math.isnan(val_metric):
            stopped = earlystop.step(float("-inf") if earlystop.mode=="max" else float("inf"))
        else:
            stopped = earlystop.step(val_metric)
        if stopped:
            print(f"Early stopping at epoch {epoch} (best epoch {best_epoch} val_ROC-AUC={best_metric})")
            break

    # Save final history
    (out_dir / "training_history.json").write_text(json.dumps(history, indent=2))
    print("Training finished. Best epoch:", best_epoch, "best_metric:", best_metric)
    return history

## 6) Collect Results 


### 6.0 MLP over Morgan Fingerprints

In [16]:
import numpy as np, os, json
from pathlib import Path
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler
import matplotlib.pyplot as plt

out_base = Path(cfg.out_dir) / "fp_baseline"
out_base.mkdir(parents=True, exist_ok=True)

# 1) Stratified internal split
target_col = cfg.target_cols[0]
try:
    train_internal_df, val_internal_df = train_test_split(
        train_df,
        test_size=0.15,
        random_state=cfg.seed,
        stratify=train_df[target_col]
    )
    print("Created stratified internal split:", len(train_internal_df), len(val_internal_df))
except Exception as e:
    print("Stratified split failed:", e)
    # fallback: use original provided val_df for validation
    train_internal_df = train_df.copy()
    val_internal_df = val_df.copy()
    print("Fallback: using provided val_df as validation set")

# Show class balance
def print_class_counts(name, df):
    counts = df[target_col].value_counts().to_dict()
    print(f"{name} counts:", counts)

print_class_counts("train_internal", train_internal_df)
print_class_counts("val_internal", val_internal_df)
print_class_counts("provided_val", val_df)

# 2) Precompute Morgan fingerprints for given df and cache to disk for speed
def precompute_fps(df, fname, n_bits=cfg.fp_nbits, radius=cfg.fp_radius, smiles_col=cfg.smiles_col, target_col=target_col):
    outp = out_base / fname
    if outp.exists():
        print("Loading cached fingerprints:", outp)
        data = np.load(outp)
        X = data["X"]
        y = data["y"]
        return X, y
    X_list = []
    y_list = []
    bad_indices = []
    for i, row in df.reset_index(drop=True).iterrows():
        s = str(row[smiles_col])
        arr = morgan_fp_from_smiles(s, n_bits=n_bits, radius=radius)
        if arr is None:
            bad_indices.append((i, s))
            continue
        X_list.append(arr)
        y_list.append(float(row[target_col]))
    if len(X_list) == 0:
        raise RuntimeError("No valid fingerprints computed for " + str(fname))
    X = np.stack(X_list, axis=0).astype(np.float32)
    y = np.array(y_list, dtype=np.float32)
    np.savez_compressed(outp, X=X, y=y)
    if len(bad_indices) > 0:
        with open(out_base / f"{fname}_bad_smiles.txt", "w") as fh:
            for i,s in bad_indices:
                fh.write(f"{i}\t{s}\n")
    print(f"Saved fingerprints to {outp} (N={len(y)})")
    return X, y

# Precompute for train_internal, val_internal, test_df
X_tr, y_tr = precompute_fps(train_internal_df, "train_internal_fp.npz")
X_val, y_val = precompute_fps(val_internal_df,   "val_internal_fp.npz")
X_test, y_test = precompute_fps(test_df,          "test_fp.npz")

# 3) WeightedRandomSampler for training 
def make_weighted_sampler_from_labels(y_array):
    vals, counts = np.unique(y_array, return_counts=True)
    class_counts = dict(zip(vals.tolist(), counts.tolist()))
    if len(class_counts) <= 1:
        return None
    sample_weights = np.array([1.0 / class_counts[int(lbl)] for lbl in y_array], dtype=np.double)
    sampler = WeightedRandomSampler(torch.from_numpy(sample_weights), num_samples=len(sample_weights), replacement=True)
    return sampler

train_sampler = make_weighted_sampler_from_labels(y_tr)

# DataLoader: use sampler only if not None
if train_sampler is not None:
    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=cfg.num_workers)
else:
    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)


# 4) TensorDatasets + DataLoaders
train_ds = TensorDataset(torch.from_numpy(X_tr), torch.from_numpy(y_tr))
val_ds   = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
test_ds  = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=cfg.num_workers)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size*2, shuffle=False, num_workers=cfg.num_workers)
test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size*2, shuffle=False, num_workers=cfg.num_workers)

print("DL sizes:", len(train_loader), len(val_loader), len(test_loader))
print("Train class counts (actual):", np.unique(y_tr, return_counts=True))
print("Val class counts (actual):", np.unique(y_val, return_counts=True))


# 6) Loss (use pos_weight computed from train_internal_df) and optimizer
pos_weight = get_pos_weight_from_df(train_internal_df, target_col).to(device)
print("Using pos_weight:", float(pos_weight))
loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

# 7) Train using run_training (uses cfg)
cfg_local = copy.deepcopy(cfg)
cfg_local.max_epochs = 50
cfg_local.patience = 8
history = run_training(model, train_loader, val_loader, optimizer, scheduler, loss_fn, cfg_local, model_name="fp_mlp", device=device)

# 8) After training, load best checkpoint and evaluate on test set
ckpt_path = Path(cfg.out_dir) / "fp_mlp" / "checkpoints" / "checkpoint_best.pt"
if ckpt_path.exists():
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model_state"])
    print("Loaded best checkpoint from epoch", state.get("epoch"), "metric", state.get("metric"))
else:
    print("No best checkpoint found at", ckpt_path, "; evaluating last model state.")

# Evaluate on test_loader
model.eval()
all_scores = []
all_targets = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device); yb = yb.to(device)
        logits = model(xb).view(-1)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        all_scores.append(probs)
        all_targets.append(yb.detach().cpu().numpy())
all_scores = np.concatenate(all_scores, axis=0)
all_targets = np.concatenate(all_targets, axis=0)

metrics = compute_classification_metrics(all_targets, all_scores)
print("=== TEST METRICS (fingerprint MLP) ===")
for k,v in metrics.items():
    print(f"{k}: {v}")

# 9) Save metrics and plot learning curves
with open(out_base / "fp_mlp_test_metrics.json", "w") as fh:
    json.dump(metrics, fh, indent=2)
# Plot learning curves (train/val loss + ROC-AUC)
hist_path = Path(cfg.out_dir) / "fp_mlp" / "training_history.json"
if hist_path.exists():
    hist = json.load(open(hist_path))
    epochs = hist["epoch"]
    plt.figure(figsize=(6,4))
    plt.plot(epochs, hist["train_loss"], label="train_loss")
    plt.plot(epochs, hist["val_loss"], label="val_loss")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.tight_layout()
    plt.savefig(out_base / "loss_curve.png")
    plt.close()

    plt.figure(figsize=(6,4))
    plt.plot(epochs, hist["train_ROC-AUC"], label="train_ROC-AUC")
    plt.plot(epochs, hist["val_ROC-AUC"], label="val_ROC-AUC")
    plt.xlabel("epoch"); plt.ylabel("ROC-AUC"); plt.legend(); plt.tight_layout()
    plt.savefig(out_base / "roc_curve_epochs.png")
    plt.close()
    print("Saved learning curves to", out_base)
else:
    print("No history file found at", hist_path)


Created stratified internal split: 1207 214
train_internal counts: {0: 1165, 1: 42}
val_internal counts: {0: 207, 1: 7}
provided_val counts: {0: 462, 1: 12}
Loading cached fingerprints: runs_hw1/fp_baseline/train_internal_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/val_internal_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/test_fp.npz


NameError: name 'train_ds' is not defined

In [60]:

out_final = Path(cfg.out_dir) / "final_mlp"
out_final.mkdir(parents=True, exist_ok=True)

# 1) Prepare train_full
train_full = pd.concat([train_df, val_df], axis=0).reset_index(drop=True)
print("train_full size:", len(train_full), " test size:", len(test_df))

# small internal val percent
internal_val_frac = 0.08
if internal_val_frac > 0:
    train_for_fit_df, internal_val_df = train_test_split(
        train_full,
        test_size=internal_val_frac,
        random_state=cfg.seed,
        stratify=train_full[cfg.target_cols[0]]
    )
else:
    train_for_fit_df = train_full.copy()
    internal_val_df = val_df.copy() 

print("training for fit:", len(train_for_fit_df), "internal_val:", len(internal_val_df))

# 2) Precompute (or load cached) fingerprints for both sets
train_cache = out_final / "train_full_fp.npz"
val_cache   = out_final / "internal_val_fp.npz"
test_cache  = Path(cfg.out_dir) / "fp_baseline" / "test_fp.npz" 

if 'precompute_fps' in globals():
    X_tr, y_tr = precompute_fps(train_for_fit_df, str(train_cache.name))
    X_val, y_val = precompute_fps(internal_val_df, str(val_cache.name))
    # prefer existing test cache
    if test_cache.exists():
        data = np.load(test_cache)
        X_test, y_test = data["X"].astype(np.float32), data["y"].astype(np.float32)
    else:
        X_test, y_test = precompute_fps(test_df, "test_fp.npz")
else:
    raise RuntimeError("precompute_fps not found. Make sure your featurizer cell is run.")

# 3) Build DataLoaders
from torch.utils.data import TensorDataset, DataLoader
def make_loader_from_numpy(X, y, batch_size, sampler=None, shuffle=False):
    ds = TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(y).float())
    if sampler is not None:
        return DataLoader(ds, batch_size=batch_size, sampler=sampler, num_workers=cfg.num_workers)
    else:
        return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=cfg.num_workers)

# add sampler guard
train_sampler = None
if 'make_weighted_sampler_from_labels' in globals():
    train_sampler = make_weighted_sampler_from_labels(y_tr)
train_loader = make_loader_from_numpy(X_tr, y_tr, batch_size=cfg.batch_size, sampler=train_sampler, shuffle=(train_sampler is None))
val_loader   = make_loader_from_numpy(X_val, y_val, batch_size=max(32, cfg.batch_size*2), shuffle=False)
test_loader  = make_loader_from_numpy(X_test, y_test, batch_size=max(32, cfg.batch_size*2), shuffle=False)

print("Train/Val/Test loader lens:", len(train_loader), len(val_loader), len(test_loader))
print("Train class counts:", np.unique(y_tr, return_counts=True))

# 4) Instantiate final model
hidden_sizes = (512, 256) 
dropout = cfg.dropout
model = build_mlp_fingerprint(in_dim=X_tr.shape[1], hidden_sizes=hidden_sizes, dropout=dropout, use_bn=False, device=device)

# 5) Loss / optimizer / scheduler
raw_pw = float(get_pos_weight_from_df(pd.DataFrame({cfg.target_cols[0]: y_tr}), cfg.target_cols[0])) if False else float((y_tr==0).sum() / max(1.0, (y_tr==1).sum()))
# safe compute directly:
n_pos = float((y_tr==1).sum())
n_neg = float((y_tr==0).sum())
raw_pw = (n_neg / max(1.0, n_pos)) if n_pos>0 else 1.0
pw_cap = 10.0
pos_weight = torch.tensor(min(raw_pw, pw_cap), dtype=torch.float32, device=device)
print(f"pos_weight raw={raw_pw:.2f} capped={float(pos_weight):.2f}")

loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

# 6) Training: run_training (uses EarlyStopping / checkpoint)
cfg_local = copy.deepcopy(cfg)
cfg_local.max_epochs = 100
cfg_local.patience = 10
history = run_training(model, train_loader, val_loader, optimizer, scheduler, loss_fn, cfg_local, model_name="final_mlp", device=device)

# 7) Load best checkpoint and evaluate on test set (safe load)
ckpt_path = Path(cfg.out_dir) / "final_mlp" / "checkpoints" / "checkpoint_best.pt"
if ckpt_path.exists():
    try:
        ck = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ck["model_state"])
        print("Loaded checkpoint epoch", ck.get("epoch"), "metric", ck.get("metric"))
    except Exception as e:
        print("Failed to load checkpoint:", e)
else:
    print("No checkpoint found at", ckpt_path, " — using last model state.")

# evaluate on test
model.eval()
all_scores = []; all_targets = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        logits = model(xb).view(-1)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_scores.append(probs)
        all_targets.append(yb.numpy())
all_scores = np.concatenate(all_scores)
all_targets = np.concatenate(all_targets)

metrics = compute_classification_metrics(all_targets, all_scores)
print("=== FINAL TEST METRICS ===")
print(json.dumps(metrics, indent=2))

# save final artifacts
np.savez_compressed(out_final/"final_test_probs.npz", probs=all_scores, y=all_targets)
(out_final/"final_test_metrics.json").write_text(json.dumps(metrics, indent=2))
print("Saved final artifacts to", out_final)


train_full size: 1895  test size: 474
training for fit: 1743 internal_val: 152
Saved fingerprints to runs_hw1/fp_baseline/train_full_fp.npz (N=1743)
Saved fingerprints to runs_hw1/fp_baseline/internal_val_fp.npz (N=152)
Train/Val/Test loader lens: 218 5 15
Train class counts: (array([0., 1.], dtype=float32), array([1687,   56]))
pos_weight raw=30.12 capped=10.00
Epoch 1/100 — train_loss=0.5347 val_loss=0.6003 val_ROC-AUC=0.8680 val_AP=0.7699 time=1.2s
  -> new best val_ROC-AUC=0.8680 saved checkpoint
Epoch 2/100 — train_loss=0.0113 val_loss=0.8523 val_ROC-AUC=0.8571 val_AP=0.7692 time=1.0s
EarlyStopping: metric did not improve (1/10)
Epoch 3/100 — train_loss=0.0053 val_loss=0.9074 val_ROC-AUC=0.8694 val_AP=0.7700 time=1.0s
  -> new best val_ROC-AUC=0.8694 saved checkpoint
Epoch 4/100 — train_loss=0.0060 val_loss=1.0332 val_ROC-AUC=0.8680 val_AP=0.8098 time=1.0s
EarlyStopping: metric did not improve (1/10)
Epoch 5/100 — train_loss=0.0086 val_loss=1.0612 val_ROC-AUC=0.8599 val_AP=0.7693 

### 6.1 CNN over SMILES

In [17]:

train_loader, val_loader, test_loader = get_sequence_loaders(
    train_df, val_df, test_df,
    batch_size=cfg.batch_size
)

model = build_cnn_fingerprint(
    vocab_size=len(stoi),
    embed_dim=128,
    num_filters=128,
    kernel_sizes=(3, 5, 7),
    dropout=cfg.dropout,
    device=device
)
print(model)

pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

cfg_local = copy.deepcopy(cfg)
cfg_local.max_epochs = 50
cfg_local.patience = 8

history = run_training(
    model, train_loader, val_loader,
    optimizer, scheduler, loss_fn,
    cfg_local, model_name="cnn_smiles", device=device
)

ckpt_path = Path(cfg.out_dir) / "cnn_smiles" / "checkpoints" / "checkpoint_best.pt"
if ckpt_path.exists():
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model_state"])
    print(f"Loaded best checkpoint from epoch {state.get('epoch')} (val ROC-AUC={state.get('metric'):.4f})")
else:
    print("No best checkpoint found; using final model weights.")

model.eval()
all_scores, all_targets = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb).view(-1)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_scores.append(probs)
        all_targets.append(yb.cpu().numpy())
all_scores = np.concatenate(all_scores)
all_targets = np.concatenate(all_targets)

metrics = compute_classification_metrics(all_targets, all_scores)
print("\n=== TEST METRICS (CNN over SMILES) ===")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

out_dir = Path(cfg.out_dir) / "cnn_smiles"
out_dir.mkdir(parents=True, exist_ok=True)
json.dump(metrics, open(out_dir / "test_metrics.json", "w"), indent=2)

hist_path = out_dir / "training_history.json"
if hist_path.exists():
    hist = json.load(open(hist_path))
    epochs = hist["epoch"]
    plt.figure(figsize=(6,4))
    plt.plot(epochs, hist["train_loss"], label="train_loss")
    plt.plot(epochs, hist["val_loss"], label="val_loss")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.tight_layout()
    plt.savefig(out_dir / "loss_curve.png"); plt.close()

    plt.figure(figsize=(6,4))
    plt.plot(epochs, hist["train_ROC-AUC"], label="train_ROC-AUC")
    plt.plot(epochs, hist["val_ROC-AUC"], label="val_ROC-AUC")
    plt.xlabel("epoch"); plt.ylabel("ROC-AUC"); plt.legend(); plt.tight_layout()
    plt.savefig(out_dir / "roc_curve.png"); plt.close()
    print("Saved plots to", out_dir)
else:
    print("No history file found; skipping plots.")

CNNFingerprint(
  (embedding): Embedding(49, 128, padding_idx=0)
  (convs): ModuleList(
    (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
    (2): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
  )
  (classifier): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)


  xs = torch.tensor([b[0] for b in batch], dtype=torch.long)


Epoch 1/50 — train_loss=7.2241 val_loss=1.7274 val_ROC-AUC=0.9118 val_AP=0.7927 time=3.6s
  -> new best val_ROC-AUC=0.9118 saved checkpoint
Epoch 2/50 — train_loss=3.7612 val_loss=2.0447 val_ROC-AUC=0.8754 val_AP=0.7506 time=3.1s
EarlyStopping: metric did not improve (1/8)
Epoch 3/50 — train_loss=2.5746 val_loss=3.8311 val_ROC-AUC=0.8808 val_AP=0.7706 time=2.8s
EarlyStopping: metric did not improve (2/8)
Epoch 4/50 — train_loss=1.7028 val_loss=1.9174 val_ROC-AUC=0.8941 val_AP=0.7599 time=3.0s
EarlyStopping: metric did not improve (3/8)
Epoch 5/50 — train_loss=1.0878 val_loss=2.5945 val_ROC-AUC=0.8981 val_AP=0.7593 time=3.0s
EarlyStopping: metric did not improve (4/8)
Epoch 6/50 — train_loss=0.4508 val_loss=1.4469 val_ROC-AUC=0.9004 val_AP=0.7628 time=3.9s
EarlyStopping: metric did not improve (5/8)
Epoch 7/50 — train_loss=0.3555 val_loss=2.0389 val_ROC-AUC=0.8755 val_AP=0.7537 time=2.9s
EarlyStopping: metric did not improve (6/8)
Epoch 8/50 — train_loss=0.1180 val_loss=1.8857 val_ROC-A

### 6.2 RNN (GRU/LSTM) over SMILES

In [19]:
train_loader, val_loader, test_loader = get_sequence_loaders(
    train_df, val_df, test_df,
    batch_size=cfg.batch_size
)

model = build_rnn_smiles(
    vocab_size=len(stoi),
    embed_dim=128,
    hidden_dim=256,
    num_layers=2,
    rnn_type="lstm",       # or "gru"
    bidirectional=True,
    dropout=cfg.dropout,
    device=device
)
print(model)

pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

cfg_local = copy.deepcopy(cfg)
cfg_local.max_epochs = 50
cfg_local.patience = 8

history = run_training(
    model, train_loader, val_loader,
    optimizer, scheduler, loss_fn,
    cfg_local, model_name="rnn_smiles", device=device
)

ckpt_path = Path(cfg.out_dir) / "rnn_smiles" / "checkpoints" / "checkpoint_best.pt"
if ckpt_path.exists():
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model_state"])
    print(f"Loaded best checkpoint from epoch {state.get('epoch')} (val ROC-AUC={state.get('metric'):.4f})")

model.eval()
all_scores, all_targets = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb).view(-1)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_scores.append(probs)
        all_targets.append(yb.cpu().numpy())
all_scores = np.concatenate(all_scores)
all_targets = np.concatenate(all_targets)

metrics = compute_classification_metrics(all_targets, all_scores)
print("\n=== TEST METRICS (RNN over SMILES) ===")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

out_dir = Path(cfg.out_dir) / "rnn_smiles"
out_dir.mkdir(parents=True, exist_ok=True)
json.dump(metrics, open(out_dir / "test_metrics.json", "w"), indent=2)


RNNSmiles(
  (embed): Embedding(49, 128, padding_idx=0)
  (rnn): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
  (classifier): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)
Epoch 1/50 — train_loss=3.0529 val_loss=1.8281 val_ROC-AUC=0.5685 val_AP=0.0301 time=54.5s
  -> new best val_ROC-AUC=0.5685 saved checkpoint
Epoch 2/50 — train_loss=2.7703 val_loss=1.9847 val_ROC-AUC=0.5833 val_AP=0.0308 time=57.8s
  -> new best val_ROC-AUC=0.5833 saved checkpoint
Epoch 3/50 — train_loss=2.8226 val_loss=2.4385 val_ROC-AUC=0.6025 val_AP=0.0325 time=51.8s
  -> new best val_ROC-AUC=0.6025 saved checkpoint
Epoch 4/50 — train_loss=2.6010 val_loss=2.4479 val_ROC-AUC=0.6571 val_AP=0.0496 time=51.9s
  -> new best val_ROC-AUC=0.6571 saved checkpoint
Epoch 5/50 — train_loss=2.7636 val_loss=2.4938 val_ROC-AUC=0.6185 va

### 6.3 Transformer over SMILES

In [21]:

# Reuse SMILES sequence data loaders
train_loader, val_loader, test_loader = get_sequence_loaders(
    train_df, val_df, test_df,
    batch_size=8  # keep small for speed
)

# Build model
model = build_transformer_smiles(
    vocab_size=len(vocab),
    device=device,
    embed_dim=128,  
    nhead=4,       
    num_layers=2,     
    ff_dim=256,
    dropout=0.1,
    max_len=cfg.max_len
)

# Compute class imbalance weight
pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

cfg_local = cfg
cfg_local.max_epochs = 20  
cfg_local.patience = 5

# Train
history = run_training(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    loss_fn,
    cfg_local,
    model_name="transformer_smiles",
    device=device
)

ckpt_path = Path(cfg.out_dir) / "transformer_smiles" / "checkpoints" / "checkpoint_best.pt"
if ckpt_path.exists():
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model_state"])
    print(f"Loaded best checkpoint (epoch {state['epoch']}, val ROC-AUC={state['metric']:.4f})")
else:
    print("No checkpoint found — evaluating last model state.")

model.eval()
all_scores, all_targets = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb).view(-1)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_scores.append(probs)
        all_targets.append(yb.cpu().numpy())

all_scores = np.concatenate(all_scores)
all_targets = np.concatenate(all_targets)

metrics = compute_classification_metrics(all_targets, all_scores)
print("\n=== TEST METRICS (Transformer over SMILES) ===")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

# Save final results
out_dir = Path(cfg.out_dir) / "transformer_smiles"
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "test_metrics.json", "w") as f:
    json.dump(metrics, f, indent=2)
print("Saved results to", out_dir)


  output = torch._nested_tensor_from_mask(


Epoch 1/20 — train_loss=3.7507 val_loss=1.2989 val_ROC-AUC=0.3896 val_AP=0.0207 time=15.2s
  -> new best val_ROC-AUC=0.3896 saved checkpoint
Epoch 2/20 — train_loss=3.2878 val_loss=1.5063 val_ROC-AUC=0.3526 val_AP=0.0196 time=14.4s
EarlyStopping: metric did not improve (1/5)
Epoch 3/20 — train_loss=3.3546 val_loss=1.3090 val_ROC-AUC=0.3667 val_AP=0.0201 time=14.2s
EarlyStopping: metric did not improve (2/5)
Epoch 4/20 — train_loss=3.1135 val_loss=1.2722 val_ROC-AUC=0.4435 val_AP=0.0229 time=16.2s
  -> new best val_ROC-AUC=0.4435 saved checkpoint
Epoch 5/20 — train_loss=3.0340 val_loss=1.3047 val_ROC-AUC=0.4046 val_AP=0.0213 time=13.0s
EarlyStopping: metric did not improve (1/5)
Epoch 6/20 — train_loss=2.9902 val_loss=1.3464 val_ROC-AUC=0.4722 val_AP=0.0244 time=14.2s
  -> new best val_ROC-AUC=0.4722 saved checkpoint
Epoch 7/20 — train_loss=2.6797 val_loss=1.2934 val_ROC-AUC=0.4293 val_AP=0.0222 time=15.1s
EarlyStopping: metric did not improve (1/5)
Epoch 8/20 — train_loss=2.6546 val_lo

### 6.4 GNN on Molecular Graphs (GCN/GIN/GraphSAGE)

In [38]:

train_loader, val_loader, test_loader = get_graph_loaders(
    train_df, val_df, test_df,
    batch_size=cfg.batch_size,
    need_3d=False 
)

b = next(iter(train_df[cfg.smiles_col])) 
data = smiles_to_pyg(b)
print("Atom feature dim:", data.x.shape[1])
print("Bond feature dim:", data.edge_attr.shape[1])

# Infer node feature size from one batch
sample_batch = next(iter(train_loader))
in_dim = sample_batch.x.size(-1)
print(f"GNN node feature dim: {in_dim}")

# 2) Build model
gnn_kind = getattr(cfg, "pyg_backend", "gcn")
model = GraphNet(
    in_dim=in_dim,
    hidden=cfg.gnn_hidden,
    layers=cfg.gnn_layers,
    kind=gnn_kind,
    pool=cfg.global_pool,
    dropout=cfg.gnn_dropout
).to(device)
print(model)

# 3) Loss/opt/scheduler
pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

# 4) Train
cfg_local = copy.deepcopy(cfg)
cfg_local.max_epochs = 40
cfg_local.patience = 8
run_name = f"gnn_{gnn_kind}_{cfg.global_pool}"

history = run_training(
    model, train_loader, val_loader,
    optimizer, scheduler, loss_fn,
    cfg_local, model_name=run_name, device=device
)

# 5) Load best checkpoint and evaluate on test set
ckpt_path = Path(cfg.out_dir) / run_name / "checkpoints" / "checkpoint_best.pt"
if ckpt_path.exists():
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model_state"])
    print(f"Loaded best checkpoint from epoch {state.get('epoch')} (val ROC-AUC={state.get('metric'):.4f})")
else:
    print("No checkpoint found; evaluating last model state.")

# Test
model.eval()
all_scores, all_targets = [], []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        logits = model(batch).view(-1)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_scores.append(probs)
        all_targets.append(batch.y.view(-1).cpu().numpy())

all_scores = np.concatenate(all_scores)
all_targets = np.concatenate(all_targets)
metrics = compute_classification_metrics(all_targets, all_scores)
print("\n=== TEST METRICS (GNN) ===")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

# Save artifacts
out_dir = Path(cfg.out_dir) / run_name
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "test_metrics.json", "w") as f:
    json.dump(metrics, f, indent=2)
print("Saved results to", out_dir)

Atom feature dim: 22
Bond feature dim: 8
GNN node feature dim: 22
GraphNet(
  (blocks): ModuleList(
    (0): GraphBlock(
      (conv): GCNConv(22, 256)
      (act): ReLU(inplace=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1-3): 3 x GraphBlock(
      (conv): GCNConv(256, 256)
      (act): ReLU(inplace=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (head): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=128, out_features=1, bias=True)
  )
)
Epoch 1/40 — train_loss=2.5742 val_loss=2.2860 val_ROC-AUC=0.7318 val_AP=0.0527 time=4.9s
  -> new best val_ROC-AUC=0.7318 saved checkpoint
Epoch 2/40 — train_loss=3.6497 val_loss=2.5782 val_ROC-AUC=0.6865 val_AP=0.0411 time=3.8s
EarlyStopping: metric did not improve (1/8)
Epoch 3/40 — train_loss=3.5010 val_loss=2.7467 val_ROC-AUC=0.6932 val_AP=0.0403 time=3.8s
EarlyStopping: metric did not impr

## 8) Experiments

Various experiments with different hyperparameters and model architectures. Report can be found in repo.

In [42]:
def run_mlp_experiment(
    name="exp",
    batch_size=32,
    hidden_sizes=(512,256),
    dropout=0.1,
    weight_decay=1e-4,
    lr=1e-4,
    grad_clip=None,
    scheduler_type="plateau",
    sampler_type="weighted",  
    pos_weight=True,
    max_epochs=30,
    patience=8,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
):
    """
    Generic experiment runner for the fingerprint MLP.
    Returns a dict of config + metrics (val/test ROC-AUC, AP, etc.)
    """

    out_dir = Path(cfg.out_dir) / f"experiments/{name}"
    out_dir.mkdir(parents=True, exist_ok=True)

    X_tr, y_tr = precompute_fps(train_df, "train_exp_fp.npz")
    X_val, y_val = precompute_fps(val_df,   "val_exp_fp.npz")
    X_test, y_test = precompute_fps(test_df, "test_exp_fp.npz")

    if sampler_type == "weighted":
        # balance classes by inverse frequency
        vals, counts = np.unique(y_tr, return_counts=True)
        class_w = {v: 1.0 / c for v,c in zip(vals, counts)}
        sample_w = np.array([class_w[int(v)] for v in y_tr])
        sampler = WeightedRandomSampler(
            torch.from_numpy(sample_w), num_samples=len(sample_w), replacement=True
        )
        shuffle = False
    else:
        sampler = None
        shuffle = True

    train_loader = DataLoader(
        TensorDataset(torch.from_numpy(X_tr), torch.from_numpy(y_tr)),
        batch_size=batch_size,
        sampler=sampler,
        shuffle=shuffle,
        num_workers=cfg.num_workers,
    )
    val_loader = DataLoader(
        TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val)),
        batch_size=batch_size * 2,
        shuffle=False,
        num_workers=cfg.num_workers,
    )
    test_loader = DataLoader(
        TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test)),
        batch_size=batch_size * 2,
        shuffle=False,
        num_workers=cfg.num_workers,
    )

    model = build_mlp_fingerprint(
        in_dim=X_tr.shape[1],
        hidden_sizes=hidden_sizes,
        dropout=dropout,
        device=device
    )

    if pos_weight:
        pw = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
    else:
        pw = None
    loss_fn = get_loss_fn("classification", pos_weight=pw, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    if scheduler_type == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)
    elif scheduler_type == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    elif scheduler_type == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=max(1, max_epochs//3), gamma=0.5)
    else:
        scheduler = None

    cfg_local = copy.deepcopy(cfg)
    cfg_local.max_epochs = max_epochs
    cfg_local.patience = patience
    cfg_local.grad_clip = grad_clip
    history = run_training(model, train_loader, val_loader, optimizer, scheduler, loss_fn,
                           cfg_local, model_name=name, device=device)

    model.eval()
    all_scores, all_targets = [], []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb).view(-1)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_scores.append(probs)
            all_targets.append(yb.cpu().numpy())
    all_scores = np.concatenate(all_scores)
    all_targets = np.concatenate(all_targets)
    metrics = compute_classification_metrics(all_targets, all_scores)

    results = {
        "config": {
            "batch_size": batch_size,
            "hidden_sizes": hidden_sizes,
            "dropout": dropout,
            "weight_decay": weight_decay,
            "lr": lr,
            "grad_clip": grad_clip,
            "scheduler_type": scheduler_type,
            "sampler_type": sampler_type,
            "pos_weight": pos_weight,
        },
        "val_best": {
            "ROC-AUC": max(history["val_ROC-AUC"]),
            "AP": max(history["val_AP"]),
        },
        "test": metrics,
    }

    with open(out_dir / "result.json", "w") as f:
        json.dump(results, f, indent=2)

    print(f"=== {name} finished ===")
    print("Val best ROC-AUC:", results["val_best"]["ROC-AUC"])
    print("Test ROC-AUC:", results["test"]["ROC-AUC"])
    print("Test AP:", results["test"]["AP"])
    return results


In [43]:
# === 8.2.1 Batch Size Experiment ===
batch_sizes = [8, 32, 128]
batch_results = []

for bs in batch_sizes:
    res = run_mlp_experiment(name=f"batch{bs}", batch_size=bs)
    batch_results.append({
        "batch_size": bs,
        "val_ROC-AUC": res["val_best"]["ROC-AUC"],
        "test_ROC-AUC": res["test"]["ROC-AUC"],
        "test_AP": res["test"]["AP"]
    })

print("\nBatch Size Results:")
for r in batch_results:
    print(r)


Loading cached fingerprints: runs_hw1/fp_baseline/train_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/val_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/test_exp_fp.npz
Epoch 1/30 — train_loss=2.3641 val_loss=1.7799 val_ROC-AUC=0.9767 val_AP=0.8591 time=1.1s
  -> new best val_ROC-AUC=0.9767 saved checkpoint
Epoch 2/30 — train_loss=0.7447 val_loss=0.7137 val_ROC-AUC=0.8979 val_AP=0.8323 time=1.0s
EarlyStopping: metric did not improve (1/8)
Epoch 3/30 — train_loss=0.2586 val_loss=0.5711 val_ROC-AUC=0.8849 val_AP=0.8316 time=1.0s
EarlyStopping: metric did not improve (2/8)
Epoch 4/30 — train_loss=0.1027 val_loss=0.7194 val_ROC-AUC=0.8707 val_AP=0.8134 time=1.1s
EarlyStopping: metric did not improve (3/8)
Epoch 5/30 — train_loss=0.0540 val_loss=0.8513 val_ROC-AUC=0.8692 val_AP=0.8226 time=1.4s
EarlyStopping: metric did not improve (4/8)
Epoch 6/30 — train_loss=0.0382 val_loss=0.8879 val_ROC-AUC=0.8714 val_AP=0.8226 time=1.7s
EarlyStopping: metric did not improv

In [44]:
# === 8.2.2 Regularization Experiment ===
dropouts = [0.0, 0.1, 0.3, 0.5]
decays   = [0.0, 1e-5, 1e-4]
gradclips = [None, 1.0]

reg_results = []

for do in dropouts:
    for wd in decays:
        for gc in gradclips:
            name = f"reg_do{do}_wd{wd}_gc{gc}"
            res = run_mlp_experiment(name=name, dropout=do, weight_decay=wd, grad_clip=gc)
            reg_results.append({
                "dropout": do,
                "weight_decay": wd,
                "grad_clip": gc,
                "val_ROC-AUC": res["val_best"]["ROC-AUC"],
                "test_ROC-AUC": res["test"]["ROC-AUC"],
                "test_AP": res["test"]["AP"]
            })

print("\nRegularization Results:")
for r in reg_results:
    print(r)


Loading cached fingerprints: runs_hw1/fp_baseline/train_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/val_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/test_exp_fp.npz
Epoch 1/30 — train_loss=4.1352 val_loss=1.8805 val_ROC-AUC=0.9904 val_AP=0.7775 time=0.2s
  -> new best val_ROC-AUC=0.9904 saved checkpoint
Epoch 2/30 — train_loss=1.4415 val_loss=1.7768 val_ROC-AUC=0.9908 val_AP=0.8870 time=0.2s
  -> new best val_ROC-AUC=0.9908 saved checkpoint
Epoch 3/30 — train_loss=1.0109 val_loss=1.3268 val_ROC-AUC=0.9668 val_AP=0.8486 time=0.2s
EarlyStopping: metric did not improve (1/8)
Epoch 4/30 — train_loss=0.7203 val_loss=0.8989 val_ROC-AUC=0.9241 val_AP=0.8348 time=0.2s
EarlyStopping: metric did not improve (2/8)
Epoch 5/30 — train_loss=0.4333 val_loss=0.6556 val_ROC-AUC=0.9013 val_AP=0.8242 time=0.2s
EarlyStopping: metric did not improve (3/8)
Epoch 6/30 — train_loss=0.2762 val_loss=0.5814 val_ROC-AUC=0.8826 val_AP=0.8231 time=0.2s
EarlyStopping: metric did not 

In [45]:
hidden_variants = [
    (64,),
    (128,),
    (256,),
    (512, 256),
]
depth_results = []

for hs in hidden_variants:
    res = run_mlp_experiment(name=f"depth_{hs}", hidden_sizes=hs)
    depth_results.append({
        "hidden_sizes": hs,
        "val_ROC-AUC": res["val_best"]["ROC-AUC"],
        "test_ROC-AUC": res["test"]["ROC-AUC"],
        "test_AP": res["test"]["AP"]
    })

print("\nModel Depth/Width Results:")
for r in depth_results:
    print(r)


Loading cached fingerprints: runs_hw1/fp_baseline/train_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/val_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/test_exp_fp.npz
Epoch 1/30 — train_loss=10.0655 val_loss=1.0535 val_ROC-AUC=0.9264 val_AP=0.4831 time=0.1s
  -> new best val_ROC-AUC=0.9264 saved checkpoint
Epoch 2/30 — train_loss=7.3925 val_loss=1.0115 val_ROC-AUC=0.9751 val_AP=0.7761 time=0.0s
  -> new best val_ROC-AUC=0.9751 saved checkpoint
Epoch 3/30 — train_loss=5.2449 val_loss=0.9780 val_ROC-AUC=0.9809 val_AP=0.8671 time=0.0s
  -> new best val_ROC-AUC=0.9809 saved checkpoint
Epoch 4/30 — train_loss=4.0991 val_loss=0.9552 val_ROC-AUC=0.9848 val_AP=0.8746 time=0.0s
  -> new best val_ROC-AUC=0.9848 saved checkpoint
Epoch 5/30 — train_loss=2.9904 val_loss=0.9332 val_ROC-AUC=0.9841 val_AP=0.8739 time=0.1s
EarlyStopping: metric did not improve (1/8)
Epoch 6/30 — train_loss=2.4886 val_loss=0.8983 val_ROC-AUC=0.9859 val_AP=0.8758 time=0.0s
  -> new best val

In [46]:
schedulers = ["plateau", "cosine", "step", "none"]
lr_results = []

for sched in schedulers:
    res = run_mlp_experiment(name=f"lr_{sched}", scheduler_type=sched)
    lr_results.append({
        "scheduler": sched,
        "val_ROC-AUC": res["val_best"]["ROC-AUC"],
        "test_ROC-AUC": res["test"]["ROC-AUC"],
        "test_AP": res["test"]["AP"]
    })

print("\nLR Schedule Results:")
for r in lr_results:
    print(r)


Loading cached fingerprints: runs_hw1/fp_baseline/train_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/val_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/test_exp_fp.npz
Epoch 1/30 — train_loss=4.1442 val_loss=1.8599 val_ROC-AUC=0.9785 val_AP=0.8621 time=0.3s
  -> new best val_ROC-AUC=0.9785 saved checkpoint
Epoch 2/30 — train_loss=1.5848 val_loss=1.9293 val_ROC-AUC=0.9740 val_AP=0.8575 time=0.2s
EarlyStopping: metric did not improve (1/8)
Epoch 3/30 — train_loss=1.1750 val_loss=1.5969 val_ROC-AUC=0.9618 val_AP=0.8500 time=0.2s
EarlyStopping: metric did not improve (2/8)
Epoch 4/30 — train_loss=0.8616 val_loss=1.2163 val_ROC-AUC=0.9340 val_AP=0.8433 time=0.2s
EarlyStopping: metric did not improve (3/8)
Epoch 5/30 — train_loss=0.6149 val_loss=0.8345 val_ROC-AUC=0.8912 val_AP=0.8319 time=0.2s
EarlyStopping: metric did not improve (4/8)
Epoch 6/30 — train_loss=0.4273 val_loss=0.7194 val_ROC-AUC=0.8808 val_AP=0.8314 time=0.2s
EarlyStopping: metric did not improv

In [47]:
imbalance_configs = [
    {"pos_weight": True, "sampler_type": "plain",   "name": "posweight"},
    {"pos_weight": False,"sampler_type": "weighted","name": "sampler"},
    {"pos_weight": True, "sampler_type": "weighted","name": "both"},
]

imbalance_results = []

for cfg_i in imbalance_configs:
    res = run_mlp_experiment(name=cfg_i["name"],
                             pos_weight=cfg_i["pos_weight"],
                             sampler_type=cfg_i["sampler_type"])
    imbalance_results.append({
        "config": cfg_i["name"],
        "val_ROC-AUC": res["val_best"]["ROC-AUC"],
        "test_ROC-AUC": res["test"]["ROC-AUC"],
        "test_AP": res["test"]["AP"]
    })

print("\nClass Imbalance Results:")
for r in imbalance_results:
    print(r)


Loading cached fingerprints: runs_hw1/fp_baseline/train_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/val_exp_fp.npz
Loading cached fingerprints: runs_hw1/fp_baseline/test_exp_fp.npz
Epoch 1/30 — train_loss=1.2733 val_loss=0.9797 val_ROC-AUC=0.8689 val_AP=0.6706 time=0.3s
  -> new best val_ROC-AUC=0.8689 saved checkpoint
Epoch 2/30 — train_loss=0.9506 val_loss=0.7805 val_ROC-AUC=0.9093 val_AP=0.7825 time=0.3s
  -> new best val_ROC-AUC=0.9093 saved checkpoint
Epoch 3/30 — train_loss=0.6856 val_loss=0.6533 val_ROC-AUC=0.8959 val_AP=0.7816 time=0.3s
EarlyStopping: metric did not improve (1/8)
Epoch 4/30 — train_loss=0.4827 val_loss=0.5460 val_ROC-AUC=0.9113 val_AP=0.8249 time=0.2s
  -> new best val_ROC-AUC=0.9113 saved checkpoint
Epoch 5/30 — train_loss=0.3309 val_loss=0.5353 val_ROC-AUC=0.9071 val_AP=0.8153 time=0.2s
EarlyStopping: metric did not improve (1/8)
Epoch 6/30 — train_loss=0.2300 val_loss=0.5753 val_ROC-AUC=0.9012 val_AP=0.8149 time=0.2s
EarlyStopping: metric di

In [50]:
batch_size = cfg.batch_size  
train_loader_seq, val_loader_seq, test_loader_seq = get_sequence_loaders(
    train_df, val_df, test_df, batch_size
)
print("Sequence loaders ready:", len(train_loader_seq), len(val_loader_seq), len(test_loader_seq))


Sequence loaders ready: 178 60 60


In [51]:
results_rnn_depth = {}
for num_layers in [1, 2]:
    print(f"\n=== Training RNN with num_layers={num_layers} ===")
    
    model = RNNSmiles(
        vocab_size=len(vocab),
        embed_dim=128,
        hidden_dim=256,
        num_layers=num_layers,
        dropout=0.1
    ).to(device)

    pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
    loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

    cfg_local = copy.deepcopy(cfg)
    cfg_local.max_epochs = 20
    cfg_local.patience = 5

    history = run_training(
        model, train_loader_seq, val_loader_seq,
        optimizer, scheduler, loss_fn, cfg_local,
        model_name=f"rnn_depth_{num_layers}L", device=device
    )

    # Load best checkpoint and evaluate
    ckpt_path = Path(cfg.out_dir) / f"rnn_depth_{num_layers}L" / "checkpoints" / "checkpoint_best.pt"
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model_state"])
    model.eval()

    all_scores, all_targets = [], []
    with torch.no_grad():
        for xb, yb in test_loader_seq:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb).view(-1)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            all_scores.append(probs)
            all_targets.append(yb.cpu().numpy())
    all_scores = np.concatenate(all_scores)
    all_targets = np.concatenate(all_targets)
    metrics = compute_classification_metrics(all_targets, all_scores)
    results_rnn_depth[num_layers] = metrics

# Summarize
print("\n=== RNN Depth Results (Test set) ===")
for layers, m in results_rnn_depth.items():
    print(f"{layers}-layer: ROC-AUC={m['ROC-AUC']:.4f}, AP={m['AP']:.4f}, "
          f"Precision={m['Precision']:.3f}, Recall={m['Recall']:.3f}")



=== Training RNN with num_layers=1 ===
Epoch 1/20 — train_loss=2.9288 val_loss=2.2309 val_ROC-AUC=0.4986 val_AP=0.0263 time=29.8s
  -> new best val_ROC-AUC=0.4986 saved checkpoint
Epoch 2/20 — train_loss=3.4729 val_loss=2.8186 val_ROC-AUC=0.5317 val_AP=0.0280 time=27.7s
  -> new best val_ROC-AUC=0.5317 saved checkpoint
Epoch 3/20 — train_loss=3.5048 val_loss=2.5929 val_ROC-AUC=0.5503 val_AP=0.0292 time=28.3s
  -> new best val_ROC-AUC=0.5503 saved checkpoint
Epoch 4/20 — train_loss=3.2802 val_loss=2.9068 val_ROC-AUC=0.5402 val_AP=0.0283 time=29.3s
EarlyStopping: metric did not improve (1/5)
Epoch 5/20 — train_loss=3.5956 val_loss=2.1097 val_ROC-AUC=0.5741 val_AP=0.0381 time=27.9s
  -> new best val_ROC-AUC=0.5741 saved checkpoint
Epoch 6/20 — train_loss=3.2550 val_loss=3.2856 val_ROC-AUC=0.5819 val_AP=0.0380 time=27.3s
  -> new best val_ROC-AUC=0.5819 saved checkpoint
Epoch 7/20 — train_loss=3.6956 val_loss=2.6207 val_ROC-AUC=0.6387 val_AP=0.0544 time=27.2s
  -> new best val_ROC-AUC=0.6

In [53]:

train_loader_seq, val_loader_seq, test_loader_seq = get_sequence_loaders(
    train_df, val_df, test_df, batch_size=cfg.batch_size
)

class TransformerSmiles(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, nhead=4, ff_dim=256,
                 num_layers=2, dropout=0.1, max_len=cfg.max_len):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim, nhead=nhead,
            dim_feedforward=ff_dim, dropout=dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Sequential(
            nn.Linear(emb_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        mask = (x == 0)
        x = self.embed(x)
        x = self.encoder(x, src_key_padding_mask=mask)
        x = x.mean(dim=1) 
        return self.classifier(x).view(-1)


depth_results = {}

for n_layers in [1, 2, 3]:
    print(f"\n=== Training Transformer ({n_layers} layers) ===")
    model = TransformerSmiles(vocab_size=len(vocab), num_layers=n_layers, dropout=0.1).to(device)

    pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
    loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=3)

    cfg_local = copy.deepcopy(cfg)
    cfg_local.max_epochs = 14
    cfg_local.patience = 5

    hist = run_training(model, train_loader_seq, val_loader_seq,
                        opt, sch, loss_fn, cfg_local,
                        model_name=f"transformer_{n_layers}L", device=device)

    # Load best checkpoint
    ckpt_path = Path(cfg.out_dir) / f"transformer_{n_layers}L" / "checkpoints" / "checkpoint_best.pt"
    if ckpt_path.exists():
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state["model_state"])
        print(f"Loaded best checkpoint (epoch {state['epoch']}, val ROC-AUC={state['metric']:.4f})")

    model.eval()
    all_scores, all_targets = [], []
    with torch.no_grad():
        for xb, yb in test_loader_seq:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_scores.append(probs)
            all_targets.append(yb.cpu().numpy())
    all_scores = np.concatenate(all_scores)
    all_targets = np.concatenate(all_targets)
    metrics = compute_classification_metrics(all_targets, all_scores)

    depth_results[n_layers] = metrics
    print(f"=== TEST METRICS (Transformer {n_layers}L) ===")
    for k,v in metrics.items():
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

print("\n=== Transformer Depth Results (Test set) ===")
for L, m in depth_results.items():
    print(f"{L}-layer: ROC-AUC={m['ROC-AUC']:.4f}, AP={m['AP']:.4f}, "
          f"Precision={m['Precision']:.3f}, Recall={m['Recall']:.3f}")



=== Training Transformer (1 layers) ===
Epoch 1/14 — train_loss=2.6777 val_loss=1.1693 val_ROC-AUC=0.3535 val_AP=0.0196 time=12.4s
  -> new best val_ROC-AUC=0.3535 saved checkpoint
Epoch 2/14 — train_loss=3.3586 val_loss=1.1654 val_ROC-AUC=0.3615 val_AP=0.0199 time=8.2s
  -> new best val_ROC-AUC=0.3615 saved checkpoint
Epoch 3/14 — train_loss=3.3263 val_loss=1.1677 val_ROC-AUC=0.3665 val_AP=0.0200 time=8.5s
  -> new best val_ROC-AUC=0.3665 saved checkpoint
Epoch 4/14 — train_loss=3.2113 val_loss=1.1747 val_ROC-AUC=0.3730 val_AP=0.0202 time=8.5s
  -> new best val_ROC-AUC=0.3730 saved checkpoint
Epoch 5/14 — train_loss=3.2536 val_loss=1.1678 val_ROC-AUC=0.3842 val_AP=0.0205 time=8.7s
  -> new best val_ROC-AUC=0.3842 saved checkpoint
Epoch 6/14 — train_loss=3.4119 val_loss=1.1698 val_ROC-AUC=0.3900 val_AP=0.0207 time=8.9s
  -> new best val_ROC-AUC=0.3900 saved checkpoint
Epoch 7/14 — train_loss=3.1225 val_loss=1.1716 val_ROC-AUC=0.3992 val_AP=0.0210 time=11.5s
  -> new best val_ROC-AUC=0

In [54]:

train_loader_graph, val_loader_graph, test_loader_graph = get_graph_loaders(
    train_df, val_df, test_df, batch_size=cfg.batch_size, need_3d=False
)

from torch_geometric.nn import GCNConv, global_mean_pool

class GCNModel(nn.Module):
    def __init__(self, in_dim, hidden_dim=64, num_layers=2, dropout=0.1):
        super().__init__()
        self.convs = nn.ModuleList()
        prev_dim = in_dim
        for _ in range(num_layers):
            conv = GCNConv(prev_dim, hidden_dim)
            self.convs.append(conv)
            prev_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = conv(x, edge_index)
            x = torch.relu(x)
            x = self.dropout(x)
        x = global_mean_pool(x, batch)
        return self.classifier(x).view(-1)


depth_results_gnn = {}

for n_layers in [1, 2, 3]:
    print(f"\n=== Training GCN ({n_layers} layers) ===")
    model = GCNModel(in_dim=train_loader_graph.dataset[0].x.shape[1],
                     hidden_dim=64, num_layers=n_layers, dropout=0.1).to(device)

    pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
    loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=3)

    cfg_local = copy.deepcopy(cfg)
    cfg_local.max_epochs = 30
    cfg_local.patience = 6

    hist = run_training(model, train_loader_graph, val_loader_graph, opt, sch,
                        loss_fn, cfg_local, model_name=f"gcn_{n_layers}L", device=device)

    # Load best checkpoint
    ckpt_path = Path(cfg.out_dir) / f"gcn_{n_layers}L" / "checkpoints" / "checkpoint_best.pt"
    if ckpt_path.exists():
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state["model_state"])
        print(f"Loaded best checkpoint (epoch {state['epoch']}, val ROC-AUC={state['metric']:.4f})")

    # Evaluate on test
    model.eval()
    all_scores, all_targets = [], []
    with torch.no_grad():
        for batch in test_loader_graph:
            batch = batch.to(device)
            logits = model(batch)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_scores.append(probs)
            all_targets.append(batch.y.cpu().numpy())
    all_scores = np.concatenate(all_scores)
    all_targets = np.concatenate(all_targets)
    metrics = compute_classification_metrics(all_targets, all_scores)
    depth_results_gnn[n_layers] = metrics
    print(f"=== TEST METRICS (GCN {n_layers}L) ===")
    for k,v in metrics.items():
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

# Summary
print("\n=== GNN Depth Results (Test set) ===")
for L, m in depth_results_gnn.items():
    print(f"{L}-layer: ROC-AUC={m['ROC-AUC']:.4f}, AP={m['AP']:.4f}, "
          f"Precision={m['Precision']:.3f}, Recall={m['Recall']:.3f}")



=== Training GCN (1 layers) ===
Epoch 1/30 — train_loss=1.9254 val_loss=1.7811 val_ROC-AUC=0.5368 val_AP=0.0274 time=2.7s
  -> new best val_ROC-AUC=0.5368 saved checkpoint
Epoch 2/30 — train_loss=2.6926 val_loss=2.1311 val_ROC-AUC=0.5602 val_AP=0.0287 time=2.4s
  -> new best val_ROC-AUC=0.5602 saved checkpoint
Epoch 3/30 — train_loss=2.6326 val_loss=2.1162 val_ROC-AUC=0.5776 val_AP=0.0297 time=2.6s
  -> new best val_ROC-AUC=0.5776 saved checkpoint
Epoch 4/30 — train_loss=2.7195 val_loss=2.3498 val_ROC-AUC=0.5963 val_AP=0.0310 time=2.3s
  -> new best val_ROC-AUC=0.5963 saved checkpoint
Epoch 5/30 — train_loss=2.8986 val_loss=1.8407 val_ROC-AUC=0.6304 val_AP=0.0337 time=2.6s
  -> new best val_ROC-AUC=0.6304 saved checkpoint
Epoch 6/30 — train_loss=2.9256 val_loss=1.9147 val_ROC-AUC=0.6575 val_AP=0.0362 time=2.4s
  -> new best val_ROC-AUC=0.6575 saved checkpoint
Epoch 7/30 — train_loss=2.6198 val_loss=1.7676 val_ROC-AUC=0.6806 val_AP=0.0388 time=2.8s
  -> new best val_ROC-AUC=0.6806 save

In [56]:
train_loader_seq, val_loader_seq, test_loader_seq = get_sequence_loaders(
    train_df, val_df, test_df, batch_size=cfg.batch_size
)

# Lightweight CNN with multi-kernel convs; dropout applied after pooled concat
class CNNSmiles(nn.Module):
    def __init__(self, vocab_size, emb_dim=cfg.embed_dim, channels=cfg.cnn_channels,
                 kernel_sizes=cfg.cnn_kernel_sizes, dropout=0.1, pad_idx=0):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=emb_dim, out_channels=channels, kernel_size=k)
            for k in kernel_sizes
        ])
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Sequential(
            nn.Linear(2 * channels * len(kernel_sizes), 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, 1),
        )


    def forward(self, x):
        # x: (B, T) indices
        x = self.embed(x)               
        x = x.transpose(1, 2)           
        feats = []
        for conv in self.convs:
            h = torch.relu(conv(x))       
            h_max = torch.amax(h, dim=2)   
            h_avg = torch.mean(h, dim=2) 
            feats.append(torch.cat([h_max, h_avg], dim=1)) 
        h = torch.cat(feats, dim=1)       
        # compress (2C*K) -> (C*K) with a linear layer is optional;
        # here, simply dropout then classifier handles the size.
        h = self.dropout(h)
        return self.classifier(h).view(-1)

# Where to save
out_dir = Path(cfg.out_dir) / "cnn_dropout_sweep"
out_dir.mkdir(parents=True, exist_ok=True)

drop_values = [0.0, 0.1, 0.3, 0.5]
results = {}

for dp in drop_values:
    print(f"\n=== Training CNN (dropout={dp}) ===")
    model = CNNSmiles(vocab_size=len(vocab),
                      emb_dim=128,        
                      channels=128,         
                      kernel_sizes=(3,5,7),
                      dropout=dp,
                      pad_idx=0).to(device)

    pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
    loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)

    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=3)

    cfg_local = copy.deepcopy(cfg)
    cfg_local.max_epochs = 20
    cfg_local.patience = 5
    tag = f"cnn_drop_{str(dp).replace('.','_')}"
    _ = run_training(model, train_loader_seq, val_loader_seq, opt, sch, loss_fn, cfg_local,
                     model_name=tag, device=device)

    # Load best checkpoint
    ckpt = Path(cfg.out_dir) / tag / "checkpoints" / "checkpoint_best.pt"
    if ckpt.exists():
        state = torch.load(ckpt, map_location=device)
        model.load_state_dict(state["model_state"])
        print(f"Loaded best checkpoint (epoch {state['epoch']}, val ROC-AUC={state['metric']:.4f})")

    # Evaluate on test
    model.eval()
    all_scores, all_targets = [], []
    with torch.no_grad():
        for xb, yb in test_loader_seq:
            xb, yb = xb.to(device), yb.to(device)
            probs = torch.sigmoid(model(xb)).cpu().numpy()
            all_scores.append(probs); all_targets.append(yb.cpu().numpy())
    all_scores = np.concatenate(all_scores); all_targets = np.concatenate(all_targets)

    metrics = compute_classification_metrics(all_targets, all_scores)
    results[dp] = metrics
    print(f"=== TEST METRICS (CNN dropout={dp}) ===")
    for k,v in metrics.items():
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

# Save & print summary
(summary_path := out_dir / "cnn_dropout_results.json").write_text(json.dumps(results, indent=2))
print("\n=== CNN Dropout Sweep: Test Summary ===")
for dp in drop_values:
    m = results[dp]
    print(f"dropout={dp:>3}: ROC-AUC={m['ROC-AUC']:.4f}, AP={m['AP']:.4f}, "
          f"Precision={m['Precision']:.3f}, Recall={m['Recall']:.3f}")
print(f"\nSaved JSON to: {summary_path}")



=== Training CNN (dropout=0.0) ===
Epoch 1/20 — train_loss=5.4127 val_loss=2.6246 val_ROC-AUC=0.7235 val_AP=0.0517 time=4.0s
  -> new best val_ROC-AUC=0.7235 saved checkpoint
Epoch 2/20 — train_loss=3.9597 val_loss=2.9584 val_ROC-AUC=0.8375 val_AP=0.3036 time=3.0s
  -> new best val_ROC-AUC=0.8375 saved checkpoint
Epoch 3/20 — train_loss=3.1135 val_loss=4.6060 val_ROC-AUC=0.9001 val_AP=0.5622 time=3.2s
  -> new best val_ROC-AUC=0.9001 saved checkpoint
Epoch 4/20 — train_loss=2.8828 val_loss=1.2993 val_ROC-AUC=0.9356 val_AP=0.7194 time=3.2s
  -> new best val_ROC-AUC=0.9356 saved checkpoint
Epoch 5/20 — train_loss=2.4286 val_loss=0.5144 val_ROC-AUC=0.9374 val_AP=0.7576 time=3.0s
  -> new best val_ROC-AUC=0.9374 saved checkpoint
Epoch 6/20 — train_loss=2.0808 val_loss=2.0803 val_ROC-AUC=0.9006 val_AP=0.7485 time=3.0s
EarlyStopping: metric did not improve (1/5)
Epoch 7/20 — train_loss=1.5530 val_loss=1.0112 val_ROC-AUC=0.9473 val_AP=0.7612 time=3.2s
  -> new best val_ROC-AUC=0.9473 saved c

In [60]:
out_dir = Path(cfg.out_dir) / "cnn_l2_sweep"
out_dir.mkdir(parents=True, exist_ok=True)

weight_decays = [0.0, 1e-4, 1e-3]
results = {}

for wd in weight_decays:
    print(f"\n=== Training CNN with weight_decay={wd} ===")

    model = CNNFingerprint(vocab_size, embed_dim=128, num_filters=128, kernel_sizes=(3,5,7),
                       dropout=0.1).to(device)

    loss_fn = get_loss_fn("classification", device=device)
    opt = Adam(model.parameters(), lr=cfg.lr, weight_decay=wd)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=3)
    cfg_local = deepcopy(cfg)
    cfg_local.max_epochs = 15
    cfg_local.patience = 5

    hist = run_training(model, train_loader_seq, val_loader_seq, opt, scheduler,
                        loss_fn, cfg_local, model_name=f"cnn_l2_{wd}", device=device)

    model.eval()
    all_scores, all_targets = [], []
    with torch.no_grad():
        for xb, yb in test_loader_seq:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb).view(-1)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_scores.append(probs)
            all_targets.append(yb.cpu().numpy())

    all_scores = np.concatenate(all_scores)
    all_targets = np.concatenate(all_targets)
    metrics = compute_classification_metrics(all_targets, all_scores)
    results[wd] = metrics

    print(f"Test ROC-AUC={metrics['ROC-AUC']:.4f}, AP={metrics['AP']:.4f}")

print("\n=== CNN L2 Regularization Results (Test set) ===")
for wd, m in results.items():
    print(f"weight_decay={wd}: ROC-AUC={m['ROC-AUC']:.4f}, AP={m['AP']:.4f}, "
          f"Precision={m['Precision']:.3f}, Recall={m['Recall']:.3f}")


=== Training CNN with weight_decay=0.0 ===
Epoch 1/15 — train_loss=0.1765 val_loss=0.1508 val_ROC-AUC=0.7258 val_AP=0.0469 time=3.5s
  -> new best val_ROC-AUC=0.7258 saved checkpoint
Epoch 2/15 — train_loss=0.1392 val_loss=0.0958 val_ROC-AUC=0.9307 val_AP=0.5195 time=4.8s
  -> new best val_ROC-AUC=0.9307 saved checkpoint
Epoch 3/15 — train_loss=0.1096 val_loss=0.0816 val_ROC-AUC=0.9625 val_AP=0.7452 time=2.9s
  -> new best val_ROC-AUC=0.9625 saved checkpoint
Epoch 4/15 — train_loss=0.1012 val_loss=0.0545 val_ROC-AUC=0.9295 val_AP=0.7435 time=2.7s
EarlyStopping: metric did not improve (1/5)
Epoch 5/15 — train_loss=0.0712 val_loss=0.0483 val_ROC-AUC=0.9596 val_AP=0.7690 time=2.6s
EarlyStopping: metric did not improve (2/5)
Epoch 6/15 — train_loss=0.0459 val_loss=0.0548 val_ROC-AUC=0.9398 val_AP=0.7093 time=2.5s
EarlyStopping: metric did not improve (3/5)
Epoch 7/15 — train_loss=0.0349 val_loss=0.0739 val_ROC-AUC=0.9417 val_AP=0.7034 time=2.5s
EarlyStopping: metric did not improve (4/5)


In [63]:
import torch
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR

results = {}

def train_transformer_with_scheduler(scheduler_name):
    print(f"\n=== Training Transformer ({scheduler_name}) ===")

    # Build model using correct args
    model = TransformerSmiles(
        vocab_size=vocab_size,
        emb_dim=128,
        nhead=4,
        ff_dim=256,
        num_layers=2,
        dropout=0.1,
        max_len=cfg.max_len
    ).to(device)

    # Weighted BCE loss 
    pos_weight = get_pos_weight_from_df(train_df, cfg.target_cols[0]).to(device)
    loss_fn = get_loss_fn("classification", pos_weight=pos_weight, device=device)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

    # Scheduler options
    if scheduler_name == "constant":
        scheduler = None
    elif scheduler_name == "step":
        scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
    elif scheduler_name == "cosine":
        scheduler = CosineAnnealingLR(optimizer, T_max=10)
    else:
        raise ValueError(f"Unknown scheduler: {scheduler_name}")

    # Training configuration
    cfg_local = copy.deepcopy(cfg)
    cfg_local.max_epochs = 15
    cfg_local.patience = 5

    # Train
    history = run_training(
        model,
        train_loader_seq,
        val_loader_seq,
        optimizer,
        scheduler,
        loss_fn,
        cfg_local,
        model_name=f"transformer_{scheduler_name}",
        device=device,
    )

    # Evaluate on test set
    model.eval()
    all_scores, all_targets = [], []
    with torch.no_grad():
        for xb, yb in test_loader_seq:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb).view(-1)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_scores.append(probs)
            all_targets.append(yb.cpu().numpy())
    all_scores = np.concatenate(all_scores)
    all_targets = np.concatenate(all_targets)

    metrics = compute_classification_metrics(all_targets, all_scores)
    results[scheduler_name] = metrics
    print(f"\n=== {scheduler_name.upper()} TEST METRICS ===")
    for k, v in metrics.items():
        if isinstance(v, (float, int)):
            print(f"{k}: {v:.4f}")
        else:
            print(f"{k}: {v}")


# Run all three scheduler variants
for sched in ["constant", "step", "cosine"]:
    train_transformer_with_scheduler(sched)

# Summarize results
print("Transformer LR Schedule Comparison Summary")
for name, m in results.items():
    print(f"{name:10s} | ROC-AUC={m['ROC-AUC']:.4f}, AP={m['AP']:.4f}, Precision={m['Precision']:.3f}, Recall={m['Recall']:.3f}")



=== Training Transformer (constant) ===
Epoch 1/15 — train_loss=2.7354 val_loss=1.1927 val_ROC-AUC=0.3494 val_AP=0.0195 time=15.4s
  -> new best val_ROC-AUC=0.3494 saved checkpoint
Epoch 2/15 — train_loss=3.3391 val_loss=1.1882 val_ROC-AUC=0.3595 val_AP=0.0198 time=19.4s
  -> new best val_ROC-AUC=0.3595 saved checkpoint
Epoch 3/15 — train_loss=3.3635 val_loss=1.1635 val_ROC-AUC=0.4037 val_AP=0.0212 time=15.2s
  -> new best val_ROC-AUC=0.4037 saved checkpoint
Epoch 4/15 — train_loss=3.1914 val_loss=1.1523 val_ROC-AUC=0.4913 val_AP=0.0253 time=15.0s
  -> new best val_ROC-AUC=0.4913 saved checkpoint
Epoch 5/15 — train_loss=3.1389 val_loss=1.0972 val_ROC-AUC=0.6858 val_AP=0.0686 time=13.7s
  -> new best val_ROC-AUC=0.6858 saved checkpoint
Epoch 6/15 — train_loss=2.8054 val_loss=1.0992 val_ROC-AUC=0.6966 val_AP=0.0894 time=14.2s
  -> new best val_ROC-AUC=0.6966 saved checkpoint
Epoch 7/15 — train_loss=2.9911 val_loss=1.0652 val_ROC-AUC=0.7758 val_AP=0.1141 time=16.7s
  -> new best val_ROC-

In [64]:
models_to_plot = {
    "MLP (Fingerprints)": "fp_mlp",
    "CNN (SMILES)": "cnn_smiles",
    "RNN (SMILES)": "rnn_smiles",
    "Transformer (SMILES)": "transformer_smiles",
    "GNN": "gnn_gcn_mean",
}

out_dir = Path("runs_hw1")

for title, subdir in models_to_plot.items():
    hist_path = out_dir / subdir / "training_history.json"
    if not hist_path.exists():
        print(f"Skipping {title}: no history file found.")
        continue

    hist = json.load(open(hist_path))
    epochs = hist["epoch"]

    plt.figure(figsize=(6, 4))
    plt.plot(epochs, hist["train_loss"], label="Train Loss")
    plt.plot(epochs, hist["val_loss"], label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"{title} – Training vs Validation Loss")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_dir / subdir / f"{subdir}_loss_curve.png", dpi=150)
    plt.close()

    plt.figure(figsize=(6, 4))
    plt.plot(epochs, hist["train_ROC-AUC"], label="Train ROC-AUC")
    plt.plot(epochs, hist["val_ROC-AUC"], label="Val ROC-AUC")
    plt.xlabel("Epoch")
    plt.ylabel("ROC-AUC")
    plt.title(f"{title} – Training vs Validation ROC-AUC")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_dir / subdir / f"{subdir}_roc_curve.png", dpi=150)
    plt.close()

    print(f"Saved plots for {title}")


Saved plots for MLP (Fingerprints)
Saved plots for CNN (SMILES)
Saved plots for RNN (SMILES)
Saved plots for Transformer (SMILES)
Saved plots for GNN
