In [3]:
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

# ================================================================
# 0) dashboard  – tweak here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    num_neighbors=8, hood_k=100,                      # <── hood size
    aggregator='linear', use_rbf=False, use_attn=False,
    use_nconv=False, use_pred_head=True,
    norm_coors=True,

    loss_type='mae', sched_metric='val_mae', study_metrics=True,
    lr=5e-3, epochs=5, batch_size=1, device='cpu',
    seed=0, analysis_mode=False, save_attn=False,      # <── new switch
    num_paths=4,                                   # cap #files
    split_mode='random', split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed);  np.random.seed(cfg.seed);  torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) helper: split + dataset + collate
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class InMemDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k, algorithm='brute')
        for p in paths:
            try:
                d=np.load(p, allow_pickle=True)
                z,pos,sites,y=d['z'],d['pos'],d['sites'],d['pks']
                if len(sites)==0: continue
                nbr.fit(pos); idx=nbr.kneighbors(sites,return_distance=False)
                self.data.append((torch.from_numpy(z[idx]),
                                  torch.from_numpy(pos[idx]),
                                  torch.from_numpy(y)))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,pos,y=self.data[i]
        return (z,pos,y,self.ids[i])

def pad(batch, k, device, return_ids):
    if return_ids: ids=[b[3] for b in batch]
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zs=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pos=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    ys=torch.full((B,S),float('nan'),device=device)
    mask=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zs[b,:s]=z.to(device); pos[b,:s]=p.to(device)
        ys[b,:s]=y.to(device); mask[b,:s]=True
    return (zs,pos,ys,mask,ids) if return_ids else (zs,pos,ys,mask)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model (unchanged architecture.py)
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
class Model(nn.Module):
    def __init__(self,c):
        super().__init__()
        self.egnn=StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                              c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)
        self.rbf =TunableBlock(LearnableRBF(c.basis,10.).to(c.device),c.use_rbf)
        self.attn=TunableBlock(AttentionBlock(c.dim+c.basis,c.dim+c.basis,c.hidden_dim)\
                               .to(c.device),c.use_attn)
        if c.aggregator.startswith('nconv'):
            self.nconv=nn.Conv1d(c.hood_k,1,c.dim+c.basis).to(c.device)
            out=1
        else: self.nconv=None; out=c.dim+c.basis
        self.head=nn.Linear(out,1).to(c.device) if c.use_pred_head else nn.Identity()
        self.prot=EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device)
    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]
        cent=coord.mean(1,keepdim=True)
        r=self.rbf(cent,coord) if cfg.use_rbf else h.new_zeros(h.size(0),cfg.hood_k,cfg.basis)
        tok=torch.cat((r.transpose(1,2),h.transpose(1,2)),1)
        tok,_=self.attn(tok.permute(2,0,1)); tok=tok.permute(1,0,2)
        tok=self.nconv(tok).squeeze(-1) if self.nconv is not None else tok.max(1).values
        p=self.head(tok); p=self.prot(p.unsqueeze(0),cent.permute(1,0,2))[0].squeeze(0)
        return p
model=Model(cfg); print("params",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
tr_ds=InMemDS(tr,cfg.hood_k); val_ds=InMemDS(val,cfg.hood_k)
collate=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(tr_ds,batch_size=cfg.batch_size,shuffle=True,collate_fn=collate)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=collate)

# ================================================================
# 5) training util
# ================================================================
primary_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
secondary_fn= nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
              nn.L1Loss() if cfg.study_metrics else None
pname='MAE' if cfg.loss_type=='mae' else 'MSE'; sname='MSE' if pname=='MAE' else 'MAE'
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval()
    ps=ss=0;n=0
    for batch in loader:
        if cfg.analysis_mode: z,x,y,m,ids=batch
        else:                 z,x,y,m=batch
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=primary_fn(pred,y)
            sec=secondary_fn(pred,y).item() if secondary_fn else 0.0
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        ps+=loss.item(); ss+=sec; n+=1
    return ps/n, (ss/n if secondary_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False)
    sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {pname}:{tr_p:.4f} | val {pname}:{va_p:.4f}"
    if secondary_fn: msg+=f" || train {sname}:{tr_s:.4f} | val {sname}:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_215148
params 17695




[1/5] train MAE:1.3717 | val MAE:1.0425 || train MSE:3.8119 | val MSE:2.1939
[2/5] train MAE:1.3013 | val MAE:1.1405 || train MSE:3.7406 | val MSE:2.5144
[3/5] train MAE:1.3006 | val MAE:1.0161 || train MSE:3.7640 | val MSE:2.1072
[4/5] train MAE:1.2782 | val MAE:1.0175 || train MSE:3.6054 | val MSE:2.1122
[5/5] train MAE:1.2581 | val MAE:1.0355 || train MSE:3.5746 | val MSE:2.1833


