In [10]:
#centroids.per
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os

# =====================================================================
# 1) model – everything behind one nn.Module
# =====================================================================
class ProteinModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        # -------------------------------- backbone --------------------------------
        self.egnn = StackedEGNN(
            dim=cfg.dim, depth=cfg.depth, hidden_dim=cfg.hidden_dim,
            dropout=cfg.dropout, num_positions=cfg.N_NEIGHBORS,
            num_tokens=118, num_nearest_neighbors=cfg.num_neighbors
        ).to(cfg.device)

        self.rbf  = TunableBlock(
            LearnableRBF(num_basis=cfg.basis, cutoff=10.0).to(cfg.device),
            enabled=cfg.use_rbf
        )
        self.attn = TunableBlock(
            AttentionBlock(embed_dim=cfg.dim+cfg.basis,
                           num_heads=cfg.dim+cfg.basis,
                           hidden_dim=cfg.hidden_dim).to(cfg.device),
            enabled=cfg.use_attn
        )
        # -------------------------------- aggregation choices ---------------------
        if cfg.aggregator == 'linear':
            self.nconv = None; out_dim = cfg.dim+cfg.basis
        elif cfg.aggregator in ('nconv','nconv+linear'):
            k = cfg.dim + cfg.basis
            self.nconv = nn.Conv1d(cfg.N_NEIGHBORS, 1, kernel_size=k, padding=0).to(cfg.device)
            out_dim = 1
        elif cfg.aggregator == 'pool':          # global max – keeps head alive but unused
            self.nconv = None; out_dim = cfg.dim+cfg.basis
        else:
            raise ValueError(f"unknown aggregator {cfg.aggregator}")

        self.pred_head = nn.Linear(out_dim, 1).to(cfg.device) if cfg.use_pred_head else nn.Identity()

        # second‑level EGNN on centroids (can be disabled via TunableBlock too)
        self.prot_egnn = TunableBlock(
            EGNN(dim=1, update_coors=False, num_nearest_neighbors=3).to(cfg.device),
            enabled=True
        )

    # ------------------------------------------------------------------
    def forward(self, z, x):                      # z:(R,N)  x:(R,N,3)
        h_list, coords = self.egnn(z, x)          # h:(R,N,dim) coords:(R,N,3)
        h = h_list[0] if isinstance(h_list,(list,tuple)) else h

        centroids = coords.mean(dim=1).unsqueeze(1)          # (R,1,3)
        rbf = self.rbf(centroids, coords)                    # (R,N,N,basis) or passthrough

        h_T = h.transpose(1,2)                               # (R,dim,N)
        r_T = rbf.transpose(1,2) if cfg.use_rbf else torch.empty_like(h_T[..., :cfg.basis])
        tok = torch.cat((r_T, h_T), 1)                       # (R,dim+basis,N)

        tok = tok.permute(2,0,1)                             # (N,R,C)
        tok,_ = self.attn(tok)                               # (N,R,C) (or identity)
        tok = tok.permute(1,0,2)                             # (R,N,C)

        # ---- aggregation route --------------------------------------
        if self.nconv is not None:
            tok = self.nconv(tok)                            # (R,1,1)
            tok = tok.squeeze(-1)                            # (R,1)
        elif cfg.aggregator == 'pool':
            tok = tok.max(dim=1).values                      # (R,C)
        else:                                                # 'linear'
            tok = tok                                        # (R,N,C) – max over N
            tok = tok.max(dim=1).values                      # (R,C)

        preds = self.pred_head(tok)                          # (R,1) or (R,1) Identity

        # protein‑level EGNN keeps gradients but (optionally) does nothing
        preds = preds.unsqueeze(0)                           # (1,R,1)
        coords_cent = centroids.permute(1,0,2)               # (1,R,3)
        preds = self.prot_egnn(preds, coords_cent)[0].squeeze(0)  # (R,1)

        return preds                                         # (R,1)

    # ------------------------------------------------------------------
    def n_trainable_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

model = ProteinModel(cfg)
print("Trainable params :", f"{model.n_trainable_parameters():,}")


# ---------------------------------------------------------------
# 0) hyper‑parameter “dashboard” – Py 3.6‑compatible
# ---------------------------------------------------------------
from types import SimpleNamespace
import torch, datetime
# ================================================================
# 0) ── hyper‑parameters / switches  (edit here, nothing else)
# ================================================================
cfg = {
    # backbone
    'dim'           : 12,
    'basis'         : 6,
    'depth'         : 2,
    'hidden_dim'    : 4,
    'num_neighbors' : 8,           # k in EGNN
    'dropout'       : 0.02,
    # neighbourhood tensor
    'N_NEIGHBORS'   : 100,
    # optional modules            ─────────   choose ablation
    #   'linear'        →   pred_head only
    #   'nconv+linear'  →   nconv then pred_head
    #   'nconv'         →   nconv only
    #   'pool'          →   max‑pool, ignores nconv & pred_head
    'aggregator'    : 'linear',
    'use_rbf'       : True,
    'use_attn'      : True,
    'use_nconv'     : False,
    'use_pred_head' : True,
    # loss / scheduler
    'loss_type'     : 'mae',        # 'mae' or 'mse'
    'sched_metric'  : 'val_mae',    # what ReduceLRO looks at
    # training
    'lr'            : 5e-3,
    'epochs'        : 20,
    'batch_size'    : 1,
    'device'        : 'cuda' if torch.cuda.is_available() else 'cpu',
    'runid'         : datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
}

print("Run‑ID:", cfg['runid'])

# 0) start timer
t0 = time.time()
N_NEIGHBORS = 100
BATCH_SIZE  =  1           # not safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 12, 6 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=4
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

# =====================================================================
# 2) dataset / dataloader – *identical* to your earlier code
# =====================================================================
N_NEIGHBORS = cfg.N_NEIGHBORS      # just to keep names aligned
# ...  (pad_collate, InMemoryHoodDataset)  ...

train_loader = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)
val_loader   = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)


train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



from architecture import TunableBlock, StackedEGNN, AttentionBlock, LearnableRBF

egnn_net = TunableBlock(
    StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                dropout=dropout, num_positions=1000, num_tokens=118,
                num_nearest_neighbors=num_neighbors, norm_coors=True),
    enabled=True
)

rbf_layer = TunableBlock(
    LearnableRBF(num_basis=basis, cutoff=cutoff),
    enabled=True
)

mha_layer = TunableBlock(
    AttentionBlock(embed_dim=dim + basis, num_heads=num_heads, hidden_dim=hidden_dim),
    enabled=True
)

nconv_layer = TunableBlock(
    nn.Conv1d(in_channels=N_NEIGHBORS, out_channels=1, kernel_size=dim + basis, padding=0),  # match your shape
    enabled=True
)

