In [4]:
# ================================================================
# 1) ── model definition  (uses your *unchanged* architecture.py)
# ================================================================
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

class Cfg(dict):
    """dot‑access + dict‑access wrapper (Py 3.6 safe)."""
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    def as_dict(self):               # for checkpoints
        return dict(self)
cfg = Cfg(
    # ------------- backbone -------------
    dim=12, basis=6, depth=2, hidden_dim=4,
    num_neighbors=8,          # k in EGNN
    dropout=0.02,
    norm_coors=True,          # <── NEW: make it tunable
    N_NEIGHBORS=100,

    # ---------- ablation switches -------
    aggregator='linear', use_rbf=True, use_attn=True,
    use_nconv=False, use_pred_head=True,

    # ---------- training / misc ---------

    loss_type='mae',          # 'mae' or 'mse'  ← primary loss
    sched_metric='val_mae',   # what ReduceLRO sees
    study_metrics=True,       # <── NEW: if False we skip the secondary metric

    lr=5e-3, epochs=5, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S"),

    seed        = 0,       # ← set to int for deterministic run, None for random
    save_attn   = False,    # ← flip to True when you need the weights

    num_paths = 2,
    split_mode  = 'random',     # 'random' | 'file'
    split_ratio = 0.8,          # used when split_mode == 'random'
    split_seed  = 0,          # reproducible shuffle
    split_files = {             # used when split_mode == 'file'
        'train': 'train_list.txt',
        'val'  : 'val_list.txt',
    },
)


PIN_MEMORY=False
print("Run‑ID:", cfg.runid)
print(cfg)
# ── reproducibility (incl. cuDNN, DataLoader shuffles, etc.)
if cfg.seed is not None:
    import random, os
    random.seed(cfg.seed);  np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(cfg.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False


# ================================================================
# 0) ── hyper‑parameters / switches  (edit here, nothing else)


class ProteinModel(nn.Module):
    def __init__(self, c):
        super(ProteinModel, self).__init__()
        self.c = c

        # ---------------- EGNN backbone ----------------
        self.egnn = StackedEGNN(
    dim               = cfg.dim,
    depth             = cfg.depth,
    hidden_dim        = cfg.hidden_dim,
    dropout           = cfg.dropout,
    num_positions     = cfg.N_NEIGHBORS,
    num_tokens        = 98,
    num_nearest_neighbors = cfg.num_neighbors,
    norm_coors        = cfg.norm_coors     # <── ONLY extra argument
).to(cfg.device)

        # --------------- optional blocks ----------------
        self.rbf  = TunableBlock(
            LearnableRBF(num_basis=c['basis'], cutoff=10.0).to(c['device']),
            enabled=c['use_rbf']
        )
        self.attn = TunableBlock(
            AttentionBlock(embed_dim=c['dim']+c['basis'],
                           num_heads=c['dim']+c['basis'],
                           hidden_dim=c['hidden_dim']).to(c['device']),
            enabled=c['use_attn']
        )

        if c['aggregator'] in ('nconv', 'nconv+linear'):
            k = c['dim'] + c['basis']
            self.nconv = nn.Conv1d(c['N_NEIGHBORS'], 1, kernel_size=k, padding=0)\
                             .to(c['device'])
            out_dim = 1
        else:
            self.nconv = None
            out_dim = c['dim'] + c['basis']

        self.pred_head = (nn.Linear(out_dim, 1).to(c['device'])
                          if c['use_pred_head'] else nn.Identity())

        # ---- protein‑level EGNN on centroids (update_coors = True) ----
        self.prot_egnn = TunableBlock(
            EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)\
                .to(c['device']),
            enabled=True
        )

    # ---------------------------------------------------
    def forward(self, z, x):                 # z:(R,N)  x:(R,N,3)
        h_list, coords = self.egnn(z, x)      # h:(R,N,dim)
        h = h_list[0] if isinstance(h_list,(list,tuple)) else h_list

        centroids = coords.mean(dim=1).unsqueeze(1)            # (R,1,3)
        rbf = self.rbf(centroids, coords)                      # or passthrough

        h_T = h.transpose(1,2)                                 # (R,dim,N)
        if self.c['use_rbf']:
            r_T = rbf.transpose(1,2)
        else:
            r_T = torch.empty(0, device=h.device)

        tok = torch.cat((r_T, h_T), 1)                         # (R,dim+basis,N)

        tok = tok.permute(2,0,1)                               # (N,R,C)
        tok,_ = self.attn(tok)                                 # (N,R,C) / identity
        tok = tok.permute(1,0,2)                               # (R,N,C)

        # -------- aggregator routes --------
        if self.nconv is not None:
            tok = self.nconv(tok).squeeze(-1)                  # (R,1)
        elif self.c['aggregator'] == 'pool':                   # max‑pool
            tok = tok.max(dim=1).values                        # (R,C)
        else:                                                  # linear
            tok = tok.max(dim=1).values                        # (R,C)

        preds = self.pred_head(tok)                            # (R,1)

        # ---- residue → protein aggregate -----
        preds_ = preds.unsqueeze(0)                            # (1,R,1)
        coords_ = centroids.permute(1,0,2)                     # (1,R,3)
        preds  = self.prot_egnn(preds_, coords_)[0].squeeze(0) # (R,1)

        return preds

    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ── make sure cfg behaves like a dict no matter what it is