In [4]:
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

# ================================================================
# 0) dashboard  – tweak here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    num_neighbors=8, hood_k=100,                      # <── hood size
    aggregator='linear', use_rbf=True, use_attn=False,
    use_nconv=False, use_pred_head=True,
    norm_coors=True,

    loss_type='mae', sched_metric='val_mae', study_metrics=True,
    lr=5e-3, epochs=5, batch_size=1, device='cpu',
    seed=0, analysis_mode=False, save_attn=False,      # <── new switch
    num_paths=4,                                   # cap #files
    split_mode='random', split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed);  np.random.seed(cfg.seed);  torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) helper: split + dataset + collate
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class InMemDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k, algorithm='brute')
        for p in paths:
            try:
                d=np.load(p, allow_pickle=True)
                z,pos,sites,y=d['z'],d['pos'],d['sites'],d['pks']
                if len(sites)==0: continue
                nbr.fit(pos); idx=nbr.kneighbors(sites,return_distance=False)
                self.data.append((torch.from_numpy(z[idx]),
                                  torch.from_numpy(pos[idx]),
                                  torch.from_numpy(y)))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,pos,y=self.data[i]
        return (z,pos,y,self.ids[i])

def pad(batch, k, device, return_ids):
    if return_ids: ids=[b[3] for b in batch]
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zs=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pos=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    ys=torch.full((B,S),float('nan'),device=device)
    mask=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zs[b,:s]=z.to(device); pos[b,:s]=p.to(device)
        ys[b,:s]=y.to(device); mask[b,:s]=True
    return (zs,pos,ys,mask,ids) if return_ids else (zs,pos,ys,mask)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model (unchanged architecture.py)
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
class Model(nn.Module):
    def __init__(self,c):
        super().__init__()
        self.egnn=StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                              c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)
        self.rbf =TunableBlock(LearnableRBF(c.basis,10.).to(c.device),c.use_rbf)
        self.attn=TunableBlock(AttentionBlock(c.dim+c.basis,c.dim+c.basis,c.hidden_dim)\
                               .to(c.device),c.use_attn)
        if c.aggregator.startswith('nconv'):
            self.nconv=nn.Conv1d(c.hood_k,1,c.dim+c.basis).to(c.device)
            out=1
        else: self.nconv=None; out=c.dim+c.basis
        self.head=nn.Linear(out,1).to(c.device) if c.use_pred_head else nn.Identity()
        self.prot=EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device)
    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]
        cent=coord.mean(1,keepdim=True)
        r=self.rbf(cent,coord) if cfg.use_rbf else h.new_zeros(h.size(0),cfg.hood_k,cfg.basis)
        tok=torch.cat((r.transpose(1,2),h.transpose(1,2)),1)
        tok,_=self.attn(tok.permute(2,0,1)); tok=tok.permute(1,0,2)
        tok=self.nconv(tok).squeeze(-1) if self.nconv is not None else tok.max(1).values
        p=self.head(tok); p=self.prot(p.unsqueeze(0),cent.permute(1,0,2))[0].squeeze(0)
        return p
model=Model(cfg); print("params",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
tr_ds=InMemDS(tr,cfg.hood_k); val_ds=InMemDS(val,cfg.hood_k)
collate=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(tr_ds,batch_size=cfg.batch_size,shuffle=True,collate_fn=collate)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=collate)

# ================================================================
# 5) training util
# ================================================================
primary_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
secondary_fn= nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
              nn.L1Loss() if cfg.study_metrics else None
pname='MAE' if cfg.loss_type=='mae' else 'MSE'; sname='MSE' if pname=='MAE' else 'MAE'
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval()
    ps=ss=0;n=0
    for batch in loader:
        if cfg.analysis_mode: z,x,y,m,ids=batch
        else:                 z,x,y,m=batch
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=primary_fn(pred,y)
            sec=secondary_fn(pred,y).item() if secondary_fn else 0.0
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        ps+=loss.item(); ss+=sec; n+=1
    return ps/n, (ss/n if secondary_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False)
    sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {pname}:{tr_p:.4f} | val {pname}:{va_p:.4f}"
    if secondary_fn: msg+=f" || train {sname}:{tr_s:.4f} | val {sname}:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_215321
params 17695




ValueError: too many values to unpack (expected 2)