pred_head = TunableBlock(
    nn.Linear(1, 1),  # could also be Linear(dim + basis, 1) if used earlier
    enabled=True
)

protein_egnn = TunableBlock(
    EGNN(dim=1, update_coors=True, norm_coors=True, norm_feats=True,
         fourier_features=6, valid_radius=8),
    enabled=True
)

conv = TunableBlock(
    nn.Conv1d(1, 1, 7, padding=3),
    enabled=True
)

mha_layer.enabled = False
rbf_layer.enabled = True
nconv_layer.enabled = False
pred_head.enabled = True
conv.enabled = False

modules = [egnn_net, rbf_layer, mha_layer, nconv_layer, pred_head, protein_egnn, conv]
enabled_params = [p for m in modules if m.enabled for p in m.parameters()]
optimizer = torch.optim.AdamW(enabled_params, lr=5e-3)

# =====================================================================
# 3) training utilities (loss + scheduler are user‑switchable)
# =====================================================================
criterion_map = {'mse': nn.MSELoss, 'mae': nn.L1Loss}
criterion = criterion_map[cfg.loss_type.lower()]()

optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

scaler  = GradScaler(enabled=(cfg.device=='cuda'))

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    h_out, coords = egnn_net(z_r, x_r)
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out

    centroids = coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids, coords)

    h0 = h.transpose(1, 2)
    r0 = rbf.transpose(1, 2)
    tok = torch.cat((r0, h0), dim=1)

    tok_input = tok.permute(2, 0, 1)  # [N, R, C]

    mha_out = mha_layer(tok_input)

    if isinstance(mha_out, tuple):
        tok = mha_out[0]
    else:
        tok = mha_out  # passthrough from disabled module

    tok = tok.permute(1, 0, 2)  # [R, N, C]

    tok = tok.permute(1, 0, 2)                # (R, N, C)
    
    tok = nconv_layer(tok)                    # (R, 1, 1)
    tok = tok.squeeze(-1)                     # (R, 1)

    preds = pred_head(tok)                    # (R, 1)

    t = preds[:, 0].T.unsqueeze(2)            # (1, R, 1)
    coords = centroids.permute(1, 0, 2)       # (1, R, 3)
    preds = protein_egnn(t, coords)[0].permute(1, 2, 0)  # (R, 1, 1)

    preds = conv(preds)                       # (R, 1, 1)
    return preds.squeeze(-1), coords
# =====================================================================
# 4) training / validation loop (records both MAE & MSE every epoch)
# =====================================================================
def run_epoch(loader, train:bool):
    if train:
        model.train()
    else:
        model.eval()
    mae_meter, mse_meter = [], []
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg.device)
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg.device)
        y_r   = y.view(-1)[valid].to(cfg.device)

        with (autocast(enabled=(cfg.device=='cuda'))):
            preds = model(z_r, x_r).flatten()
            loss  = criterion(preds, y_r)
            mae   = nn.L1Loss()(preds, y_r).item()
            mse   = nn.MSELoss()(preds, y_r).item()

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        mae_meter.append(mae)
        mse_meter.append(mse)
    return np.mean(mae_meter), np.mean(mse_meter)

history = {'train_mae':[], 'val_mae':[], 'train_mse':[], 'val_mse':[]}

t0 = time.time()
for epoch in range(cfg.epochs):
    tr_mae, tr_mse = run_epoch(train_loader, True)
    vl_mae, vl_mse = run_epoch(val_loader,   False)

    history['train_mae'].append(tr_mae); history['train_mse'].append(tr_mse)
    history['val_mae'].append(vl_mae);   history['val_mse'].append(vl_mse)

    metric_to_step = vl_mae if cfg.scheduler_metric=='val_mae' else vl_mse
    scheduler.step(metric_to_step)

    print(f"[{epoch+1:03d}/{cfg.epochs}] "
          f"train MAE {tr_mae:.4f} | val MAE {vl_mae:.4f} | "
          f"train MSE {tr_mse:.4f} | val MSE {vl_mse:.4f}")

print("Total time:", (time.time()-t0)/60, "min")


#    with torch.no_grad():
#        for z, x, y, mask in val_loader:
#            valid   = mask.view(-1)
#            z_res   = z.view(-1, z.size(2))[valid].to(device)
#            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
#            y_res   = y.view(-1)[valid].to(device)

                #model
#            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
#            
#            preds = pred_head(feats)       
#            t=preds.unsqueeze(0)
            #preds=protein_egnn(t,centroids)[0]

#            loss  = criterion(preds.flatten(), y_res)
#            vl_losses.append(loss.item())

#    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
#print(time.time() - t0,"sec")##

# =====================================================================
# 5) save *everything required* to resume later
# =====================================================================
ckpt = {
    'model_state': model.state_dict(),
    'optimizer'  : optimizer.state_dict(),
    'scheduler'  : scheduler.state_dict(),
    'cfg'        : asdict(cfg),
    'history'    : history,
}
name = f"checkpoint_{cfg.runid}.pt"
#torch.save(ckpt, name)
print("Saved", name)

TypeError: __init__() missing 1 required positional argument: 'norm_coors'

In [13]:
# ================================================================
# 1) ── model definition  (uses your *unchanged* architecture.py)
# ================================================================
import torch, torch.nn as nn
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)

class Cfg(dict):
    """dot‑access + dict‑access wrapper (Py 3.6 safe)."""
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    def as_dict(self):               # for checkpoints
        return dict(self)
cfg = Cfg(
    # ------------- backbone -------------
    dim=12, basis=6, depth=2, hidden_dim=4,
    num_neighbors=8,          # k in EGNN
    dropout=0.02,
    norm_coors=True,          # <── NEW: make it tunable
    N_NEIGHBORS=100,

    # ---------- ablation switches -------
    aggregator='linear', use_rbf=True, use_attn=True,
    use_nconv=False, use_pred_head=True,

    # ---------- training / misc ---------
    loss_type='mae', sched_metric='val_mae',
    lr=5e-3, epochs=5, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)



print("Run‑ID:", cfg.runid)
print(cfg)

# ================================================================
# 0) ── hyper‑parameters / switches  (edit here, nothing else)