if not isinstance(cfg, dict):
    cfg = vars(cfg)          # SimpleNamespace → ordinary dict

model = ProteinModel(cfg)
print("Trainable parameters:", "{:,}".format(model.n_params()))


from torch.utils.data import Dataset, DataLoader

from train_utils import (set_seed, InMemoryHoodDataset,
                         pad_collate, one_epoch, make_split)



# ── reproducibility
#set_seed(cfg.seed)
# =====================================================================
# 2) dataset / dataloader – *identical* to your earlier code
# =====================================================================

N_NEIGHBORS = cfg.N_NEIGHBORS      # just to keep names aligned
BATCH_SIZE =cfg.batch_size
PIN_MEMORY = False
# ...  (pad_collate, InMemoryHoodDataset)  ...
# ---------------------------------------------------------------
# 2)  file list with optional cap + reproducible shuffle
# ---------------------------------------------------------------
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
if cfg.split_mode == 'random' and cfg.num_paths is not None:
    all_paths = all_paths[:cfg.num_paths]
train_paths, val_paths = make_split(all_paths, cfg)


# ── datasets and loaders
train_ds = InMemoryHoodDataset(train_paths, n_neighbors=cfg.N_NEIGHBORS,
                               pin_memory=False)
val_ds   = InMemoryHoodDataset(val_paths,   n_neighbors=cfg.N_NEIGHBORS,
                               pin_memory=False)

train_loader = DataLoader(
    train_ds, batch_size=cfg.batch_size, shuffle=True,
    collate_fn=lambda b: pad_collate(b, cfg.N_NEIGHBORS, cfg.device),
    num_workers=0, pin_memory=False)

val_loader = DataLoader(
    val_ds, batch_size=cfg.batch_size, shuffle=False,
    collate_fn=lambda b: pad_collate(b, cfg.N_NEIGHBORS, cfg.device),
    num_workers=0, pin_memory=False)



 

# ================================================================
# 3) ── optimisation bits
# ================================================================
criterion = nn.L1Loss() if cfg['loss_type']=='mae' else nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

scaler = GradScaler(enabled=(cfg['device']=='cuda'))
# pick loss & the “other” metric once, outside the loop
if cfg.loss_type.lower() == 'mae':
    primary_loss_fn   = nn.L1Loss()
    secondary_fn      = nn.MSELoss() if cfg.study_metrics else None
    primary_name      = 'MAE'
else:
    primary_loss_fn   = nn.MSELoss()
    secondary_fn      = nn.L1Loss()  if cfg.study_metrics else None
    primary_name      = 'MSE'

def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    primary_sum, secondary_sum, n = 0.0, 0.0, 0
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg.device)
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg.device)
        y_r   = y.view(-1)[valid].to(cfg.device)

        with autocast(enabled=(cfg.device=='cuda')):
            preds   = model(z_r, x_r).flatten()
            loss    = primary_loss_fn(preds, y_r)
            primary = loss.item()
            if secondary_fn is not None:
                secondary = secondary_fn(preds, y_r).item()
            else:
                secondary = 0.0

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        primary_sum   += primary
        secondary_sum += secondary
        n += 1

    return primary_sum/n, (secondary_sum/n if secondary_fn is not None else None)