In [15]:
# ================================================================
# 0) dashboard  – change anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    num_neighbors=8, hood_k=100, norm_coors=True,
    # blocks
    aggregator='linear', use_rbf=False, use_attn=False,
    use_nconv=True, use_pred_head=False,
    # training
    loss_type='mae', study_metrics=True, lr=5e-3, epochs=3, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    # reproducibility & misc
    seed=0, analysis_mode=False, save_attn=False,
    num_paths=2,
    split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
# ================================================================
# 3)  fully‑tunable Model (drop‑in replacement)
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn, torch

class Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.c = cfg                          # keep a local copy

        # ---------- residue‑level EGNN backbone -------------------
        self.egnn = StackedEGNN(
            cfg.dim, cfg.depth, cfg.hidden_dim, cfg.dropout,
            cfg.hood_k, 98, cfg.num_neighbors, cfg.get('norm_coors', True)
        ).to(cfg.device)

        # ---------- optional blocks --------------------------------
        self.rbf  = TunableBlock(LearnableRBF(cfg.basis, 10.).to(cfg.device),
                                 cfg.use_rbf)
        self.attn = TunableBlock(AttentionBlock(cfg.dim + cfg.basis,
                                                cfg.dim + cfg.basis,
                                                cfg.hidden_dim).to(cfg.device),
                                 cfg.use_attn)

        # ---------- neighbour → scalar aggregators -----------------
        C = cfg.dim + cfg.basis                        # feature channels
        if cfg.aggregator == 'linear':
            self.agg = nn.Linear(C, 1).to(cfg.device)
        elif cfg.aggregator == 'nconv':
            # in‑channels = hood_k, kernel = C   (N,C) layout
            self.agg = nn.Conv1d(cfg.hood_k, 1, kernel_size=C, padding=0).to(cfg.device)
        elif cfg.aggregator == 'pool':                 # max‑pool *without* linear
            self.agg = None
        else:
            raise ValueError("aggregator must be 'linear', 'nconv', or 'pool'")

        # ---------- optional boost  (pred_head2) -------------------
        self.boost = nn.Linear(1, 1).to(cfg.device) if cfg.get('use_boost', False) else nn.Identity()

        # ---------- protein‑level EGNN & final conv ----------------
        self.use_prot = cfg.get('use_prot_egnn', True)
        self.prot = (EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)
                     .to(cfg.device)) if self.use_prot else nn.Identity()

        self.use_conv = cfg.get('use_conv', False)
        if self.use_conv:
            k = cfg.get('conv_kernel', 7)
            self.conv = nn.Conv1d(1, 1, k, padding=k // 2).to(cfg.device)
        else:
            self.conv = nn.Identity()

    # --------------------------------------------------------------
    def forward(self, z, x):                       # z:(R,N) , x:(R,N,3)
        h, coord = self.egnn(z, x); h = h[0]       # (R,N,dim)

        cent = coord.mean(1, keepdim=True)         # (R,1,3)
        if self.c.use_rbf:
            r = self.rbf(cent, coord).transpose(1, 2)          # (R,basis,N)
        else:
            r = h.new_zeros(h.size(0), self.c.basis, self.c.hood_k)

        tok = torch.cat((r, h.transpose(1, 2)), 1)             # (R,C,N)

        # ---- attention (tuple‑safe) ------------------------------
        out = self.attn(tok.permute(2, 0, 1))
        tok = out[0] if isinstance(out, (tuple, list)) else out
        tok = tok.permute(1, 0, 2)                              # (R,N,C)

        # ---- aggregation paths -----------------------------------
        if self.c.aggregator == 'linear':
            pooled = tok.max(1).values                          # (R,C)
            preds  = self.agg(pooled)                           # (R,1)
        elif self.c.aggregator == 'nconv':
            preds  = self.agg(tok).squeeze(-1)                  # (R,1)
        else:                                                   # 'pool'
            preds  = tok.max(1).values.unsqueeze(-1)            # (R,1)

        # optional boost
        preds = self.boost(preds)

        # protein‑level EGNN
        if self.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1, 0, 2))[0].squeeze(0)

        # final conv (1×k across residues)
        if self.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T  # (R,1)

        return preds

model=Model(cfg)
cfg.update(dict(
    use_rbf       = False,
    use_attn      = False,
    aggregator    = 'nconv',      # 'linear', 'nconv', 'pool'
    use_boost     = False,        # pred_head2
    use_prot_egnn = True,
    use_conv      = False
))
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn=nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
s_fn=nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
     nn.L1Loss() if cfg.study_metrics else None
pname='MAE' if cfg.loss_type=='mae' else 'MSE'; sname='MSE' if pname=='MAE' else 'MAE'
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval(); P=S=0;n=0
    for batch in loader:
        z,x,y,m,*rest=batch        # rest = [ids] or []
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y); sec=s_fn(pred,y).item() if s_fn else 0.
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        P+=loss.item(); S+=sec; n+=1
    return P/n, (S/n if s_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False); sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {pname}:{tr_p:.4f} | val {pname}:{va_p:.4f}"
    if s_fn: msg+=f" || train {sname}:{tr_s:.4f} | val {sname}:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_221026