class ProteinModel(nn.Module):
    def __init__(self, c):
        super(ProteinModel, self).__init__()
        self.c = c

        # ---------------- EGNN backbone ----------------
        self.egnn = StackedEGNN(
    dim               = cfg.dim,
    depth             = cfg.depth,
    hidden_dim        = cfg.hidden_dim,
    dropout           = cfg.dropout,
    num_positions     = cfg.N_NEIGHBORS,
    num_tokens        = 98,
    num_nearest_neighbors = cfg.num_neighbors,
    norm_coors        = cfg.norm_coors     # <── ONLY extra argument
).to(cfg.device)

        # --------------- optional blocks ----------------
        self.rbf  = TunableBlock(
            LearnableRBF(num_basis=c['basis'], cutoff=10.0).to(c['device']),
            enabled=c['use_rbf']
        )
        self.attn = TunableBlock(
            AttentionBlock(embed_dim=c['dim']+c['basis'],
                           num_heads=c['dim']+c['basis'],
                           hidden_dim=c['hidden_dim']).to(c['device']),
            enabled=c['use_attn']
        )

        if c['aggregator'] in ('nconv', 'nconv+linear'):
            k = c['dim'] + c['basis']
            self.nconv = nn.Conv1d(c['N_NEIGHBORS'], 1, kernel_size=k, padding=0)\
                             .to(c['device'])
            out_dim = 1
        else:
            self.nconv = None
            out_dim = c['dim'] + c['basis']

        self.pred_head = (nn.Linear(out_dim, 1).to(c['device'])
                          if c['use_pred_head'] else nn.Identity())

        # ---- protein‑level EGNN on centroids (update_coors = True) ----
        self.prot_egnn = TunableBlock(
            EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)\
                .to(c['device']),
            enabled=True
        )

    # ---------------------------------------------------
    def forward(self, z, x):                 # z:(R,N)  x:(R,N,3)
        h_list, coords = self.egnn(z, x)      # h:(R,N,dim)
        h = h_list[0] if isinstance(h_list,(list,tuple)) else h_list

        centroids = coords.mean(dim=1).unsqueeze(1)            # (R,1,3)
        rbf = self.rbf(centroids, coords)                      # or passthrough

        h_T = h.transpose(1,2)                                 # (R,dim,N)
        if self.c['use_rbf']:
            r_T = rbf.transpose(1,2)
        else:
            r_T = torch.empty(0, device=h.device)

        tok = torch.cat((r_T, h_T), 1)                         # (R,dim+basis,N)

        tok = tok.permute(2,0,1)                               # (N,R,C)
        tok,_ = self.attn(tok)                                 # (N,R,C) / identity
        tok = tok.permute(1,0,2)                               # (R,N,C)

        # -------- aggregator routes --------
        if self.nconv is not None:
            tok = self.nconv(tok).squeeze(-1)                  # (R,1)
        elif self.c['aggregator'] == 'pool':                   # max‑pool
            tok = tok.max(dim=1).values                        # (R,C)
        else:                                                  # linear
            tok = tok.max(dim=1).values                        # (R,C)

        preds = self.pred_head(tok)                            # (R,1)

        # ---- residue → protein aggregate -----
        preds_ = preds.unsqueeze(0)                            # (1,R,1)
        coords_ = centroids.permute(1,0,2)                     # (1,R,3)
        preds  = self.prot_egnn(preds_, coords_)[0].squeeze(0) # (R,1)

        return preds

    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ── make sure cfg behaves like a dict no matter what it is
if not isinstance(cfg, dict):
    cfg = vars(cfg)          # SimpleNamespace → ordinary dict

model = ProteinModel(cfg)
print("Trainable parameters:", "{:,}".format(model.n_params()))




np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:5], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

# =====================================================================
# 2) dataset / dataloader – *identical* to your earlier code
# =====================================================================
N_NEIGHBORS = cfg.N_NEIGHBORS      # just to keep names aligned
# ...  (pad_collate, InMemoryHoodDataset)  ...

train_loader = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)
val_loader   = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)


train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)


# ================================================================
# 3) ── optimisation bits
# ================================================================
criterion = nn.L1Loss() if cfg['loss_type']=='mae' else nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

scaler = GradScaler(enabled=(cfg['device']=='cuda'))

# ================================================================
# 4) ── train / validate
# ================================================================
def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    mae_acc, mse_acc = [], []
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg['device'])
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg['device'])
        y_r   = y.view(-1)[valid].to(cfg['device'])

        with autocast(enabled=(cfg['device']=='cuda')):
            out  = model(z_r, x_r).flatten()
            loss = criterion(out, y_r)
            mae  = nn.L1Loss()(out, y_r).item()
            mse  = nn.MSELoss()(out, y_r).item()

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        mae_acc.append(mae); mse_acc.append(mse)
    return sum(mae_acc)/len(mae_acc), sum(mse_acc)/len(mse_acc)

for ep in range(cfg['epochs']):
    tr_mae, tr_mse = epoch_loop(train_loader, True)
    vl_mae, vl_mse = epoch_loop(val_loader,   False)

    metric = vl_mae if cfg['sched_metric']=='val_mae' else vl_mse
    scheduler.step(metric)

    print("[{:03d}/{:03d}]  "
          "train MAE {:.4f} | val MAE {:.4f}  ||  "
          "train MSE {:.4f} | val MSE {:.4f}".format(
              ep+1, cfg['epochs'], tr_mae, vl_mae, tr_mse, vl_mse))


# ================================================================
# 5) ── save everything needed to resume
# ================================================================
ckpt = {
    'model_state': model.state_dict(),
    'optim_state': optimizer.state_dict(),
    'sched_state': scheduler.state_dict(),
    'cfg'        : cfg,
}
ckpt_name = "ckpt_{}.pt".format(cfg['runid'])
#torch.save(ckpt, ckpt_name)
print("Saved checkpoint:", ckpt_name)



Run‑ID: 20250726_205300
Trainable parameters: 17,935




[001/020]  train MAE 1.3625 | val MAE 1.2577  ||  train MSE 3.8920 | val MSE 3.8825


KeyboardInterrupt: 

In [14]:
# ================================================================
# 1) ── model definition  (uses your *unchanged* architecture.py)
# ================================================================
import torch, torch.nn as nn
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)

class Cfg(dict):
    """dot‑access + dict‑access wrapper (Py 3.6 safe)."""
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    def as_dict(self):               # for checkpoints
        return dict(self)
cfg = Cfg(
    # ------------- backbone -------------
    dim=12, basis=6, depth=2, hidden_dim=4,
    num_neighbors=8,          # k in EGNN
    dropout=0.02,
    norm_coors=True,          # <── NEW: make it tunable
    N_NEIGHBORS=100,

    # ---------- ablation switches -------
    aggregator='linear', use_rbf=True, use_attn=True,
    use_nconv=False, use_pred_head=True,

    # ---------- training / misc ---------
    loss_type='mae', sched_metric='val_mae',
    lr=5e-3, epochs=5, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)



print("Run‑ID:", cfg.runid)
print(cfg)

# ================================================================
# 0) ── hyper‑parameters / switches  (edit here, nothing else)