# ================================================================
# 4) ── train / validate
# ================================================================
def epoch_loop(loader, train):
    if train: model.train()
    else:     model.eval()

    mae_acc, mse_acc = [], []
    for z,x,y,mask in loader:
        valid = mask.view(-1)
        z_r   = z.view(-1, z.size(2))[valid].to(cfg['device'])
        x_r   = x.view(-1, x.size(2), 3)[valid].to(cfg['device'])
        y_r   = y.view(-1)[valid].to(cfg['device'])

        with autocast(enabled=(cfg['device']=='cuda')):
            out  = model(z_r, x_r).flatten()
            loss = criterion(out, y_r)
            mae  = nn.L1Loss()(out, y_r).item()
            mse  = nn.MSELoss()(out, y_r).item()

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

        mae_acc.append(mae); mse_acc.append(mse)
    return sum(mae_acc)/len(mae_acc), sum(mse_acc)/len(mse_acc)

for ep in range(cfg.epochs):
    tr_primary, tr_sec = epoch_loop(train_loader, True)
    vl_primary, vl_sec = epoch_loop(val_loader, False)

    scheduler.step(vl_primary if cfg.sched_metric=='val_'+primary_name.lower()
                   else vl_sec)

    msg = "[{}/{}] train {} {:.4f} | val {} {:.4f}".format(
            ep+1, cfg.epochs, primary_name, tr_primary, primary_name, vl_primary)
    if cfg.study_metrics and tr_sec is not None:
        other = 'MSE' if primary_name=='MAE' else 'MAE'
        msg += "  ||  train {} {:.4f} | val {} {:.4f}".format(
                other, tr_sec, other, vl_sec)
    print(msg)

# ================================================================
# 5) ── save everything needed to resume
# ================================================================
ckpt = {
    'model_state': model.state_dict(),
    'optim_state': optimizer.state_dict(),
    'sched_state': scheduler.state_dict(),
    'cfg'        : cfg,
}
ckpt_name = "ckpt_{}.pt".format(cfg['runid'])
#torch.save(ckpt, ckpt_name)
print("Saved checkpoint:", ckpt_name)





Run‑ID: 20250726_214548
{'dim': 12, 'basis': 6, 'depth': 2, 'hidden_dim': 4, 'num_neighbors': 8, 'dropout': 0.02, 'norm_coors': True, 'N_NEIGHBORS': 100, 'aggregator': 'linear', 'use_rbf': True, 'use_attn': True, 'use_nconv': False, 'use_pred_head': True, 'loss_type': 'mae', 'sched_metric': 'val_mae', 'study_metrics': True, 'lr': 0.005, 'epochs': 5, 'batch_size': 1, 'device': 'cpu', 'runid': '20250726_214548', 'seed': 0, 'save_attn': False, 'num_paths': 2, 'split_mode': 'random', 'split_ratio': 0.8, 'split_seed': 0, 'split_files': {'train': 'train_list.txt', 'val': 'val_list.txt'}}
Trainable parameters: 17,695


ValueError: too many values to unpack (expected 4)

In [5]:
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader


In [9]:
# ================================================================
# 0) dashboard  – tweak here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    num_neighbors=8, hood_k=100,                      # <── hood size
    aggregator='linear', use_rbf=True, use_attn=True,
    use_nconv=False, use_pred_head=True,
    norm_coors=True,

    loss_type='mae', sched_metric='val_mae', study_metrics=True,
    lr=5e-3, epochs=5, batch_size=1, device='cpu',
    seed=0, analysis_mode=False, save_attn=False,      # <── new switch
    num_paths=4,                                   # cap #files
    split_mode='random', split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed);  np.random.seed(cfg.seed);  torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) helper: split + dataset + collate
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class InMemDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k, algorithm='brute')
        for p in paths:
            try:
                d=np.load(p, allow_pickle=True)
                z,pos,sites,y=d['z'],d['pos'],d['sites'],d['pks']
                if len(sites)==0: continue
                nbr.fit(pos); idx=nbr.kneighbors(sites,return_distance=False)
                self.data.append((torch.from_numpy(z[idx]),
                                  torch.from_numpy(pos[idx]),
                                  torch.from_numpy(y)))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,pos,y=self.data[i]
        return (z,pos,y,self.ids[i])

