## Model Sucks

In [6]:
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.data import Data

In [7]:
df = pd.read_csv("/Users/williamsilver/Ai in materials workshop/AI-in-materials-workshop/AI-in-materials-workshop/mmc2_cleaned_2_no_zero_pce.csv")

In [14]:
# ---------- helpers ----------
def one_hot(value, choices):
    """One-hot with an extra 'other' bucket."""
    out = [0.0] * (len(choices) + 1)
    idx = choices.index(value) if value in choices else len(choices)
    out[idx] = 1.0
    return out

def index_add(src, index, dim_size):
    """
    src: (E, hidden)
    index: (E,)  destination indices
    returns: (dim_size, hidden)
    """
    out = torch.zeros(
        dim_size,
        src.size(1),
        device=src.device,
        dtype=src.dtype
    )
    out.index_add_(0, index, src)
    return out
def safe_float(x, default=0.0):
    try:
        if x is None: 
            return default
        if isinstance(x, str):
            x = float(x)
        if np.isnan(x) or np.isinf(x):
            return default
        return float(x)
    except Exception:
        return default

In [10]:
# ---------- atom features ----------
HYBRID_CHOICES = [
    Chem.rdchem.HybridizationType.S,
    Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3,
    Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2,
]
CHIRAL_CHOICES = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER,
]

def get_atom_features(atom, gasteiger_charge=0.0):
    """
    Chemprop-ish: categorical one-hots + a couple numeric signals.
    Returns a float list.
    """
    Z = atom.GetAtomicNum()
    Z = int(Z) if Z is not None else 0
    # atomic number one-hot (1..100), with 0 as "unknown"
    Z_oh = [0.0] * 101
    Z_oh[min(max(Z, 0), 100)] = 1.0

    degree_oh = one_hot(atom.GetTotalDegree(), [0,1,2,3,4,5])
    formal_charge_oh = one_hot(atom.GetFormalCharge(), [-3,-2,-1,0,1,2,3])
    hybrid_oh = one_hot(atom.GetHybridization(), HYBRID_CHOICES)
    chiral_oh = one_hot(atom.GetChiralTag(), CHIRAL_CHOICES)

    total_h_oh = one_hot(atom.GetTotalNumHs(), [0,1,2,3,4])
    implicit_val_oh = one_hot(atom.GetImplicitValence(), [0,1,2,3,4,5,6])
    explicit_val_oh = one_hot(atom.GetExplicitValence(), [0,1,2,3,4,5,6])
    radical_oh = one_hot(atom.GetNumRadicalElectrons(), [0,1,2])

    is_aromatic = [1.0 if atom.GetIsAromatic() else 0.0]
    in_ring = [1.0 if atom.IsInRing() else 0.0]

    # Numeric features
    mass = [atom.GetMass() / 200.0]  # roughly scaled
    # clamp charge to a reasonable range
    q = float(np.clip(safe_float(gasteiger_charge, 0.0), -2.0, 2.0))
    g_charge = [q]

    feats = (
        Z_oh + degree_oh + formal_charge_oh + hybrid_oh + chiral_oh +
        total_h_oh + implicit_val_oh + explicit_val_oh + radical_oh +
        is_aromatic + in_ring +
        mass + g_charge
    )
    return feats

# ---------- bond features ----------
STEREO_CHOICES = [
    Chem.rdchem.BondStereo.STEREONONE,
    Chem.rdchem.BondStereo.STEREOZ,
    Chem.rdchem.BondStereo.STEREOE,
    Chem.rdchem.BondStereo.STEREOCIS,
    Chem.rdchem.BondStereo.STEREOTRANS,
    Chem.rdchem.BondStereo.STEREOANY,
]
BONDDIR_CHOICES = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT,
]

def get_bond_features(bond):
    bt = bond.GetBondType()
    bond_type_oh = [
        1.0 if bt == Chem.rdchem.BondType.SINGLE else 0.0,
        1.0 if bt == Chem.rdchem.BondType.DOUBLE else 0.0,
        1.0 if bt == Chem.rdchem.BondType.TRIPLE else 0.0,
        1.0 if bt == Chem.rdchem.BondType.AROMATIC else 0.0,
    ]
    conj = [1.0 if bond.GetIsConjugated() else 0.0]
    ring = [1.0 if bond.IsInRing() else 0.0]
    stereo_oh = one_hot(bond.GetStereo(), STEREO_CHOICES)
    dir_oh = one_hot(bond.GetBondDir(), BONDDIR_CHOICES)

    return bond_type_oh + conj + ring + stereo_oh + dir_oh