class ProteinModel(nn.Module):
    def __init__(self, c):
        super(ProteinModel, self).__init__()
        self.c = c

        # ---------------- EGNN backbone ----------------
        self.egnn = StackedEGNN(
    dim               = cfg.dim,
    depth             = cfg.depth,
    hidden_dim        = cfg.hidden_dim,
    dropout           = cfg.dropout,
    num_positions     = cfg.N_NEIGHBORS,
    num_tokens        = 98,
    num_nearest_neighbors = cfg.num_neighbors,
    norm_coors        = cfg.norm_coors     # <── ONLY extra argument
).to(cfg.device)

        # --------------- optional blocks ----------------
        self.rbf  = TunableBlock(
            LearnableRBF(num_basis=c['basis'], cutoff=10.0).to(c['device']),
            enabled=c['use_rbf']
        )
        self.attn = TunableBlock(
            AttentionBlock(embed_dim=c['dim']+c['basis'],
                           num_heads=c['dim']+c['basis'],
                           hidden_dim=c['hidden_dim']).to(c['device']),
            enabled=c['use_attn']
        )

        if c['aggregator'] in ('nconv', 'nconv+linear'):
            k = c['dim'] + c['basis']
            self.nconv = nn.Conv1d(c['N_NEIGHBORS'], 1, kernel_size=k, padding=0)\
                             .to(c['device'])
            out_dim = 1
        else:
            self.nconv = None
            out_dim = c['dim'] + c['basis']

        self.pred_head = (nn.Linear(out_dim, 1).to(c['device'])
                          if c['use_pred_head'] else nn.Identity())

        # ---- protein‑level EGNN on centroids (update_coors = True) ----
        self.prot_egnn = TunableBlock(
            EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)\
                .to(c['device']),
            enabled=True
        )

    # ---------------------------------------------------
    def forward(self, z, x):                 # z:(R,N)  x:(R,N,3)
        h_list, coords = self.egnn(z, x)      # h:(R,N,dim)
        h = h_list[0] if isinstance(h_list,(list,tuple)) else h_list

        centroids = coords.mean(dim=1).unsqueeze(1)            # (R,1,3)
        rbf = self.rbf(centroids, coords)                      # or passthrough

        h_T = h.transpose(1,2)                                 # (R,dim,N)
        if self.c['use_rbf']:
            r_T = rbf.transpose(1,2)
        else:
            r_T = torch.empty(0, device=h.device)

        tok = torch.cat((r_T, h_T), 1)                         # (R,dim+basis,N)

        tok = tok.permute(2,0,1)                               # (N,R,C)
        tok,_ = self.attn(tok)                                 # (N,R,C) / identity
        tok = tok.permute(1,0,2)                               # (R,N,C)

        # -------- aggregator routes --------
        if self.nconv is not None:
            tok = self.nconv(tok).squeeze(-1)                  # (R,1)
        elif self.c['aggregator'] == 'pool':                   # max‑pool
            tok = tok.max(dim=1).values                        # (R,C)
        else:                                                  # linear
            tok = tok.max(dim=1).values                        # (R,C)

        preds = self.pred_head(tok)                            # (R,1)

        # ---- residue → protein aggregate -----
        preds_ = preds.unsqueeze(0)                            # (1,R,1)
        coords_ = centroids.permute(1,0,2)                     # (1,R,3)
        preds  = self.prot_egnn(preds_, coords_)[0].squeeze(0) # (R,1)

        return preds

    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ── make sure cfg behaves like a dict no matter what it is
if not isinstance(cfg, dict):
    cfg = vars(cfg)          # SimpleNamespace → ordinary dict

model = ProteinModel(cfg)
print("Trainable parameters:", "{:,}".format(model.n_params()))




np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:5], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

# =====================================================================
# 2) dataset / dataloader – *identical* to your earlier code
# =====================================================================
N_NEIGHBORS = cfg.N_NEIGHBORS      # just to keep names aligned
# ...  (pad_collate, InMemoryHoodDataset)  ...

train_loader = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)
val_loader   = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)


train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)


# ================================================================
# 3) ── optimisation bits
# ================================================================
criterion = nn.L1Loss() if cfg['loss_type']=='mae' else nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

scaler = GradScaler(enabled=(cfg['device']=='cuda'))

# ================================================================
# 4) ── train / validate
# ================================================================
def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    mae_acc, mse_acc = [], []
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg['device'])
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg['device'])
        y_r   = y.view(-1)[valid].to(cfg['device'])

        with autocast(enabled=(cfg['device']=='cuda')):
            out  = model(z_r, x_r).flatten()
            loss = criterion(out, y_r)
            mae  = nn.L1Loss()(out, y_r).item()
            mse  = nn.MSELoss()(out, y_r).item()

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        mae_acc.append(mae); mse_acc.append(mse)
    return sum(mae_acc)/len(mae_acc), sum(mse_acc)/len(mse_acc)

for ep in range(cfg['epochs']):
    tr_mae, tr_mse = epoch_loop(train_loader, True)
    vl_mae, vl_mse = epoch_loop(val_loader,   False)

    metric = vl_mae if cfg['sched_metric']=='val_mae' else vl_mse
    scheduler.step(metric)

    print("[{:03d}/{:03d}]  "
          "train MAE {:.4f} | val MAE {:.4f}  ||  "
          "train MSE {:.4f} | val MSE {:.4f}".format(
              ep+1, cfg['epochs'], tr_mae, vl_mae, tr_mse, vl_mse))


# ================================================================
# 5) ── save everything needed to resume
# ================================================================
ckpt = {
    'model_state': model.state_dict(),
    'optim_state': optimizer.state_dict(),
    'sched_state': scheduler.state_dict(),
    'cfg'        : cfg,
}
ckpt_name = "ckpt_{}.pt".format(cfg['runid'])
#torch.save(ckpt, ckpt_name)
print("Saved checkpoint:", ckpt_name)



Run‑ID: 20250726_205840
{'dim': 12, 'basis': 6, 'depth': 2, 'hidden_dim': 4, 'num_neighbors': 8, 'dropout': 0.02, 'norm_coors': True, 'N_NEIGHBORS': 100, 'aggregator': 'linear', 'use_rbf': True, 'use_attn': True, 'use_nconv': False, 'use_pred_head': True, 'loss_type': 'mae', 'sched_metric': 'val_mae', 'lr': 0.005, 'epochs': 5, 'batch_size': 1, 'device': 'cpu', 'runid': '20250726_205840'}
Trainable parameters: 17,695




[001/005]  train MAE 1.4693 | val MAE 1.4530  ||  train MSE 3.8303 | val MSE 4.6516
[002/005]  train MAE 1.2156 | val MAE 1.2802  ||  train MSE 2.9310 | val MSE 3.9881
[003/005]  train MAE 1.1161 | val MAE 1.2816  ||  train MSE 2.7563 | val MSE 3.7701
[004/005]  train MAE 1.1059 | val MAE 1.2627  ||  train MSE 2.7447 | val MSE 3.9004
[005/005]  train MAE 1.0930 | val MAE 1.2660  ||  train MSE 2.6866 | val MSE 3.9391
Saved checkpoint: ckpt_20250726_205840.pt