params: 17695




RuntimeError: mat1 and mat2 shapes cannot be multiplied (270x201 and 3x6)

In [13]:
# ================================================================
# 0) dashboard  – change anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    num_neighbors=8, hood_k=100, norm_coors=True,
    # blocks
    aggregator='linear', use_rbf=False, use_attn=False,
    use_nconv=True, use_pred_head=False,
    # training
    loss_type='mae', study_metrics=True, lr=5e-3, epochs=3, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    # reproducibility & misc
    seed=0, analysis_mode=False, save_attn=False,
    num_paths=2,
    split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
# ================================================================
# 3)  fully‑tunable Model
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.c = cfg                    # keep a local handle

        # ── residue‑level EGNN backbone
        self.egnn = StackedEGNN(cfg.dim, cfg.depth, cfg.hidden_dim, cfg.dropout,
                                cfg.hood_k, 98, cfg.num_neighbors,
                                cfg.get('norm_coors', True)
                               ).to(cfg.device)

        # ── optional blocks
        self.rbf  = TunableBlock(LearnableRBF(cfg.basis, 10.).to(cfg.device),
                                 cfg.use_rbf)

        self.attn = TunableBlock(AttentionBlock(cfg.dim+cfg.basis,
                                                cfg.dim+cfg.basis,
                                                cfg.hidden_dim).to(cfg.device),
                                 cfg.use_attn)

        # ── neighbours → scalar aggregator
        if cfg.aggregator.startswith('nconv'):
            self.nconv = nn.Conv1d(cfg.hood_k, 1,
                                   kernel_size=cfg.dim+cfg.basis,
                                   padding=0).to(cfg.device)
            agg_out = 1
        else:
            self.nconv = None
            agg_out    = cfg.dim + cfg.basis

        if cfg.use_pred_head:
            self.head = nn.Linear(agg_out, 1).to(cfg.device)
        else:
            self.head = nn.Identity()

        # ── protein‑level geometry & optional 1‑D conv
        self.use_prot = cfg.get('use_prot_egnn', True)
        self.prot = EGNN(dim=1, update_coors=True,
                         num_nearest_neighbors=3).to(cfg.device) \
                         if self.use_prot else nn.Identity()

        self.use_conv = cfg.get('use_conv', False)
        if self.use_conv:
            k = cfg.get('conv_kernel', 7)
            self.conv = nn.Conv1d(1, 1, k, padding=k//2).to(cfg.device)
        else:
            self.conv = nn.Identity()

    # -----------------------------------------------------------
    def forward(self, z, x):                       # z:(R,N)  x:(R,N,3)
        h, coord = self.egnn(z, x); h = h[0]       # h:(R,N,dim)

        cent = coord.mean(1, keepdim=True)         # (R,1,3)
        if self.c.use_rbf:
            r = self.rbf(cent, coord).transpose(1, 2)   # (R, basis, N)
        else:
            r = h.new_zeros(h.size(0), self.c.basis, self.c.hood_k)

        tok = torch.cat((r, h.transpose(1, 2)), 1)      # (R, C, N)

        # ---- attention (handles tuple vs tensor automatically)
        attn_out = self.attn(tok.permute(2, 0, 1))
        tok = attn_out[0] if isinstance(attn_out, (tuple, list)) else attn_out
        tok = tok.permute(1, 0, 2)                      # (R, N, C)

        # ---- aggregator paths
        if self.nconv is not None:                      # nconv or nconv+linear
            tok = self.nconv(tok).squeeze(-1)           # (R, 1)
        elif self.c.aggregator == 'pool':               # max‑pool
            tok = tok.max(1).values                     # (R, C)
        else:                                           # linear
            tok = tok.max(1).values                     # (R, C)

        preds = self.head(tok)                          # (R, 1) or unchanged

        # ---- protein‑level EGNN (optional)
        if self.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1, 0, 2))[0].squeeze(0)

        # ---- final 1‑D conv over residues (optional)
        if self.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T  # (R,1)

        return preds

print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn=nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
s_fn=nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
     nn.L1Loss() if cfg.study_metrics else None
pname='MAE' if cfg.loss_type=='mae' else 'MSE'; sname='MSE' if pname=='MAE' else 'MAE'
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval(); P=S=0;n=0
    for batch in loader:
        z,x,y,m,*rest=batch        # rest = [ids] or []
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y); sec=s_fn(pred,y).item() if s_fn else 0.
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        P+=loss.item(); S+=sec; n+=1
    return P/n, (S/n if s_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False); sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {pname}:{tr_p:.4f} | val {pname}:{va_p:.4f}"
    if s_fn: msg+=f" || train {sname}:{tr_s:.4f} | val {sname}:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_220313