def smiles_to_graph(smiles, targets):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    # Compute Gasteiger charges (best on H-added molecule)
    molH = Chem.AddHs(mol)
    try:
        AllChem.ComputeGasteigerCharges(molH)
        charges = []
        for i in range(mol.GetNumAtoms()):
            a = molH.GetAtomWithIdx(i)
            charges.append(a.GetProp("_GasteigerCharge"))
    except Exception:
        charges = [0.0] * mol.GetNumAtoms()

    # Node features
    node_feats = []
    for i, atom in enumerate(mol.GetAtoms()):
        q = charges[i] if i < len(charges) else 0.0
        node_feats.append(get_atom_features(atom, gasteiger_charge=q))
    x = torch.tensor(node_feats, dtype=torch.float)

    # Directed edges + edge features + reverse edge mapping
    edge_pairs = []
    edge_feats = []
    pair_to_idx = {}

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bf = get_bond_features(bond)

        # i -> j
        pair_to_idx[(i, j)] = len(edge_pairs)
        edge_pairs.append([i, j])
        edge_feats.append(bf)

        # j -> i
        pair_to_idx[(j, i)] = len(edge_pairs)
        edge_pairs.append([j, i])
        edge_feats.append(bf)

    if len(edge_pairs) == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 0), dtype=torch.float)  # will be fixed below
        rev_edge_index = torch.empty((0,), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_pairs, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_feats, dtype=torch.float)
        rev_edge_index = torch.tensor(
            [pair_to_idx[(dst, src)] for src, dst in edge_pairs],
            dtype=torch.long
        )

    y = torch.tensor([targets], dtype=torch.float)  # (1,4)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
    data.rev_edge_index = rev_edge_index  # store for D-MPNN
    return data

In [11]:
from tqdm import tqdm

final_graphs = []
for _, row in tqdm(df.iterrows(), total=len(df)):
    targets = [row["homo"], row["lumo"], row["gap"], row["pce"]]
    g = smiles_to_graph(row["smiles"], targets)
    if g is not None and g.x is not None:
        # If a molecule has no bonds, edge_attr will be empty; we handle in model
        final_graphs.append(g)

graph_list = final_graphs
print(f"Created {len(graph_list)} graph objects.")
print("Node feature dim:", graph_list[0].x.shape[1])
print("Edge feature dim:", (graph_list[0].edge_attr.shape[1] if graph_list[0].edge_attr.numel() else "empty"))

100%|██████████| 48357/48357 [01:19<00:00, 606.35it/s]

Created 48357 graph objects.
Node feature dim: 158
Edge feature dim: 17





In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import Set2Set