In [15]:
# ================================================================
# 1) ── model definition  (uses your *unchanged* architecture.py)
# ================================================================
import torch, torch.nn as nn
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)

class Cfg(dict):
    """dot‑access + dict‑access wrapper (Py 3.6 safe)."""
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    def as_dict(self):               # for checkpoints
        return dict(self)
cfg = Cfg(
    # ------------- backbone -------------
    dim=12, basis=6, depth=2, hidden_dim=4,
    num_neighbors=8,          # k in EGNN
    dropout=0.02,
    norm_coors=True,          # <── NEW: make it tunable
    N_NEIGHBORS=100,

    # ---------- ablation switches -------
    aggregator='linear', use_rbf=True, use_attn=True,
    use_nconv=False, use_pred_head=True,

    # ---------- training / misc ---------

    loss_type='mae',          # 'mae' or 'mse'  ← primary loss
    sched_metric='val_mae',   # what ReduceLRO sees
    study_metrics=True,       # <── NEW: if False we skip the secondary metric

    lr=5e-3, epochs=5, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)



print("Run‑ID:", cfg.runid)
print(cfg)

# ================================================================
# 0) ── hyper‑parameters / switches  (edit here, nothing else)


class ProteinModel(nn.Module):
    def __init__(self, c):
        super(ProteinModel, self).__init__()
        self.c = c

        # ---------------- EGNN backbone ----------------
        self.egnn = StackedEGNN(
    dim               = cfg.dim,
    depth             = cfg.depth,
    hidden_dim        = cfg.hidden_dim,
    dropout           = cfg.dropout,
    num_positions     = cfg.N_NEIGHBORS,
    num_tokens        = 98,
    num_nearest_neighbors = cfg.num_neighbors,
    norm_coors        = cfg.norm_coors     # <── ONLY extra argument
).to(cfg.device)

        # --------------- optional blocks ----------------
        self.rbf  = TunableBlock(
            LearnableRBF(num_basis=c['basis'], cutoff=10.0).to(c['device']),
            enabled=c['use_rbf']
        )
        self.attn = TunableBlock(
            AttentionBlock(embed_dim=c['dim']+c['basis'],
                           num_heads=c['dim']+c['basis'],
                           hidden_dim=c['hidden_dim']).to(c['device']),
            enabled=c['use_attn']
        )

        if c['aggregator'] in ('nconv', 'nconv+linear'):
            k = c['dim'] + c['basis']
            self.nconv = nn.Conv1d(c['N_NEIGHBORS'], 1, kernel_size=k, padding=0)\
                             .to(c['device'])
            out_dim = 1
        else:
            self.nconv = None
            out_dim = c['dim'] + c['basis']

        self.pred_head = (nn.Linear(out_dim, 1).to(c['device'])
                          if c['use_pred_head'] else nn.Identity())

        # ---- protein‑level EGNN on centroids (update_coors = True) ----
        self.prot_egnn = TunableBlock(
            EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)\
                .to(c['device']),
            enabled=True
        )

    # ---------------------------------------------------
    def forward(self, z, x):                 # z:(R,N)  x:(R,N,3)
        h_list, coords = self.egnn(z, x)      # h:(R,N,dim)
        h = h_list[0] if isinstance(h_list,(list,tuple)) else h_list

        centroids = coords.mean(dim=1).unsqueeze(1)            # (R,1,3)
        rbf = self.rbf(centroids, coords)                      # or passthrough

        h_T = h.transpose(1,2)                                 # (R,dim,N)
        if self.c['use_rbf']:
            r_T = rbf.transpose(1,2)
        else:
            r_T = torch.empty(0, device=h.device)

        tok = torch.cat((r_T, h_T), 1)                         # (R,dim+basis,N)

        tok = tok.permute(2,0,1)                               # (N,R,C)
        tok,_ = self.attn(tok)                                 # (N,R,C) / identity
        tok = tok.permute(1,0,2)                               # (R,N,C)

        # -------- aggregator routes --------
        if self.nconv is not None:
            tok = self.nconv(tok).squeeze(-1)                  # (R,1)
        elif self.c['aggregator'] == 'pool':                   # max‑pool
            tok = tok.max(dim=1).values                        # (R,C)
        else:                                                  # linear
            tok = tok.max(dim=1).values                        # (R,C)

        preds = self.pred_head(tok)                            # (R,1)

        # ---- residue → protein aggregate -----
        preds_ = preds.unsqueeze(0)                            # (1,R,1)
        coords_ = centroids.permute(1,0,2)                     # (1,R,3)
        preds  = self.prot_egnn(preds_, coords_)[0].squeeze(0) # (R,1)

        return preds

    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ── make sure cfg behaves like a dict no matter what it is
if not isinstance(cfg, dict):
    cfg = vars(cfg)          # SimpleNamespace → ordinary dict

model = ProteinModel(cfg)
print("Trainable parameters:", "{:,}".format(model.n_params()))




np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:5], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

# =====================================================================
# 2) dataset / dataloader – *identical* to your earlier code
# =====================================================================
N_NEIGHBORS = cfg.N_NEIGHBORS      # just to keep names aligned
# ...  (pad_collate, InMemoryHoodDataset)  ...

train_loader = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)
val_loader   = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)


train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)


# ================================================================
# 3) ── optimisation bits
# ================================================================
criterion = nn.L1Loss() if cfg['loss_type']=='mae' else nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

scaler = GradScaler(enabled=(cfg['device']=='cuda'))
# pick loss & the “other” metric once, outside the loop
if cfg.loss_type.lower() == 'mae':
    primary_loss_fn   = nn.L1Loss()
    secondary_fn      = nn.MSELoss() if cfg.study_metrics else None
    primary_name      = 'MAE'
else:
    primary_loss_fn   = nn.MSELoss()
    secondary_fn      = nn.L1Loss()  if cfg.study_metrics else None
    primary_name      = 'MSE'

def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    primary_sum, secondary_sum, n = 0.0, 0.0, 0
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg.device)
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg.device)
        y_r   = y.view(-1)[valid].to(cfg.device)

        with autocast(enabled=(cfg.device=='cuda')):
            preds   = model(z_r, x_r).flatten()
            loss    = primary_loss_fn(preds, y_r)
            primary = loss.item()
            if secondary_fn is not None:
                secondary = secondary_fn(preds, y_r).item()
            else:
                secondary = 0.0

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        primary_sum   += primary
        secondary_sum += secondary
        n += 1

    return primary_sum/n, (secondary_sum/n if secondary_fn is not None else None)