params: 17676




RuntimeError: mat1 and mat2 shapes cannot be multiplied (270x37 and 3x6)

In [None]:
model=Model(cfg)
cfg.update(dict(
    use_rbf       = False,
    use_attn      = False,
    aggregator    = 'nconv',      # 'linear', 'nconv', 'pool'
    use_boost     = False,        # pred_head2
    use_prot_egnn = True,
    use_conv      = False
))

In [1]:
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader


  from .autonotebook import tqdm as notebook_tqdm


In [22]:
# ================================================================
# 0) dashboard  – change anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    num_neighbors=8, hood_k=100, norm_coors=True,
    # blocks
    aggregator='linear', use_rbf=False, use_attn=False,
    use_nconv=True, use_pred_head=False,
    # training
    loss_type='mae', study_metrics=True, lr=5e-3, epochs=3, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    # reproducibility & misc
    seed=0, analysis_mode=False, save_attn=False,
    num_paths=2,
    split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
# ================================================================
# 3)  fully‑tunable Model (drop‑in replacement)
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn, torch

class Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.c = cfg                          # keep a local copy

        # ---------- residue‑level EGNN backbone -------------------
        self.egnn = StackedEGNN(
            cfg.dim, cfg.depth, cfg.hidden_dim, cfg.dropout,
            cfg.hood_k, 98, cfg.num_neighbors, cfg.get('norm_coors', True)
        ).to(cfg.device)

        # ---------- optional blocks --------------------------------
        self.rbf  = TunableBlock(LearnableRBF(cfg.basis, 10.).to(cfg.device),
                                 cfg.use_rbf)
        self.attn = TunableBlock(AttentionBlock(cfg.dim + cfg.basis,
                                                cfg.dim + cfg.basis,
                                                cfg.hidden_dim).to(cfg.device),
                                 cfg.use_attn)

        # ---------- neighbour → scalar aggregators -----------------
        C = cfg.dim + cfg.basis                        # feature channels
        if cfg.aggregator == 'linear':
            self.agg = nn.Linear(C, 1).to(cfg.device)
        elif cfg.aggregator == 'nconv':
            # in‑channels = hood_k, kernel = C   (N,C) layout
            self.agg = nn.Conv1d(cfg.hood_k, 1, kernel_size=C, padding=0).to(cfg.device)
        elif cfg.aggregator == 'pool':                 # max‑pool *without* linear
            self.agg = None
        else:
            raise ValueError("aggregator must be 'linear', 'nconv', or 'pool'")

        # ---------- optional boost  (pred_head2) -------------------
        self.boost = nn.Linear(1, 1).to(cfg.device) if cfg.get('use_boost', False) else nn.Identity()

        # ---------- protein‑level EGNN & final conv ----------------
        self.use_prot = cfg.get('use_prot_egnn', True)
        self.prot = (EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)
                     .to(cfg.device)) if self.use_prot else nn.Identity()

        self.use_conv = cfg.get('use_conv', False)
        if self.use_conv:
            k = cfg.get('conv_kernel', 7)
            self.conv = nn.Conv1d(1, 1, k, padding=k // 2).to(cfg.device)
        else:
            self.conv = nn.Identity()

    # --------------------------------------------------------------
    def forward(self, z, x):                       # z:(R,N) , x:(R,N,3)
        h, coord = self.egnn(z, x); h = h[0]       # (R,N,dim)

        cent = coord.mean(1, keepdim=True)         # (R,1,3)
        if self.c.use_rbf:
            r = self.rbf(cent, coord).transpose(1, 2)          # (R,basis,N)
        else:
            r = h.new_zeros(h.size(0), self.c.basis, self.c.hood_k)

        tok = torch.cat((r, h.transpose(1, 2)), 1)             # (R,C,N)

        # ---- attention (tuple‑safe) ------------------------------
        out = self.attn(tok.permute(2, 0, 1))
        tok = out[0] if isinstance(out, (tuple, list)) else out
        tok = tok.permute(1, 0, 2)                              # (R,N,C)

        # ---- aggregation paths -----------------------------------
        if self.c.aggregator == 'linear':
            pooled = tok.max(1).values                          # (R,C)
            preds  = self.agg(pooled)                           # (R,1)
        elif self.c.aggregator == 'nconv':
            preds  = self.agg(tok).squeeze(-1)                  # (R,1)
        else:                                                   # 'pool'
            preds  = tok.max(1).values.unsqueeze(-1)            # (R,1)

        # optional boost
        preds = self.boost(preds)

        # protein‑level EGNN
        if self.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1, 0, 2))[0].squeeze(0)

        # final conv (1×k across residues)
        if self.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T  # (R,1)

        return preds