def pad(batch, k, device, return_ids):
    if return_ids: ids=[b[3] for b in batch]
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zs=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pos=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    ys=torch.full((B,S),float('nan'),device=device)
    mask=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zs[b,:s]=z.to(device); pos[b,:s]=p.to(device)
        ys[b,:s]=y.to(device); mask[b,:s]=True
    return (zs,pos,ys,mask,ids) if return_ids else (zs,pos,ys,mask)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model (unchanged architecture.py)
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
class Model(nn.Module):
    def __init__(self,c):
        super().__init__()
        self.egnn=StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                              c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)
        self.rbf =TunableBlock(LearnableRBF(c.basis,10.).to(c.device),c.use_rbf)
        self.attn=TunableBlock(AttentionBlock(c.dim+c.basis,c.dim+c.basis,c.hidden_dim)\
                               .to(c.device),c.use_attn)
        if c.aggregator.startswith('nconv'):
            self.nconv=nn.Conv1d(c.hood_k,1,c.dim+c.basis).to(c.device)
            out=1
        else: self.nconv=None; out=c.dim+c.basis
        self.head=nn.Linear(out,1).to(c.device) if c.use_pred_head else nn.Identity()
        self.prot=EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device)
    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]
        cent=coord.mean(1,keepdim=True)
        r=self.rbf(cent,coord) if cfg.use_rbf else h.new_zeros(h.size(0),cfg.hood_k,cfg.basis)
        tok=torch.cat((r.transpose(1,2),h.transpose(1,2)),1)
        tok,_=self.attn(tok.permute(2,0,1)); tok=tok.permute(1,0,2)
        tok=self.nconv(tok).squeeze(-1) if self.nconv is not None else tok.max(1).values
        p=self.head(tok); p=self.prot(p.unsqueeze(0),cent.permute(1,0,2))[0].squeeze(0)
        return p
model=Model(cfg); print("params",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
tr_ds=InMemDS(tr,cfg.hood_k); val_ds=InMemDS(val,cfg.hood_k)
collate=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(tr_ds,batch_size=cfg.batch_size,shuffle=True,collate_fn=collate)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=collate)

# ================================================================
# 5) training util
# ================================================================
primary_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
secondary_fn= nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
              nn.L1Loss() if cfg.study_metrics else None
pname='MAE' if cfg.loss_type=='mae' else 'MSE'; sname='MSE' if pname=='MAE' else 'MAE'
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval()
    ps=ss=0;n=0
    for batch in loader:
        if cfg.analysis_mode: z,x,y,m,ids=batch
        else:                 z,x,y,m=batch
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=primary_fn(pred,y)
            sec=secondary_fn(pred,y).item() if secondary_fn else 0.0
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        ps+=loss.item(); ss+=sec; n+=1
    return ps/n, (ss/n if secondary_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False)
    sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {pname}:{tr_p:.4f} | val {pname}:{va_p:.4f}"
    if secondary_fn: msg+=f" || train {sname}:{tr_s:.4f} | val {sname}:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_214855
params 17695




[1/5] train MAE:1.3717 | val MAE:1.0425 || train MSE:3.8119 | val MSE:2.1939
[2/5] train MAE:1.3013 | val MAE:1.1405 || train MSE:3.7406 | val MSE:2.5144
[3/5] train MAE:1.3006 | val MAE:1.0161 || train MSE:3.7640 | val MSE:2.1072
[4/5] train MAE:1.2782 | val MAE:1.0175 || train MSE:3.6054 | val MSE:2.1122
[5/5] train MAE:1.2581 | val MAE:1.0355 || train MSE:3.5746 | val MSE:2.1833