# ================================================================
# 4) ── train / validate
# ================================================================
def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    mae_acc, mse_acc = [], []
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg['device'])
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg['device'])
        y_r   = y.view(-1)[valid].to(cfg['device'])

        with autocast(enabled=(cfg['device']=='cuda')):
            out  = model(z_r, x_r).flatten()
            loss = criterion(out, y_r)
            mae  = nn.L1Loss()(out, y_r).item()
            mse  = nn.MSELoss()(out, y_r).item()

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        mae_acc.append(mae); mse_acc.append(mse)
    return sum(mae_acc)/len(mae_acc), sum(mse_acc)/len(mse_acc)

for ep in range(cfg.epochs):
    tr_primary, tr_sec = epoch_loop(train_loader, True)
    vl_primary, vl_sec = epoch_loop(val_loader, False)

    scheduler.step(vl_primary if cfg.sched_metric=='val_'+primary_name.lower()
                   else vl_sec)

    msg = "[{}/{}] train {} {:.4f} | val {} {:.4f}".format(
            ep+1, cfg.epochs, primary_name, tr_primary, primary_name, vl_primary)
    if cfg.study_metrics and tr_sec is not None:
        other = 'MSE' if primary_name=='MAE' else 'MAE'
        msg += "  ||  train {} {:.4f} | val {} {:.4f}".format(
                other, tr_sec, other, vl_sec)
    print(msg)

# ================================================================
# 5) ── save everything needed to resume
# ================================================================
ckpt = {
    'model_state': model.state_dict(),
    'optim_state': optimizer.state_dict(),
    'sched_state': scheduler.state_dict(),
    'cfg'        : cfg,
}
ckpt_name = "ckpt_{}.pt".format(cfg['runid'])
#torch.save(ckpt, ckpt_name)
print("Saved checkpoint:", ckpt_name)



Run‑ID: 20250726_210152
{'dim': 12, 'basis': 6, 'depth': 2, 'hidden_dim': 4, 'num_neighbors': 8, 'dropout': 0.02, 'norm_coors': True, 'N_NEIGHBORS': 100, 'aggregator': 'linear', 'use_rbf': True, 'use_attn': True, 'use_nconv': False, 'use_pred_head': True, 'loss_type': 'mae', 'sched_metric': 'val_mae', 'study_metrics': True, 'lr': 0.005, 'epochs': 5, 'batch_size': 1, 'device': 'cpu', 'runid': '20250726_210152'}
Trainable parameters: 17,695




[1/5] train MAE 1.1944 | val MAE 1.2649  ||  train MSE 2.9880 | val MSE 3.8912
[2/5] train MAE 1.0796 | val MAE 1.2528  ||  train MSE 2.6411 | val MSE 3.8555
[3/5] train MAE 1.0631 | val MAE 1.2542  ||  train MSE 2.5918 | val MSE 3.9007
[4/5] train MAE 1.0472 | val MAE 1.2357  ||  train MSE 2.5451 | val MSE 3.7295
[5/5] train MAE 1.0221 | val MAE 1.2382  ||  train MSE 2.4487 | val MSE 3.8262
Saved checkpoint: ckpt_20250726_210152.pt


In [None]:
# ================================================================
# 1) ── model definition  (uses your *unchanged* architecture.py)
# ================================================================
import torch, torch.nn as nn
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)

class Cfg(dict):
    """dot‑access + dict‑access wrapper (Py 3.6 safe)."""
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    def as_dict(self):               # for checkpoints
        return dict(self)
cfg = Cfg(
    # ------------- backbone -------------
    dim=12, basis=6, depth=2, hidden_dim=4,
    num_neighbors=8,          # k in EGNN
    dropout=0.02,
    norm_coors=True,          # <── NEW: make it tunable
    N_NEIGHBORS=100,

    # ---------- ablation switches -------
    aggregator='linear', use_rbf=True, use_attn=True,
    use_nconv=False, use_pred_head=True,

    # ---------- training / misc ---------

    loss_type='mae',          # 'mae' or 'mse'  ← primary loss
    sched_metric='val_mae',   # what ReduceLRO sees
    study_metrics=True,       # <── NEW: if False we skip the secondary metric

    lr=5e-3, epochs=5, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")



print("Run‑ID:", cfg.runid)
print(cfg)

# ================================================================
# 0) ── hyper‑parameters / switches  (edit here, nothing else)


class ProteinModel(nn.Module):
    def __init__(self, c):
        super(ProteinModel, self).__init__()
        self.c = c

        # ---------------- EGNN backbone ----------------
        self.egnn = StackedEGNN(
    dim               = cfg.dim,
    depth             = cfg.depth,
    hidden_dim        = cfg.hidden_dim,
    dropout           = cfg.dropout,
    num_positions     = cfg.N_NEIGHBORS,
    num_tokens        = 98,
    num_nearest_neighbors = cfg.num_neighbors,
    norm_coors        = cfg.norm_coors     # <── ONLY extra argument
).to(cfg.device)

        # --------------- optional blocks ----------------
        self.rbf  = TunableBlock(
            LearnableRBF(num_basis=c['basis'], cutoff=10.0).to(c['device']),
            enabled=c['use_rbf']
        )
        self.attn = TunableBlock(
            AttentionBlock(embed_dim=c['dim']+c['basis'],
                           num_heads=c['dim']+c['basis'],
                           hidden_dim=c['hidden_dim']).to(c['device']),
            enabled=c['use_attn']
        )

        if c['aggregator'] in ('nconv', 'nconv+linear'):
            k = c['dim'] + c['basis']
            self.nconv = nn.Conv1d(c['N_NEIGHBORS'], 1, kernel_size=k, padding=0)\
                             .to(c['device'])
            out_dim = 1
        else:
            self.nconv = None
            out_dim = c['dim'] + c['basis']

        self.pred_head = (nn.Linear(out_dim, 1).to(c['device'])
                          if c['use_pred_head'] else nn.Identity())

        # ---- protein‑level EGNN on centroids (update_coors = True) ----
        self.prot_egnn = TunableBlock(
            EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)\
                .to(c['device']),
            enabled=True
        )

    # ---------------------------------------------------
    def forward(self, z, x):                 # z:(R,N)  x:(R,N,3)
        h_list, coords = self.egnn(z, x)      # h:(R,N,dim)
        h = h_list[0] if isinstance(h_list,(list,tuple)) else h_list

        centroids = coords.mean(dim=1).unsqueeze(1)            # (R,1,3)
        rbf = self.rbf(centroids, coords)                      # or passthrough

        h_T = h.transpose(1,2)                                 # (R,dim,N)
        if self.c['use_rbf']:
            r_T = rbf.transpose(1,2)
        else:
            r_T = torch.empty(0, device=h.device)

        tok = torch.cat((r_T, h_T), 1)                         # (R,dim+basis,N)

        tok = tok.permute(2,0,1)                               # (N,R,C)
        tok,_ = self.attn(tok)                                 # (N,R,C) / identity
        tok = tok.permute(1,0,2)                               # (R,N,C)

        # -------- aggregator routes --------
        if self.nconv is not None:
            tok = self.nconv(tok).squeeze(-1)                  # (R,1)
        elif self.c['aggregator'] == 'pool':                   # max‑pool
            tok = tok.max(dim=1).values                        # (R,C)
        else:                                                  # linear
            tok = tok.max(dim=1).values                        # (R,C)

        preds = self.pred_head(tok)                            # (R,1)

        # ---- residue → protein aggregate -----
        preds_ = preds.unsqueeze(0)                            # (1,R,1)
        coords_ = centroids.permute(1,0,2)                     # (1,R,3)
        preds  = self.prot_egnn(preds_, coords_)[0].squeeze(0) # (R,1)

        return preds

    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ── make sure cfg behaves like a dict no matter what it is