cfg.update(dict(
    use_rbf       = False,
    use_attn      = False,
    aggregator    = 'linear',   # ← decide before instantiating
    use_boost     = False,
    use_prot_egnn = True,
    use_conv      = False,
))

model=Model(cfg)

print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn=nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
s_fn=nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
     nn.L1Loss() if cfg.study_metrics else None
pname='MAE' if cfg.loss_type=='mae' else 'MSE'; sname='MSE' if pname=='MAE' else 'MAE'
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval(); P=S=0;n=0
    for batch in loader:
        z,x,y,m,*rest=batch        # rest = [ids] or []
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y); sec=s_fn(pred,y).item() if s_fn else 0.
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        P+=loss.item(); S+=sec; n+=1
    return P/n, (S/n if s_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False); sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {pname}:{tr_p:.4f} | val {pname}:{va_p:.4f}"
    if s_fn: msg+=f" || train {sname}:{tr_s:.4f} | val {sname}:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_221357
params: 17695




[1/3] train MAE:1.5816 | val MAE:1.2734 || train MSE:4.3669 | val MSE:2.8816
[2/3] train MAE:1.4030 | val MAE:1.1154 || train MSE:3.7351 | val MSE:2.4087
[3/3] train MAE:1.2910 | val MAE:1.0252 || train MSE:3.3240 | val MSE:2.1439


In [None]:
# ================================================================
# 0) dashboard  – change anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    num_neighbors=8, hood_k=100, norm_coors=True,
    # blocks
    aggregator='pool', use_rbf=False, use_attn=False,
    use_nconv=True, use_pred_head=False,
    # training
    loss_type='mae', study_metrics=True, lr=5e-3, epochs=3, batch_size=1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    # reproducibility & misc
    seed=0, analysis_mode=False, save_attn=False,
    num_paths=2,
    split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
# ================================================================
# 3)  fully‑tunable Model (drop‑in replacement)
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn, torch

class Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.c = cfg                          # keep a local copy

        # ---------- residue‑level EGNN backbone -------------------
        self.egnn = StackedEGNN(
            cfg.dim, cfg.depth, cfg.hidden_dim, cfg.dropout,
            cfg.hood_k, 98, cfg.num_neighbors, cfg.get('norm_coors', True)
        ).to(cfg.device)

        # ---------- optional blocks --------------------------------
        self.rbf  = TunableBlock(LearnableRBF(cfg.basis, 10.).to(cfg.device),
                                 cfg.use_rbf)
        self.attn = TunableBlock(AttentionBlock(cfg.dim + cfg.basis,
                                                cfg.dim + cfg.basis,
                                                cfg.hidden_dim).to(cfg.device),
                                 cfg.use_attn)

        # ---------- neighbour → scalar aggregators -----------------
        C = cfg.dim + cfg.basis   
        
                             # feature channels
        # ---- aggregation paths -----------------------------------------
        if self.c.aggregator == 'linear':
            # (R, C, N)  →  (R, N, C) so Linear sees last‑dim = C
            per_neigh = self.agg(tok.permute(0, 2, 1))     # (R, N, 1)
            preds     = per_neigh.max(1).values            # (R, 1)

        elif self.c.aggregator == 'nconv':
            # nconv expects (R, N, C) layout inside Conv1d
            preds = self.agg(tok).squeeze(-1)              # (R, 1)

        elif self.c.aggregator == 'pool':                  # ablation
            preds = tok.max(1).values.unsqueeze(-1)        # (R, 1)

        else:
            raise ValueError("aggregator must be 'linear', 'nconv', or 'pool'")

        if cfg.aggregator == 'linear':
            self.agg = nn.Linear(C, 1).to(cfg.device)
        elif cfg.aggregator == 'nconv':
            # in‑channels = hood_k, kernel = C   (N,C) layout
            self.agg = nn.Conv1d(cfg.hood_k, 1, kernel_size=C, padding=0).to(cfg.device)
        elif cfg.aggregator == 'pool':                 # max‑pool *without* linear
            self.agg = None
        else:
            raise ValueError("aggregator must be 'linear', 'nconv', or 'pool'")

        # ---------- optional boost  (pred_head2) -------------------
        self.boost = nn.Linear(1, 1).to(cfg.device) if cfg.get('use_boost', False) else nn.Identity()

        # ---------- protein‑level EGNN & final conv ----------------
        self.use_prot = cfg.get('use_prot_egnn', True)
        self.prot = (EGNN(dim=1, update_coors=True, num_nearest_neighbors=3)
                     .to(cfg.device)) if self.use_prot else nn.Identity()

        self.use_conv = cfg.get('use_conv', False)
        if self.use_conv:
            k = cfg.get('conv_kernel', 7)
            self.conv = nn.Conv1d(1, 1, k, padding=k // 2).to(cfg.device)
        else:
            self.conv = nn.Identity()

    # --------------------------------------------------------------
    def forward(self, z, x):                       # z:(R,N) , x:(R,N,3)
        h, coord = self.egnn(z, x); h = h[0]       # (R,N,dim)

        cent = coord.mean(1, keepdim=True)         # (R,1,3)
        if self.c.use_rbf:
            r = self.rbf(cent, coord).transpose(1, 2)          # (R,basis,N)
        else:
            r = h.new_zeros(h.size(0), self.c.basis, self.c.hood_k)

        tok = torch.cat((r, h.transpose(1, 2)), 1)             # (R,C,N)

        # ---- attention (tuple‑safe) ------------------------------
        out = self.attn(tok.permute(2, 0, 1))
        tok = out[0] if isinstance(out, (tuple, list)) else out
        tok = tok.permute(1, 0, 2)                              # (R,N,C)

        # ---- aggregation paths -----------------------------------
        if self.c.aggregator == 'linear':
            pooled = tok.max(1).values                          # (R,C)
            preds  = self.agg(pooled)                           # (R,1)
        elif self.c.aggregator == 'nconv':
            preds  = self.agg(tok).squeeze(-1)                  # (R,1)
        else:                                                   # 'pool'
            preds  = tok.max(1).values.unsqueeze(-1)            # (R,1)

        # optional boost
        preds = self.boost(preds)

        # protein‑level EGNN
        if self.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1, 0, 2))[0].squeeze(0)

        # final conv (1×k across residues)
        if self.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T  # (R,1)

        return preds