In [10]:
# ================================================================
# 0) dashboard  – tweak here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    num_neighbors=8, hood_k=100,                      # <── hood size
    aggregator='linear', use_rbf=True, use_attn=True,
    use_nconv=False, use_pred_head=True,
    norm_coors=True,

    loss_type='mae', sched_metric='val_mae', study_metrics=True,
    lr=5e-3, epochs=5, batch_size=1, device='cpu',
    seed=0, analysis_mode=False, save_attn=False,      # <── new switch
    num_paths=4,                                   # cap #files
    split_mode='random', split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed);  np.random.seed(cfg.seed);  torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) helper: split + dataset + collate
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class InMemDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k, algorithm='brute')
        for p in paths:
            try:
                d=np.load(p, allow_pickle=True)
                z,pos,sites,y=d['z'],d['pos'],d['sites'],d['pks']
                if len(sites)==0: continue
                nbr.fit(pos); idx=nbr.kneighbors(sites,return_distance=False)
                self.data.append((torch.from_numpy(z[idx]),
                                  torch.from_numpy(pos[idx]),
                                  torch.from_numpy(y)))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,pos,y=self.data[i]
        return (z,pos,y,self.ids[i])

def pad(batch, k, device, return_ids):
    if return_ids: ids=[b[3] for b in batch]
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zs=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pos=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    ys=torch.full((B,S),float('nan'),device=device)
    mask=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zs[b,:s]=z.to(device); pos[b,:s]=p.to(device)
        ys[b,:s]=y.to(device); mask[b,:s]=True
    return (zs,pos,ys,mask,ids) if return_ids else (zs,pos,ys,mask)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model (unchanged architecture.py)
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
class Model(nn.Module):
    def __init__(self,c):
        super().__init__()
        self.egnn=StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                              c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)
        self.rbf =TunableBlock(LearnableRBF(c.basis,10.).to(c.device),c.use_rbf)
        self.attn=TunableBlock(AttentionBlock(c.dim+c.basis,c.dim+c.basis,c.hidden_dim)\
                               .to(c.device),c.use_attn)
        if c.aggregator.startswith('nconv'):
            self.nconv=nn.Conv1d(c.hood_k,1,c.dim+c.basis).to(c.device)
            out=1
        else: self.nconv=None; out=c.dim+c.basis
        self.head=nn.Linear(out,1).to(c.device) if c.use_pred_head else nn.Identity()
        self.prot=EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device)
    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]
        cent=coord.mean(1,keepdim=True)
        r=self.rbf(cent,coord) if cfg.use_rbf else h.new_zeros(h.size(0),cfg.hood_k,cfg.basis)
        tok=torch.cat((r.transpose(1,2),h.transpose(1,2)),1)
        tok,_=self.attn(tok.permute(2,0,1)); tok=tok.permute(1,0,2)
        tok=self.nconv(tok).squeeze(-1) if self.nconv is not None else tok.max(1).values
        p=self.head(tok); p=self.prot(p.unsqueeze(0),cent.permute(1,0,2))[0].squeeze(0)
        return p
model=Model(cfg); print("params",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
tr_ds=InMemDS(tr,cfg.hood_k); val_ds=InMemDS(val,cfg.hood_k)
collate=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(tr_ds,batch_size=cfg.batch_size,shuffle=True,collate_fn=collate)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=collate)

# ================================================================
# 5) training util
# ================================================================
primary_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
secondary_fn= nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
              nn.L1Loss() if cfg.study_metrics else None
pname='MAE' if cfg.loss_type=='mae' else 'MSE'; sname='MSE' if pname=='MAE' else 'MAE'
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval()
    ps=ss=0;n=0
    for batch in loader:
        if cfg.analysis_mode: z,x,y,m,ids=batch
        else:                 z,x,y,m=batch
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=primary_fn(pred,y)
            sec=secondary_fn(pred,y).item() if secondary_fn else 0.0
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        ps+=loss.item(); ss+=sec; n+=1
    return ps/n, (ss/n if secondary_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False)
    sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {pname}:{tr_p:.4f} | val {pname}:{va_p:.4f}"
    if secondary_fn: msg+=f" || train {sname}:{tr_s:.4f} | val {sname}:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_214916
params 17695




[1/5] train MAE:1.3717 | val MAE:1.0425 || train MSE:3.8119 | val MSE:2.1939
[2/5] train MAE:1.3013 | val MAE:1.1405 || train MSE:3.7406 | val MSE:2.5144
[3/5] train MAE:1.3006 | val MAE:1.0161 || train MSE:3.7640 | val MSE:2.1072
[4/5] train MAE:1.2782 | val MAE:1.0175 || train MSE:3.6054 | val MSE:2.1122
[5/5] train MAE:1.2581 | val MAE:1.0355 || train MSE:3.5746 | val MSE:2.1833