if not isinstance(cfg, dict):
    cfg = vars(cfg)          # SimpleNamespace → ordinary dict

model = ProteinModel(cfg)
print("Trainable parameters:", "{:,}".format(model.n_params()))




np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:5], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

# =====================================================================
# 2) dataset / dataloader – *identical* to your earlier code
# =====================================================================
N_NEIGHBORS = cfg.N_NEIGHBORS      # just to keep names aligned
# ...  (pad_collate, InMemoryHoodDataset)  ...

train_loader = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)
val_loader   = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)


train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)


# ================================================================
# 3) ── optimisation bits
# ================================================================
criterion = nn.L1Loss() if cfg['loss_type']=='mae' else nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

scaler = GradScaler(enabled=(cfg['device']=='cuda'))
# pick loss & the “other” metric once, outside the loop
if cfg.loss_type.lower() == 'mae':
    primary_loss_fn   = nn.L1Loss()
    secondary_fn      = nn.MSELoss() if cfg.study_metrics else None
    primary_name      = 'MAE'
else:
    primary_loss_fn   = nn.MSELoss()
    secondary_fn      = nn.L1Loss()  if cfg.study_metrics else None
    primary_name      = 'MSE'

def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    primary_sum, secondary_sum, n = 0.0, 0.0, 0
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg.device)
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg.device)
        y_r   = y.view(-1)[valid].to(cfg.device)

        with autocast(enabled=(cfg.device=='cuda')):
            preds   = model(z_r, x_r).flatten()
            loss    = primary_loss_fn(preds, y_r)
            primary = loss.item()
            if secondary_fn is not None:
                secondary = secondary_fn(preds, y_r).item()
            else:
                secondary = 0.0

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        primary_sum   += primary
        secondary_sum += secondary
        n += 1

    return primary_sum/n, (secondary_sum/n if secondary_fn is not None else None)

# ================================================================
# 4) ── train / validate
# ================================================================
def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    mae_acc, mse_acc = [], []
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg['device'])
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg['device'])
        y_r   = y.view(-1)[valid].to(cfg['device'])

        with autocast(enabled=(cfg['device']=='cuda')):
            out  = model(z_r, x_r).flatten()
            loss = criterion(out, y_r)
            mae  = nn.L1Loss()(out, y_r).item()
            mse  = nn.MSELoss()(out, y_r).item()

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        mae_acc.append(mae); mse_acc.append(mse)
    return sum(mae_acc)/len(mae_acc), sum(mse_acc)/len(mse_acc)

for ep in range(cfg.epochs):
    tr_primary, tr_sec = epoch_loop(train_loader, True)
    vl_primary, vl_sec = epoch_loop(val_loader, False)

    scheduler.step(vl_primary if cfg.sched_metric=='val_'+primary_name.lower()
                   else vl_sec)

    msg = "[{}/{}] train {} {:.4f} | val {} {:.4f}".format(
            ep+1, cfg.epochs, primary_name, tr_primary, primary_name, vl_primary)
    if cfg.study_metrics and tr_sec is not None:
        other = 'MSE' if primary_name=='MAE' else 'MAE'
        msg += "  ||  train {} {:.4f} | val {} {:.4f}".format(
                other, tr_sec, other, vl_sec)
    print(msg)

# ================================================================
# 5) ── save everything needed to resume
# ================================================================
ckpt = {
    'model_state': model.state_dict(),
    'optim_state': optimizer.state_dict(),
    'sched_state': scheduler.state_dict(),
    'cfg'        : cfg,
}
ckpt_name = "ckpt_{}.pt".format(cfg['runid'])
#torch.save(ckpt, ckpt_name)
print("Saved checkpoint:", ckpt_name)



Run‑ID: 20250726_210254
{'dim': 12, 'basis': 6, 'depth': 2, 'hidden_dim': 4, 'num_neighbors': 8, 'dropout': 0.02, 'norm_coors': True, 'N_NEIGHBORS': 100, 'aggregator': 'linear', 'use_rbf': True, 'use_attn': True, 'use_nconv': False, 'use_pred_head': True, 'loss_type': 'mae', 'sched_metric': 'val_mae', 'study_metrics': True, 'lr': 0.005, 'epochs': 5, 'batch_size': 1, 'device': 'cpu', 'runid': '20250726_210254'}
Trainable parameters: 17,695




[1/5] train MAE 1.2831 | val MAE 1.5003  ||  train MSE 3.1863 | val MSE 4.8607


KeyboardInterrupt: 

In [None]:
# ================================================================
# 1) ── model definition  (uses your *unchanged* architecture.py)
# ================================================================
import torch, torch.nn as nn
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)

class Cfg(dict):
    """dot‑access + dict‑access wrapper (Py 3.6 safe)."""
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    def as_dict(self):               # for checkpoints
        return dict(self)