cfg.update(dict(
    use_rbf       = False,
    use_attn      = False,
    aggregator    = 'linear',   # ← decide before instantiating
    use_boost     = False,
    use_prot_egnn = True,
    use_conv      = False,
))

model=Model(cfg)

print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn=nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
s_fn=nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
     nn.L1Loss() if cfg.study_metrics else None
pname='MAE' if cfg.loss_type=='mae' else 'MSE'; sname='MSE' if pname=='MAE' else 'MAE'
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval(); P=S=0;n=0
    for batch in loader:
        z,x,y,m,*rest=batch        # rest = [ids] or []
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y); sec=s_fn(pred,y).item() if s_fn else 0.
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        P+=loss.item(); S+=sec; n+=1
    return P/n, (S/n if s_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False); sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {pname}:{tr_p:.4f} | val {pname}:{va_p:.4f}"
    if s_fn: msg+=f" || train {sname}:{tr_s:.4f} | val {sname}:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_221734
params: 17695




[1/3] train MAE:1.5816 | val MAE:1.2734 || train MSE:4.3669 | val MSE:2.8816
[2/3] train MAE:1.4030 | val MAE:1.1154 || train MSE:3.7351 | val MSE:2.4087
[3/3] train MAE:1.2910 | val MAE:1.0252 || train MSE:3.3240 | val MSE:2.1439


In [29]:
# ================================================================
# 0) dashboard – flip flags here, nothing else
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,

    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =False,
    use_boost    =False,      # Linear(1→1) after aggregator
    use_prot     =True,       # protein‑level EGNN
    use_conv     =False,      # 1‑D conv after prot EGNN
    conv_kernel  =7,          # kernel size for the conv

    # training
    loss_type='mae', study_metrics=True,
    lr=5e-3, epochs=3, batch_size=1,

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis                       # per‑atom channel count

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = (nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device)
                      if c.use_conv else nn.Identity())

    # -----------------------------------------------------------
    def forward(self,z,x):                      # z:(R,N)  x:(R,N,3)
        h,coord=self.egnn(z,x); h=h[0]          # h:(R,N,dim)
        cent=coord.mean(1,keepdim=True)         # (R,1,3)

        if self.c.use_rbf:
            r=self.rbf(cent,coord).transpose(1,2)            # (R,basis,N)
        else:
            r=h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)

        tok=torch.cat((r,h.transpose(1,2)),1)                # (R,C,N)

        attn_out=self.attn(tok.permute(2,0,1))
        tok     = attn_out[0] if isinstance(attn_out,(list,tuple)) else attn_out
        tok     = tok.permute(1,0,2)                         # (R,N,C)

        # --- aggregation --------------------------------------
        if self.c.aggregator=='linear':
            per_neigh=self.agg(tok)                          # (R,N,1)
            preds    = per_neigh.max(1).values               # (R,1)
        elif self.c.aggregator=='nconv':
            preds=self.agg(tok.transpose(1,2)).squeeze(-1)   # (R,1)
        else:  # 'pool'
            preds=tok.max(1).values.unsqueeze(-1)            # (R,1)

        preds=self.boost(preds)                              # optional boost

        if self.c.use_prot:
            preds=self.prot(preds.unsqueeze(0),
                            cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds=self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds
model=Model(cfg); print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn=nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
s_fn=nn.MSELoss() if cfg.study_metrics and cfg.loss_type=='mae' else \
     nn.L1Loss() if cfg.study_metrics else None
opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval(); P=S=0;n=0
    for batch in loader:
        z,x,y,m,*_ = batch
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
            sec=s_fn(pred,y).item() if s_fn else 0.
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        P+=loss.item(); S+=sec; n+=1
    return P/n, (S/n if s_fn else None)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False); sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {cfg.loss_type.upper()}:{tr_p:.4f} | " \
        f"val {cfg.loss_type.upper()}:{va_p:.4f}"
    if s_fn: msg+=f" || train other:{tr_s:.4f} | val other:{va_s:.4f}"
    print(msg)