class DMPNN_Set2Set(nn.Module):
    def __init__(self, node_in_dim, edge_in_dim, hidden_dim=256, depth=3, out_dim=4, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.depth = depth
        self.dropout = dropout

        self.W_in = nn.Linear(node_in_dim + edge_in_dim, hidden_dim)
        self.W_msg = nn.Linear(hidden_dim, hidden_dim)
        self.W_atom = nn.Linear(node_in_dim + hidden_dim, hidden_dim)

        self.pool = Set2Set(hidden_dim, processing_steps=3)
        self.head = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        batch = data.batch
        N = x.size(0)

        # Handle molecules with no bonds
        if edge_index.numel() == 0:
            atom_h = F.relu(
                self.W_atom(
                    torch.cat([x, torch.zeros((N, self.hidden_dim), device=x.device)], dim=1)
                )
            )
            pooled = self.pool(atom_h, batch)
            return self.head(pooled)

        src, dst = edge_index
        rev = data.rev_edge_index.to(x.device)

        # Initial edge embeddings
        h0 = F.relu(self.W_in(torch.cat([x[src], edge_attr], dim=1)))
        h = h0

        for _ in range(self.depth):
            # m_in = scatter_add(h, dst)
            m_in = index_add(h, dst, N)

            # remove reverse edge
            m_e = m_in[src] - h[rev]

            h = F.relu(self.W_msg(h0 + m_e))
            h = F.dropout(h, p=self.dropout, training=self.training)

        # Atom-level aggregation
        atom_msg = index_add(h, dst, N)
        atom_h = F.relu(self.W_atom(torch.cat([x, atom_msg], dim=1)))

        pooled = self.pool(atom_h, batch)
        return self.head(pooled)

In [19]:
@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total = 0.0
    n = 0
    for batch in loader:
        batch = batch.to(device)
        pred = model(batch)
        y = batch.y.view(pred.shape).to(device)
        loss = criterion(pred, y)
        total += loss.item() * batch.num_graphs
        n += batch.num_graphs
    return total / max(n, 1)

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total = 0.0
    n = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch)
        y = batch.y.view(pred.shape).to(device)
        loss = criterion(pred, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        total += loss.item() * batch.num_graphs
        n += batch.num_graphs
    return total / max(n, 1)

class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.best = float("inf")
        self.count = 0

    def step(self, val_loss):
        improved = (self.best - val_loss) > self.min_delta
        if improved:
            self.best = val_loss
            self.count = 0
            return False  # don't stop
        else:
            self.count += 1
            return self.count >= self.patience

In [None]:
import copy
import numpy as np
import torch
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

# Safety: only Data objects
graph_list = [g for g in graph_list if isinstance(g, Data)]
assert len(graph_list) > 0, "No valid graphs found."

node_in_dim = graph_list[0].x.shape[1]
# edge dim may be empty for some molecules; infer from a molecule that has edges
edge_in_dim = None
for g in graph_list:
    if g.edge_attr is not None and g.edge_attr.numel() > 0:
        edge_in_dim = g.edge_attr.shape[1]
        break
assert edge_in_dim is not None, "All molecules have no bonds? Edge features couldn't be inferred."

kf = KFold(n_splits=5, shuffle=True, random_state=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

fold_best = []

for fold, (train_idx, val_idx) in enumerate(kf.split(graph_list), start=1):
    print(f"\n--- Fold {fold}/5 ---")

    train_set = [copy.deepcopy(graph_list[i]) for i in train_idx]
    val_set   = [copy.deepcopy(graph_list[i]) for i in val_idx]

    # Fit scaler on TRAIN targets only
    train_y_raw = np.vstack([g.y.detach().cpu().numpy() for g in train_set])  # (N,4)
    scaler = StandardScaler().fit(train_y_raw)

    # Apply scaling
    for g in train_set:
        g.y = torch.tensor(scaler.transform(g.y.detach().cpu().numpy()), dtype=torch.float)

    for g in val_set:
        g.y = torch.tensor(scaler.transform(g.y.detach().cpu().numpy()), dtype=torch.float)

    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=64, shuffle=False)

    model = DMPNN_Set2Set(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        hidden_dim=256,
        depth=3,
        out_dim=4,
        dropout=0.1,
    ).to(device)

    # Weight decay: try 1e-5 then 1e-4 if needed
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    criterion = torch.nn.MSELoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.5,
        patience=5,
        min_lr=1e-5,
        verbose=True
    )

    early = EarlyStopping(patience=15, min_delta=1e-4)
    best_val = float("inf")
    best_state = None

    for epoch in range(1, 101):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss   = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_loss)

        if val_loss < best_val:
            best_val = val_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

        if epoch % 10 == 0 or epoch == 1:
            lr = optimizer.param_groups[0]["lr"]
            print(f"Epoch {epoch:03d} | LR {lr:.2e} | Train MSE {train_loss:.4f} | Val MSE {val_loss:.4f}")

        if early.step(val_loss):
            print(f"Early stopping at epoch {epoch} (best val {best_val:.4f})")
            break

    # restore best weights for this fold
    if best_state is not None:
        model.load_state_dict({k: v.to(device) for k, v in best_state.items()})

    fold_best.append(best_val)

print(f"\nAverage 5-Fold best Val MSE: {np.mean(fold_best):.4f}")


--- Fold 1/5 ---




Epoch 001 | LR 1.00e-03 | Train MSE 0.3543 | Val MSE 0.2384