cfg = Cfg(
    # ------------- backbone -------------
    dim=12, basis=6, depth=2, hidden_dim=4,
    num_neighbors=8,          # k in EGNN
    dropout=0.02,
    norm_coors=True,          # <── NEW: make it tunable
    N_NEIGHBORS=100,

    # ---------- ablation switches -------
    aggregator='linear', use_rbf=True, use_attn=True,
    use_nconv=False, use_pred_head=True,

    # ---------- training / misc ---------

    loss_type='mae',          # 'mae' or 'mse'  ← primary loss
    sched_metric='val_mae',   # what ReduceLRO sees
    study_metrics=True,       # <── NEW: if False we skip the secondary metric

    lr=5e-3, epochs=5, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")



print("Run‑ID:", cfg.runid)
print(cfg)

# ================================================================
# 0) ── hyper‑parameters / switches  (edit here, nothing else)


class ProteinModel(nn.Module):
    def __init__(self, c):
        super(ProteinModel, self).__init__()
        self.c = c

        # ---------------- EGNN backbone ----------------
        self.egnn = StackedEGNN(
    dim               = cfg.dim,
    depth             = cfg.depth,
    hidden_dim        = cfg.hidden_dim,
    dropout           = cfg.dropout,
    num_positions     = cfg.N_NEIGHBORS,
    num_tokens        = 98,
    num_nearest_neighbors = cfg.num_neighbors,
    norm_coors        = cfg.norm_coors     # <── ONLY extra argument
).to(cfg.device)

        # --------------- optional blocks ----------------
        self.rbf  = TunableBlock(
            LearnableRBF(num_basis=c['basis'], cutoff=10.0).to(c['device']),
            enabled=c['use_rbf']
        )
        self.attn = TunableBlock(
            AttentionBlock(embed_dim=c['dim']+c['basis'],
                           num_heads=c['dim']+c['basis'],
                           hidden_dim=c['hidden_dim']).to(c['device']),
            enabled=c['use_attn']
        )

        if c['aggregator'] in ('nconv', 'nconv+linear'):
            k = c['dim'] + c['basis']
            self.nconv = nn.Conv1d(c['N_NEIGHBORS'], 1, kernel_size=k, padding=0)\
                             .to(c['device'])
            out_dim = 1
        else:
            self.nconv = None
            out_dim = c['dim'] + c['basis']

        self.pred_head = (nn.Linear(out_dim, 1).to(c['device'])
                          if c['use_pred_head'] else nn.Identity())

        # ---- protein‑level EGNN on centroids (update_coors = True) ----
        self.prot_egnn = TunableBlock(
            EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)\
                .to(c['device']),
            enabled=True
        )

    # ---------------------------------------------------
    def forward(self, z, x):                 # z:(R,N)  x:(R,N,3)
        h_list, coords = self.egnn(z, x)      # h:(R,N,dim)
        h = h_list[0] if isinstance(h_list,(list,tuple)) else h_list

        centroids = coords.mean(dim=1).unsqueeze(1)            # (R,1,3)
        rbf = self.rbf(centroids, coords)                      # or passthrough

        h_T = h.transpose(1,2)                                 # (R,dim,N)
        if self.c['use_rbf']:
            r_T = rbf.transpose(1,2)
        else:
            r_T = torch.empty(0, device=h.device)

        tok = torch.cat((r_T, h_T), 1)                         # (R,dim+basis,N)

        tok = tok.permute(2,0,1)                               # (N,R,C)
        tok,_ = self.attn(tok)                                 # (N,R,C) / identity
        tok = tok.permute(1,0,2)                               # (R,N,C)

        # -------- aggregator routes --------
        if self.nconv is not None:
            tok = self.nconv(tok).squeeze(-1)                  # (R,1)
        elif self.c['aggregator'] == 'pool':                   # max‑pool
            tok = tok.max(dim=1).values                        # (R,C)
        else:                                                  # linear
            tok = tok.max(dim=1).values                        # (R,C)

        preds = self.pred_head(tok)                            # (R,1)

        # ---- residue → protein aggregate -----
        preds_ = preds.unsqueeze(0)                            # (1,R,1)
        coords_ = centroids.permute(1,0,2)                     # (1,R,3)
        preds  = self.prot_egnn(preds_, coords_)[0].squeeze(0) # (R,1)

        return preds

    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ── make sure cfg behaves like a dict no matter what it is
if not isinstance(cfg, dict):
    cfg = vars(cfg)          # SimpleNamespace → ordinary dict

model = ProteinModel(cfg)
print("Trainable parameters:", "{:,}".format(model.n_params()))




np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:5], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

# =====================================================================
# 2) dataset / dataloader – *identical* to your earlier code
# =====================================================================
N_NEIGHBORS = cfg.N_NEIGHBORS      # just to keep names aligned
# ...  (pad_collate, InMemoryHoodDataset)  ...

train_loader = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)
val_loader   = DataLoader( ... , batch_size=cfg.batch_size, collate_fn=pad_collate)


train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)


# ================================================================
# 3) ── optimisation bits
# ================================================================
criterion = nn.L1Loss() if cfg['loss_type']=='mae' else nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

scaler = GradScaler(enabled=(cfg['device']=='cuda'))
# pick loss & the “other” metric once, outside the loop
if cfg.loss_type.lower() == 'mae':
    primary_loss_fn   = nn.L1Loss()
    secondary_fn      = nn.MSELoss() if cfg.study_metrics else None
    primary_name      = 'MAE'
else:
    primary_loss_fn   = nn.MSELoss()
    secondary_fn      = nn.L1Loss()  if cfg.study_metrics else None
    primary_name      = 'MSE'

def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    primary_sum, secondary_sum, n = 0.0, 0.0, 0
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg.device)
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg.device)
        y_r   = y.view(-1)[valid].to(cfg.device)

        with autocast(enabled=(cfg.device=='cuda')):
            preds   = model(z_r, x_r).flatten()
            loss    = primary_loss_fn(preds, y_r)
            primary = loss.item()
            if secondary_fn is not None:
                secondary = secondary_fn(preds, y_r).item()
            else:
                secondary = 0.0

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        primary_sum   += primary
        secondary_sum += secondary
        n += 1

    return primary_sum/n, (secondary_sum/n if secondary_fn is not None else None)

# ================================================================
# 4) ── train / validate
# ================================================================
def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    mae_acc, mse_acc = [], []
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg['device'])
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg['device'])
        y_r   = y.view(-1)[valid].to(cfg['device'])

        with autocast(enabled=(cfg['device']=='cuda')):
            out  = model(z_r, x_r).flatten()
            loss = criterion(out, y_r)
            mae  = nn.L1Loss()(out, y_r).item()
            mse  = nn.MSELoss()(out, y_r).item()

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        mae_acc.append(mae); mse_acc.append(mse)
    return sum(mae_acc)/len(mae_acc), sum(mse_acc)/len(mse_acc)

for ep in range(cfg.epochs):
    tr_primary, tr_sec = epoch_loop(train_loader, True)
    vl_primary, vl_sec = epoch_loop(val_loader, False)

    scheduler.step(vl_primary if cfg.sched_metric=='val_'+primary_name.lower()
                   else vl_sec)

    msg = "[{}/{}] train {} {:.4f} | val {} {:.4f}".format(
            ep+1, cfg.epochs, primary_name, tr_primary, primary_name, vl_primary)
    if cfg.study_metrics and tr_sec is not None:
        other = 'MSE' if primary_name=='MAE' else 'MAE'
        msg += "  ||  train {} {:.4f} | val {} {:.4f}".format(
                other, tr_sec, other, vl_sec)
    print(msg)

# ================================================================
# 5) ── save everything needed to resume
# ================================================================
ckpt = {
    'model_state': model.state_dict(),
    'optim_state': optimizer.state_dict(),
    'sched_state': scheduler.state_dict(),
    'cfg'        : cfg,
}
ckpt_name = "ckpt_{}.pt".format(cfg['runid'])
#torch.save(ckpt, ckpt_name)
print("Saved checkpoint:", ckpt_name)



TypeError: __init__() missing 1 required positional argument: 'norm_coors'