Run‑ID: 20250726_222332
params: 17695




[1/3] train MAE:1.5971 | val MAE:1.2026 || train other:3.9319 | val other:2.4462
[2/3] train MAE:1.4604 | val MAE:1.0831 || train other:3.5514 | val other:2.2081
[3/3] train MAE:1.3527 | val MAE:1.0277 || train other:3.2811 | val other:2.0960


In [31]:
cfg.aggregator

'linear'

In [None]:
cf.

In [35]:
# n‑conv without RBF / Attn
cfg.aggregator = 'linear'
cfg.use_rbf = cfg.use_attn = False
cfg.use_boost = False
# re‑instantiate:
model = Model(cfg)


for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False); sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {cfg.loss_type.upper()}:{tr_p:.4f} | " \
        f"val {cfg.loss_type.upper()}:{va_p:.4f}"
    if s_fn: msg+=f" || train other:{tr_s:.4f} | val other:{va_s:.4f}"
    print(msg)




[1/3] train MAE:1.2685 | val MAE:1.0217 || train other:3.1997 | val other:2.0759
[2/3] train MAE:1.2685 | val MAE:1.0217 || train other:3.1997 | val other:2.0759
[3/3] train MAE:1.2685 | val MAE:1.0217 || train other:3.1997 | val other:2.0759


In [34]:
# n‑conv without RBF / Attn
cfg.aggregator = 'linear'
cfg.use_rbf = cfg.use_attn = True
cfg.use_boost = True
# re‑instantiate:
model = Model(cfg)


for e in range(cfg.epochs):
    tr_p,tr_s=run(tr_loader,True)
    va_p,va_s=run(va_loader,False); sch.step(va_p)
    msg=f"[{e+1}/{cfg.epochs}] train {cfg.loss_type.upper()}:{tr_p:.4f} | " \
        f"val {cfg.loss_type.upper()}:{va_p:.4f}"
    if s_fn: msg+=f" || train other:{tr_s:.4f} | val other:{va_s:.4f}"
    print(msg)




[1/3] train MAE:1.2776 | val MAE:1.0904 || train other:3.3217 | val other:2.3500
[2/3] train MAE:1.2782 | val MAE:1.0904 || train other:3.3261 | val other:2.3500
[3/3] train MAE:1.2790 | val MAE:1.0904 || train other:3.3255 | val other:2.3500


In [36]:
# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,

    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae', study_metrics=False,
    lr=5e-3, epochs=3, batch_size=1,

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.8, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)

# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================
from architecture import StackedEGNN, LearnableRBF, AttentionBlock, TunableBlock
from egnn_pytorch import EGNN
import torch.nn as nn
class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator=='linear':
            preds = self.agg(tok) .max(1).values                # (R,1)
        elif self.c.aggregator=='nconv':
            preds = self.agg(tok).squeeze(-1)                   # (R,1)
        else:   # pool
            preds = tok.max(1).values.mean(1,keepdim=True)      # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds
model=Model(cfg); print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        loss_sum+=loss.item(); n+=1
    return loss_sum/n

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(tr_loader,True)
    va=run(va_loader,False); sch.step(va)
    print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")


Run‑ID: 20250726_222839
params: 17695




[1/3]  train 1.5971 | val 1.2026
[2/3]  train 1.4604 | val 1.0831
[3/3]  train 1.3527 | val 1.0277
