In [None]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator=='linear':
            preds = self.agg(tok) .max(1).values                # (R,1)
        elif self.c.aggregator=='nconv':
            preds = self.agg(tok).squeeze(-1)                   # (R,1)
        else:   # pool
            preds = tok.max(1).values.mean(1,keepdim=True)      # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,

    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='nconv',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, epochs=10, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=20, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

  from .autonotebook import tqdm as notebook_tqdm


Run‑ID: 20250727_004152
params: 19477




[1/10]  train mae 2.0621 | mae val 1.2336
     additional metrics:  [1/10]  train 7.2658 | val 2.8827



KeyboardInterrupt: 

In [8]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator=='linear':
            preds = self.agg(tok) .max(1).values                # (R,1)
        elif self.c.aggregator=='nconv':

            preds = self.agg(tok).squeeze(-1)                   # (R,1)
            #print(preds.shape,to)
        else:   # pool
            preds = tok.max(1).values.mean(1,keepdim=True)      # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,

    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, epochs=10, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=20, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

Run‑ID: 20250727_005810
params: 17695




[1/10]  train mae 1.2886 | mae val 1.1465
     additional metrics:  [1/10]  train 3.5299 | val 3.1183



KeyboardInterrupt: 

In [4]:
# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")




[batch 0  rot 0]  max|Δ|=1.550e-06
[batch 0  rot 1]  max|Δ|=1.788e-06
[batch 0  rot 2]  max|Δ|=1.431e-06
[batch 0  rot 3]  max|Δ|=1.788e-06
[batch 0  rot 4]  max|Δ|=1.907e-06
[batch 0  perm‑K ]  max|Δ|=2.034e+00
[batch 0  perm‑R ]  max|Δ|=2.384e-07
-------------------------------------------------------
[batch 1  rot 0]  max|Δ|=2.623e-06
[batch 1  rot 1]  max|Δ|=2.861e-06
[batch 1  rot 2]  max|Δ|=3.457e-06
[batch 1  rot 3]  max|Δ|=2.265e-06
[batch 1  rot 4]  max|Δ|=2.265e-06
[batch 1  perm‑K ]  max|Δ|=1.887e+00
[batch 1  perm‑R ]  max|Δ|=0.000e+00
-------------------------------------------------------
[batch 2  rot 0]  max|Δ|=2.384e-06
[batch 2  rot 1]  max|Δ|=2.146e-06
[batch 2  rot 2]  max|Δ|=2.623e-06
[batch 2  rot 3]  max|Δ|=1.848e-06
[batch 2  rot 4]  max|Δ|=2.623e-06
[batch 2  perm‑K ]  max|Δ|=2.301e+00
[batch 2  perm‑R ]  max|Δ|=0.000e+00
-------------------------------------------------------

----------------  summary (max abs error) ----------------
eqv   : 3.457e-06
permK

In [12]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator=='linear':
            preds = self.agg(tok) .max(1).values                # (R,1)
        elif self.c.aggregator=='nconv':

            preds = self.agg(tok).squeeze(-1)                   # (R,1)
            #print(preds.shape,to)
        else:   # pool
            preds = tok.max(1).values.mean(1,keepdim=True)      # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=1,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


Run‑ID: 20250727_010144
params: 17695




[1/1]  train mae 1.6055 | mae val 1.2054
     additional metrics:  [1/1]  train 3.9628 | val 2.4540

0.7018108367919922 sec


[batch 0  rot 0]  max|Δ|=9.537e-07
[batch 0  rot 1]  max|Δ|=9.537e-07
[batch 0  rot 2]  max|Δ|=9.537e-07
[batch 0  rot 3]  max|Δ|=1.311e-06
[batch 0  rot 4]  max|Δ|=7.153e-07
[batch 0  perm‑K ]  max|Δ|=5.068e-01
[batch 0  perm‑R ]  max|Δ|=0.000e+00
-------------------------------------------------------

----------------  summary (max abs error) ----------------
eqv   : 1.311e-06
permK : 5.068e-01
permR : 0.000e+00
----------------------------------------------------------
✔  thresholds used: atol=5.0e-04, rtol=5.0e-04



In [15]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator=='linear':
            preds = self.agg(tok) .max(1).values                # (R,1)
        elif self.c.aggregator=='nconv':

            preds = self.agg(tok).squeeze(-1)                   # (R,1)
            #print(preds.shape,to)
        else:   # pool
            preds = tok.max(1).values.mean(1,keepdim=True)      # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=1,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =True,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


Run‑ID: 20250727_010526
params: 16367




[1/1]  train mae 1.5971 | mae val 1.2055
     additional metrics:  [1/1]  train 3.9318 | val 2.4525

0.7177050113677979 sec


[batch 0  rot 0]  max|Δ|=0.000e+00
[batch 0  rot 1]  max|Δ|=0.000e+00
[batch 0  rot 2]  max|Δ|=0.000e+00
[batch 0  rot 3]  max|Δ|=0.000e+00
[batch 0  rot 4]  max|Δ|=0.000e+00
[batch 0  perm‑K ]  max|Δ|=3.877e-01
[batch 0  perm‑R ]  max|Δ|=0.000e+00
-------------------------------------------------------

----------------  summary (max abs error) ----------------
eqv   : 0.000e+00
permK : 3.877e-01
permR : 0.000e+00
----------------------------------------------------------
✔  thresholds used: atol=5.0e-04, rtol=5.0e-04



In [16]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator=='linear':
            preds = self.agg(tok) .max(1).values                # (R,1)
        elif self.c.aggregator=='nconv':

            preds = self.agg(tok).squeeze(-1)                   # (R,1)
            #print(preds.shape,to)
        else:   # pool
            preds = tok.max(1).values.mean(1,keepdim=True)      # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=1,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =True,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


Run‑ID: 20250727_010614
params: 16367




[1/1]  train mae 1.6956 | mae val 1.0550
     additional metrics:  [1/1]  train 4.2723 | val 2.1244

0.9071965217590332 sec


[batch 0  rot 0]  max|Δ|=1.192e-07
[batch 0  rot 1]  max|Δ|=1.192e-07
[batch 0  rot 2]  max|Δ|=1.192e-07
[batch 0  rot 3]  max|Δ|=5.960e-08
[batch 0  rot 4]  max|Δ|=1.192e-07
[batch 0  perm‑K ]  max|Δ|=4.585e-01
[batch 0  perm‑R ]  max|Δ|=0.000e+00
-------------------------------------------------------

----------------  summary (max abs error) ----------------
eqv   : 1.192e-07
permK : 4.585e-01
permR : 0.000e+00
----------------------------------------------------------
✔  thresholds used: atol=5.0e-04, rtol=5.0e-04



In [18]:
def sanity_check_knn(dataset, n_samples=5):
    for i in range(min(n_samples, len(dataset))):
        z, pos, y = dataset.data[i]          # shapes: (S,K), (S,K,3)
        sites = pos[:, 0]                       # central atoms (K=hood_k)
        for s in range(len(sites)):
            d = ((pos[s] - sites[s]).norm(dim=1))   # K distances
            if not torch.all(d[1:] >= d[:-1]):      # monotonic?
                print(f"❌ sample {i} site {s} not sorted!")
                return
    print("✔ KNN neighbours are distance‑sorted (checked {:d} samples)".format(n_samples))

sanity_check_knn(train_ds)


✔ KNN neighbours are distance‑sorted (checked 5 samples)


In [19]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator=='linear':
            preds = self.agg(tok) .max(1).values                # (R,1)
        elif self.c.aggregator=='nconv':

            preds = self.agg(tok).squeeze(-1)                   # (R,1)
            #print(preds.shape,to)
        else:   # pool
            preds = tok.max(1).values.mean(1,keepdim=True)      # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=1,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='nconv',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


Run‑ID: 20250727_011231
params: 19479




[1/1]  train mae 1.2278 | mae val 4.1253
     additional metrics:  [1/1]  train 3.2103 | val 18.9056

0.9689278602600098 sec


[batch 0  rot 0]  max|Δ|=2.384e-06
[batch 0  rot 1]  max|Δ|=2.384e-06
[batch 0  rot 2]  max|Δ|=1.907e-06
[batch 0  rot 3]  max|Δ|=2.384e-06
[batch 0  rot 4]  max|Δ|=2.384e-06
[batch 0  perm‑K ]  max|Δ|=5.460e-01
[batch 0  perm‑R ]  max|Δ|=0.000e+00
-------------------------------------------------------

----------------  summary (max abs error) ----------------
eqv   : 2.384e-06
permK : 5.460e-01
permR : 0.000e+00
----------------------------------------------------------
✔  thresholds used: atol=5.0e-04, rtol=5.0e-04



m_ij  = φ_e(h_i, h_j, ‖x_i – x_j‖)         # edge message
m_i   = Σ_j m_ij                            # SUM over neighbours
h_i'  = φ_h(h_i, m_i)                       # new features
x_i'  = x_i + φ_x(m_i) (co‑ordinate update)


In [1]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator == 'linear':
            preds = self.agg(tok).max(1).values                 # (R,1)
        elif self.c.aggregator == 'nconv':
            preds = self.agg(tok).squeeze(-1)                   # (R,1)
        else:                           # pool
            # THIS must be max over dim=2 (neighbours!), not dim=1
            preds = tok.max(2).values.mean(1, keepdim=True)     # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=1,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='pool',

    # block switches
    use_rbf      =False,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)

# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


  from .autonotebook import tqdm as notebook_tqdm


Run‑ID: 20250727_012131
params: 16348




[1/1]  train mae 2.0103 | mae val 1.6925
     additional metrics:  [1/1]  train 5.4504 | val 3.9141

0.729222297668457 sec


[batch 0  rot 0]  max|Δ|=1.192e-07
[batch 0  rot 1]  max|Δ|=1.192e-07
[batch 0  rot 2]  max|Δ|=1.192e-07
[batch 0  rot 3]  max|Δ|=0.000e+00
[batch 0  rot 4]  max|Δ|=1.192e-07
[batch 0  perm‑K ]  max|Δ|=8.344e-02
[batch 0  perm‑R ]  max|Δ|=0.000e+00
-------------------------------------------------------

----------------  summary (max abs error) ----------------
eqv   : 1.192e-07
permK : 8.344e-02
permR : 0.000e+00
----------------------------------------------------------
✔  thresholds used: atol=5.0e-04, rtol=5.0e-04



In [1]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator == 'linear':
            preds = self.agg(tok).max(1).values                 # (R,1)
        elif self.c.aggregator == 'nconv':
            preds = self.agg(tok).squeeze(-1)                   # (R,1)
        else:                           # pool
            # THIS must be max over dim=2 (neighbours!), not dim=1
            preds = tok.max(2).values.mean(1, keepdim=True)     # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=1,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='pool',

    # block switches
    use_rbf      =False,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

  from .autonotebook import tqdm as notebook_tqdm


Run‑ID: 20250727_012459
params: 16348




[1/1]  train mae 2.0103 | mae val 1.6925
     additional metrics:  [1/1]  train 5.4504 | val 3.9141

0.9247465133666992 sec


[batch 0  rot 0]  max|Δ|=1.192e-07
[batch 0  rot 1]  max|Δ|=1.192e-07
[batch 0  rot 2]  max|Δ|=1.192e-07
[batch 0  rot 3]  max|Δ|=0.000e+00
[batch 0  rot 4]  max|Δ|=1.192e-07
[batch 0  perm‑K ]  max|Δ|=8.344e-02
[batch 0  perm‑R ]  max|Δ|=0.000e+00
-------------------------------------------------------

----------------  summary (max abs error) ----------------
eqv   : 1.192e-07
permK : 8.344e-02
permR : 0.000e+00
----------------------------------------------------------
✔  thresholds used: atol=5.0e-04, rtol=5.0e-04



In [5]:
#winner
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn
import numpy as np
from egnn_pytorch import EGNN
from architecture import (StackedEGNN,
                          LearnableRBF,
                          AttentionBlock,
                          TunableBlock)
import time, datetime
import glob
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ================================================================
# 0) dashboard – flip anything here
# ================================================================
class Cfg(dict):
    __getattr__ = dict.__getitem__; __setattr__ = dict.__setitem__


# ================================================================
# 1) reproducibility
# ================================================================
import random, os, numpy as np, torch, glob, datetime


# ================================================================
# 2) dataset helpers
# ================================================================
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
class HoodDS(Dataset):
    def __init__(self, paths, k):
        self.data=[]; self.ids=[]
        nbr=NearestNeighbors(k,algorithm='brute')
        for p in paths:
            try:
                d=np.load(p,allow_pickle=True)
                if len(d['sites'])==0: continue
                nbr.fit(d['pos']); idx=nbr.kneighbors(d['sites'],return_distance=False)
                self.data.append((torch.from_numpy(d['z'][idx]),
                                  torch.from_numpy(d['pos'][idx]),
                                  torch.from_numpy(d['pks'])))
                self.ids.append(os.path.splitext(os.path.basename(p))[0])
            except Exception as e: print("skip",p,e)
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        z,p,y=self.data[i]; return z,p,y,self.ids[i]

def pad(batch,k,device,ret_ids):
    ids=[b[3] for b in batch] if ret_ids else None
    B=len(batch); S=max(b[0].shape[0] for b in batch)
    zt=torch.zeros(B,S,k,dtype=torch.int32,device=device)
    pt=torch.zeros(B,S,k,3,dtype=torch.float32,device=device)
    yt=torch.full((B,S),float('nan'),device=device); mt=torch.zeros(B,S,dtype=torch.bool,device=device)
    for b,(z,p,y,_) in enumerate(batch):
        s=z.shape[0]; zt[b,:s]=z; pt[b,:s]=p; yt[b,:s]=y; mt[b,:s]=True
    return (zt,pt,yt,mt,ids) if ret_ids else (zt,pt,yt,mt)

def split(paths):
    if cfg.num_paths: paths=paths[:cfg.num_paths]
    rng=np.random.RandomState(cfg.split_seed)
    idx=rng.permutation(len(paths)); cut=int(len(paths)*cfg.split_ratio)
    return [paths[i] for i in idx[:cut]], [paths[i] for i in idx[cut:]]

# ================================================================
# 3) model
# ================================================================

class Model(nn.Module):
    def __init__(self,c):
        super().__init__(); self.c=c
        C = c.dim + c.basis

        self.egnn = StackedEGNN(c.dim,c.depth,c.hidden_dim,c.dropout,
                                c.hood_k,98,c.num_neighbors,c.norm_coors).to(c.device)

        self.rbf  = TunableBlock(LearnableRBF(c.basis,10.).to(c.device), c.use_rbf)
        self.attn = TunableBlock(AttentionBlock(C,C,c.hidden_dim).to(c.device), c.use_attn)

        if c.aggregator=='linear':
            self.agg = nn.Linear(C,1).to(c.device)
        elif c.aggregator=='nconv':
            self.agg = nn.Conv1d(c.hood_k,1,kernel_size=C,padding=0).to(c.device)
        elif c.aggregator=='pool':
            self.agg = None
        else: raise ValueError("aggregator must be 'linear' | 'nconv' | 'pool'")

        self.boost = nn.Linear(1,1).to(c.device) if c.use_boost else nn.Identity()
        self.prot  = EGNN(dim=1,update_coors=True,num_nearest_neighbors=3).to(c.device) \
                     if c.use_prot else nn.Identity()
        self.conv  = nn.Conv1d(1,1,c.conv_kernel,padding=c.conv_kernel//2).to(c.device) \
                     if c.use_conv else nn.Identity()

    def forward(self,z,x):
        h,coord=self.egnn(z,x); h=h[0]                # (R,N,dim)
        cent=coord.mean(1,keepdim=True)               # (R,1,3)

        # --- build token ----------------------------------------------------------------
        r = self.rbf(cent,coord).transpose(1,2) if self.c.use_rbf else \
            h.new_zeros(h.size(0),self.c.basis,self.c.hood_k)
        tok = torch.cat((r,h.transpose(1,2)),1)       # (R,C,N)

        att = self.attn(tok.permute(2,0,1))
        tok = att[0] if isinstance(att,(tuple,list)) else att
        tok = tok.permute(1,0,2)                      # (R,N,C)

        # --- aggregation ----------------------------------------------------------------
        if self.c.aggregator == 'linear':
            preds = self.agg(tok).max(1).values                 # (R,1)
        elif self.c.aggregator == 'nconv':
            preds = self.agg(tok).squeeze(-1)                   # (R,1)
        else:                           # pool
            # THIS must be max over dim=2 (neighbours!), not dim=1
            preds = tok.max(2).values.mean(1, keepdim=True)     # (R,1)

        preds = self.boost(preds)

        if self.c.use_prot:
            preds = self.prot(preds.unsqueeze(0),
                              cent.permute(1,0,2))[0].squeeze(0)

        if self.c.use_conv:
            preds = self.conv(preds.T.unsqueeze(0)).squeeze(0).T

        return preds

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='pool',

    # block switches
    use_rbf      =False,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_013022
params: 16348
std of all pos_emb weights: [0.0]




[1/10]  train mae 2.0679 | mae val 1.6369
     additional metrics:  [1/10]  train 5.7051 | val 3.7122

[2/10]  train mae 1.9043 | mae val 1.5044
     additional metrics:  [2/10]  train 5.0144 | val 3.2630

[3/10]  train mae 1.7744 | mae val 1.4418
     additional metrics:  [3/10]  train 4.5309 | val 3.0694

[4/10]  train mae 1.7134 | mae val 1.3815
     additional metrics:  [4/10]  train 4.3170 | val 2.8967

[5/10]  train mae 1.6554 | mae val 1.3670
     additional metrics:  [5/10]  train 4.1240 | val 2.8574

[6/10]  train mae 1.6412 | mae val 1.3597
     additional metrics:  [6/10]  train 4.0755 | val 2.8380

[7/10]  train mae 1.6349 | mae val 1.3280
     additional metrics:  [7/10]  train 4.0553 | val 2.7526

[8/10]  train mae 1.6060 | mae val 1.3082
     additional metrics:  [8/10]  train 3.9623 | val 2.6997

[9/10]  train mae 1.5867 | mae val 1.2715
     additional metrics:  [9/10]  train 3.9030 | val 2.6055

[10/10]  train mae 1.5501 | mae val 1.2268
     additional metrics:  [10/

In [6]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_013112
params: 16367
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.3678 | mae val 1.0182
     additional metrics:  [1/10]  train 3.3569 | val 2.1280

[2/10]  train mae 1.2794 | mae val 1.0285
     additional metrics:  [2/10]  train 3.1727 | val 2.1822

[3/10]  train mae 1.2337 | mae val 1.0614
     additional metrics:  [3/10]  train 3.1268 | val 2.2614

[4/10]  train mae 1.2274 | mae val 1.0644
     additional metrics:  [4/10]  train 3.1727 | val 2.2682

[5/10]  train mae 1.2321 | mae val 1.0470
     additional metrics:  [5/10]  train 3.2208 | val 2.2196

[6/10]  train mae 1.2248 | mae val 1.0359
     additional metrics:  [6/10]  train 3.1890 | val 2.1889

[7/10]  train mae 1.2200 | mae val 1.0254
     additional metrics:  [7/10]  train 3.1695 | val 2.1593

[8/10]  train mae 1.2162 | mae val 1.0156
     additional metrics:  [8/10]  train 3.1511 | val 2.1338

[9/10]  train mae 1.2144 | mae val 1.0106
     additional metrics:  [9/10]  train 3.1356 | val 2.1154

[10/10]  train mae 1.2138 | mae val 1.0083
     additional metrics:  [10/

In [7]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_013210
params: 16367
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.3885 | mae val 1.0231
     additional metrics:  [1/10]  train 3.3832 | val 2.1185

[2/10]  train mae 1.2816 | mae val 1.0145
     additional metrics:  [2/10]  train 3.1687 | val 2.1319

[3/10]  train mae 1.2249 | mae val 1.0412
     additional metrics:  [3/10]  train 3.1034 | val 2.1972

[4/10]  train mae 1.2246 | mae val 1.0401
     additional metrics:  [4/10]  train 3.1733 | val 2.1897

[5/10]  train mae 1.2235 | mae val 1.0248
     additional metrics:  [5/10]  train 3.1738 | val 2.1414

[6/10]  train mae 1.2154 | mae val 1.0156
     additional metrics:  [6/10]  train 3.1317 | val 2.1062

[7/10]  train mae 1.2105 | mae val 1.0125
     additional metrics:  [7/10]  train 3.1010 | val 2.0948

[8/10]  train mae 1.2097 | mae val 1.0116
     additional metrics:  [8/10]  train 3.0911 | val 2.0895

[9/10]  train mae 1.2081 | mae val 1.0119
     additional metrics:  [9/10]  train 3.0841 | val 2.0881

[10/10]  train mae 1.2059 | mae val 1.0132
     additional metrics:  [10/

In [8]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =True,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_013329
params: 16367
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.5507 | mae val 1.0079
     additional metrics:  [1/10]  train 3.8115 | val 2.0867

[2/10]  train mae 1.2471 | mae val 1.1769
     additional metrics:  [2/10]  train 3.1513 | val 2.5985

[3/10]  train mae 1.3027 | mae val 1.2052
     additional metrics:  [3/10]  train 3.4073 | val 2.6954

[4/10]  train mae 1.3337 | mae val 1.1403
     additional metrics:  [4/10]  train 3.5148 | val 2.5015

[5/10]  train mae 1.2920 | mae val 1.0555
     additional metrics:  [5/10]  train 3.3721 | val 2.2421

[6/10]  train mae 1.2332 | mae val 1.0267
     additional metrics:  [6/10]  train 3.1390 | val 2.1562

[7/10]  train mae 1.2295 | mae val 1.0116
     additional metrics:  [7/10]  train 3.1168 | val 2.1080

[8/10]  train mae 1.2206 | mae val 1.0090
     additional metrics:  [8/10]  train 3.0837 | val 2.0860

[9/10]  train mae 1.2360 | mae val 1.0097
     additional metrics:  [9/10]  train 3.1314 | val 2.0811

[10/10]  train mae 1.2409 | mae val 1.0091
     additional metrics:  [10/

In [3]:
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())


std of all pos_emb weights: [0.0]
perm‑K  max|Δ| = 3.5762786865234375e-07


In [9]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='nconv',

    # block switches
    use_rbf      =False,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_013423
params: 18149
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.2849 | mae val 1.9432
     additional metrics:  [1/10]  train 3.0402 | val 5.3374

[2/10]  train mae 1.9140 | mae val 1.3153
     additional metrics:  [2/10]  train 5.3635 | val 2.6428

[3/10]  train mae 1.4063 | mae val 1.3795
     additional metrics:  [3/10]  train 2.8902 | val 2.8052

[4/10]  train mae 1.4163 | mae val 0.9753
     additional metrics:  [4/10]  train 2.8889 | val 1.9173

[5/10]  train mae 0.9356 | mae val 1.1540
     additional metrics:  [5/10]  train 1.7292 | val 2.5166

[6/10]  train mae 1.0839 | mae val 1.2083
     additional metrics:  [6/10]  train 2.0315 | val 2.7017

[7/10]  train mae 1.0911 | mae val 1.0455
     additional metrics:  [7/10]  train 2.0044 | val 2.1617

[8/10]  train mae 0.8662 | mae val 1.0175
     additional metrics:  [8/10]  train 1.3532 | val 1.9758

[9/10]  train mae 0.7540 | mae val 1.0620
     additional metrics:  [9/10]  train 1.1012 | val 2.0516

[10/10]  train mae 0.7726 | mae val 1.0604
     additional metrics:  [10/

In [10]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =False,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014018
params: 16367
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.3678 | mae val 1.0182
     additional metrics:  [1/10]  train 3.3569 | val 2.1280

[2/10]  train mae 1.2794 | mae val 1.0285
     additional metrics:  [2/10]  train 3.1727 | val 2.1822

[3/10]  train mae 1.2337 | mae val 1.0614
     additional metrics:  [3/10]  train 3.1268 | val 2.2614

[4/10]  train mae 1.2274 | mae val 1.0644
     additional metrics:  [4/10]  train 3.1727 | val 2.2682

[5/10]  train mae 1.2321 | mae val 1.0470
     additional metrics:  [5/10]  train 3.2208 | val 2.2196

[6/10]  train mae 1.2248 | mae val 1.0359
     additional metrics:  [6/10]  train 3.1890 | val 2.1889

[7/10]  train mae 1.2200 | mae val 1.0254
     additional metrics:  [7/10]  train 3.1695 | val 2.1593

[8/10]  train mae 1.2162 | mae val 1.0156
     additional metrics:  [8/10]  train 3.1511 | val 2.1338

[9/10]  train mae 1.2144 | mae val 1.0106
     additional metrics:  [9/10]  train 3.1356 | val 2.1154

[10/10]  train mae 1.2138 | mae val 1.0083
     additional metrics:  [10/

In [11]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =True,
    use_boost    =False,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014057
params: 16367
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.5507 | mae val 1.0079
     additional metrics:  [1/10]  train 3.8115 | val 2.0867

[2/10]  train mae 1.2471 | mae val 1.1769
     additional metrics:  [2/10]  train 3.1513 | val 2.5985

[3/10]  train mae 1.3027 | mae val 1.2052
     additional metrics:  [3/10]  train 3.4073 | val 2.6954

[4/10]  train mae 1.3337 | mae val 1.1403
     additional metrics:  [4/10]  train 3.5148 | val 2.5015

[5/10]  train mae 1.2920 | mae val 1.0555
     additional metrics:  [5/10]  train 3.3721 | val 2.2421

[6/10]  train mae 1.2332 | mae val 1.0267
     additional metrics:  [6/10]  train 3.1390 | val 2.1562

[7/10]  train mae 1.2295 | mae val 1.0116
     additional metrics:  [7/10]  train 3.1168 | val 2.1080

[8/10]  train mae 1.2206 | mae val 1.0090
     additional metrics:  [8/10]  train 3.0837 | val 2.0860

[9/10]  train mae 1.2360 | mae val 1.0097
     additional metrics:  [9/10]  train 3.1314 | val 2.0811

[10/10]  train mae 1.2409 | mae val 1.0091
     additional metrics:  [10/

In [12]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014149
params: 16369
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.3911 | mae val 1.0238
     additional metrics:  [1/10]  train 3.7992 | val 2.1272

[2/10]  train mae 1.2327 | mae val 1.0621
     additional metrics:  [2/10]  train 3.2020 | val 2.1575

[3/10]  train mae 1.2938 | mae val 1.0441
     additional metrics:  [3/10]  train 3.2189 | val 2.1229

[4/10]  train mae 1.2640 | mae val 1.0159
     additional metrics:  [4/10]  train 3.1567 | val 2.0934

[5/10]  train mae 1.2257 | mae val 1.0162
     additional metrics:  [5/10]  train 3.1480 | val 2.1503

[6/10]  train mae 1.2413 | mae val 1.0351
     additional metrics:  [6/10]  train 3.2236 | val 2.2005

[7/10]  train mae 1.2562 | mae val 1.0373
     additional metrics:  [7/10]  train 3.2569 | val 2.1939

[8/10]  train mae 1.2633 | mae val 1.0280
     additional metrics:  [8/10]  train 3.2703 | val 2.1653

[9/10]  train mae 1.2534 | mae val 1.0189
     additional metrics:  [9/10]  train 3.2261 | val 2.1390

[10/10]  train mae 1.2480 | mae val 1.0105
     additional metrics:  [10/

In [13]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =False,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014222
params: 17697
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.8071 | mae val 1.1899
     additional metrics:  [1/10]  train 5.4303 | val 2.6614

[2/10]  train mae 1.3854 | mae val 1.0166
     additional metrics:  [2/10]  train 3.7577 | val 2.1186

[3/10]  train mae 1.2363 | mae val 1.0548
     additional metrics:  [3/10]  train 3.1941 | val 2.1412

[4/10]  train mae 1.2867 | mae val 1.0817
     additional metrics:  [4/10]  train 3.1528 | val 2.1956

[5/10]  train mae 1.3186 | mae val 1.0695
     additional metrics:  [5/10]  train 3.2172 | val 2.1704

[6/10]  train mae 1.3116 | mae val 1.0436
     additional metrics:  [6/10]  train 3.2500 | val 2.1172

[7/10]  train mae 1.2843 | mae val 1.0298
     additional metrics:  [7/10]  train 3.1945 | val 2.0954

[8/10]  train mae 1.2538 | mae val 1.0191
     additional metrics:  [8/10]  train 3.1881 | val 2.0842

[9/10]  train mae 1.2483 | mae val 1.0114
     additional metrics:  [9/10]  train 3.1594 | val 2.0878

[10/10]  train mae 1.2266 | mae val 1.0092
     additional metrics:  [10/

In [14]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =False,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014251
params: 16369
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.3937 | mae val 1.0243
     additional metrics:  [1/10]  train 3.8035 | val 2.1324

[2/10]  train mae 1.2349 | mae val 1.0654
     additional metrics:  [2/10]  train 3.2066 | val 2.1650

[3/10]  train mae 1.2997 | mae val 1.0471
     additional metrics:  [3/10]  train 3.2265 | val 2.1290

[4/10]  train mae 1.2703 | mae val 1.0156
     additional metrics:  [4/10]  train 3.1630 | val 2.0852

[5/10]  train mae 1.2241 | mae val 1.0124
     additional metrics:  [5/10]  train 3.1328 | val 2.1259

[6/10]  train mae 1.2329 | mae val 1.0388
     additional metrics:  [6/10]  train 3.1838 | val 2.2051

[7/10]  train mae 1.2619 | mae val 1.0506
     additional metrics:  [7/10]  train 3.2795 | val 2.2364

[8/10]  train mae 1.2756 | mae val 1.0411
     additional metrics:  [8/10]  train 3.3172 | val 2.2054

[9/10]  train mae 1.2694 | mae val 1.0226
     additional metrics:  [9/10]  train 3.2686 | val 2.1484

[10/10]  train mae 1.2465 | mae val 1.0138
     additional metrics:  [10/

In [18]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014627
params: 17697
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.8114 | mae val 1.1944
     additional metrics:  [1/10]  train 5.4368 | val 2.6812

[2/10]  train mae 1.3855 | mae val 1.0176
     additional metrics:  [2/10]  train 3.7635 | val 2.1219

[3/10]  train mae 1.2359 | mae val 1.0543
     additional metrics:  [3/10]  train 3.1869 | val 2.1408

[4/10]  train mae 1.2889 | mae val 1.0848
     additional metrics:  [4/10]  train 3.1580 | val 2.2013

[5/10]  train mae 1.3212 | mae val 1.0738
     additional metrics:  [5/10]  train 3.2142 | val 2.1788

[6/10]  train mae 1.3132 | mae val 1.0469
     additional metrics:  [6/10]  train 3.2433 | val 2.1241

[7/10]  train mae 1.2861 | mae val 1.0330
     additional metrics:  [7/10]  train 3.1852 | val 2.1020

[8/10]  train mae 1.2504 | mae val 1.0217
     additional metrics:  [8/10]  train 3.1760 | val 2.0898

[9/10]  train mae 1.2464 | mae val 1.0120
     additional metrics:  [9/10]  train 3.1592 | val 2.0907

[10/10]  train mae 1.2272 | mae val 1.0113
     additional metrics:  [10/

In [None]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =True,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014447
params: 17705
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.3635 | mae val 1.0526
     additional metrics:  [1/10]  train 3.6311 | val 2.2125

[2/10]  train mae 1.2415 | mae val 1.0090
     additional metrics:  [2/10]  train 3.2189 | val 2.0988

[3/10]  train mae 1.2353 | mae val 1.0189
     additional metrics:  [3/10]  train 3.1706 | val 2.0925

[4/10]  train mae 1.2461 | mae val 1.0235
     additional metrics:  [4/10]  train 3.1474 | val 2.0977

[5/10]  train mae 1.2537 | mae val 1.0192
     additional metrics:  [5/10]  train 3.1591 | val 2.0929

[6/10]  train mae 1.2430 | mae val 1.0136
     additional metrics:  [6/10]  train 3.1418 | val 2.0911

[7/10]  train mae 1.2286 | mae val 1.0118
     additional metrics:  [7/10]  train 3.1184 | val 2.0915

[8/10]  train mae 1.2383 | mae val 1.0107
     additional metrics:  [8/10]  train 3.1703 | val 2.0946

[9/10]  train mae 1.2172 | mae val 1.0107
     additional metrics:  [9/10]  train 3.0995 | val 2.0992

[10/10]  train mae 1.2251 | mae val 1.0110
     additional metrics:  [10/

In [17]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =True,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014558
params: 17705
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.3635 | mae val 1.0526
     additional metrics:  [1/10]  train 3.6311 | val 2.2125

[2/10]  train mae 1.2415 | mae val 1.0090
     additional metrics:  [2/10]  train 3.2189 | val 2.0988

[3/10]  train mae 1.2353 | mae val 1.0189
     additional metrics:  [3/10]  train 3.1706 | val 2.0925

[4/10]  train mae 1.2461 | mae val 1.0235
     additional metrics:  [4/10]  train 3.1474 | val 2.0977

[5/10]  train mae 1.2537 | mae val 1.0192
     additional metrics:  [5/10]  train 3.1591 | val 2.0929

[6/10]  train mae 1.2430 | mae val 1.0136
     additional metrics:  [6/10]  train 3.1418 | val 2.0911

[7/10]  train mae 1.2286 | mae val 1.0118
     additional metrics:  [7/10]  train 3.1184 | val 2.0915

[8/10]  train mae 1.2383 | mae val 1.0107
     additional metrics:  [8/10]  train 3.1703 | val 2.0946

[9/10]  train mae 1.2172 | mae val 1.0107
     additional metrics:  [9/10]  train 3.0995 | val 2.0992

[10/10]  train mae 1.2251 | mae val 1.0110
     additional metrics:  [10/

In [25]:
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =True,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_021450
params: 17705
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.3635 | mae val 1.0526
     additional metrics:  [1/10]  train 3.6311 | val 2.2125

[2/10]  train mae 1.2415 | mae val 1.0090
     additional metrics:  [2/10]  train 3.2189 | val 2.0988

[3/10]  train mae 1.2353 | mae val 1.0189
     additional metrics:  [3/10]  train 3.1706 | val 2.0925

[4/10]  train mae 1.2461 | mae val 1.0235
     additional metrics:  [4/10]  train 3.1474 | val 2.0977



KeyboardInterrupt: 

In [21]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =True,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_014842
params: 17705




[1/10]  train mae 1.4138 | mae val 1.0971
     additional metrics:  [1/10]  train 3.8091 | val 2.3469

[2/10]  train mae 1.2822 | mae val 1.0316
     additional metrics:  [2/10]  train 3.3433 | val 2.1616

[3/10]  train mae 1.2361 | mae val 1.0102
     additional metrics:  [3/10]  train 3.1918 | val 2.0975

[4/10]  train mae 1.2304 | mae val 1.0139
     additional metrics:  [4/10]  train 3.1469 | val 2.0943

[5/10]  train mae 1.2397 | mae val 1.0137
     additional metrics:  [5/10]  train 3.1532 | val 2.0958

[6/10]  train mae 1.2378 | mae val 1.0125
     additional metrics:  [6/10]  train 3.1455 | val 2.1010

[7/10]  train mae 1.2327 | mae val 1.0133
     additional metrics:  [7/10]  train 3.1303 | val 2.1114

[8/10]  train mae 1.2354 | mae val 1.0144
     additional metrics:  [8/10]  train 3.1471 | val 2.1183

[9/10]  train mae 1.2249 | mae val 1.0170
     additional metrics:  [9/10]  train 3.1272 | val 2.1264

[10/10]  train mae 1.2271 | mae val 1.0205
     additional metrics:  [10/

In [None]:

t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =True,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

conclusion: dont modify egnn at all, and just use as plannned, maybe even Conv after I study that more.
but dont use nconv aggregation

In [26]:
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_021523
params: 17697
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.8114 | mae val 1.1944
     additional metrics:  [1/10]  train 5.4368 | val 2.6812

[2/10]  train mae 1.3855 | mae val 1.0176
     additional metrics:  [2/10]  train 3.7635 | val 2.1219

[3/10]  train mae 1.2359 | mae val 1.0543
     additional metrics:  [3/10]  train 3.1869 | val 2.1408

[4/10]  train mae 1.2889 | mae val 1.0848
     additional metrics:  [4/10]  train 3.1580 | val 2.2013

[5/10]  train mae 1.3212 | mae val 1.0738
     additional metrics:  [5/10]  train 3.2142 | val 2.1788

[6/10]  train mae 1.3132 | mae val 1.0469
     additional metrics:  [6/10]  train 3.2433 | val 2.1241

[7/10]  train mae 1.2861 | mae val 1.0330
     additional metrics:  [7/10]  train 3.1852 | val 2.1020

[8/10]  train mae 1.2504 | mae val 1.0217
     additional metrics:  [8/10]  train 3.1760 | val 2.0898

[9/10]  train mae 1.2464 | mae val 1.0120
     additional metrics:  [9/10]  train 3.1592 | val 2.0907

[10/10]  train mae 1.2272 | mae val 1.0113
     additional metrics:  [10/

In [27]:
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_021540
params: 17697
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.8114 | mae val 1.1944
     additional metrics:  [1/10]  train 5.4368 | val 2.6812

[2/10]  train mae 1.3855 | mae val 1.0176
     additional metrics:  [2/10]  train 3.7635 | val 2.1219

[3/10]  train mae 1.2359 | mae val 1.0543
     additional metrics:  [3/10]  train 3.1869 | val 2.1408

[4/10]  train mae 1.2889 | mae val 1.0848
     additional metrics:  [4/10]  train 3.1580 | val 2.2013

[5/10]  train mae 1.3212 | mae val 1.0738
     additional metrics:  [5/10]  train 3.2142 | val 2.1788

[6/10]  train mae 1.3132 | mae val 1.0469
     additional metrics:  [6/10]  train 3.2433 | val 2.1241

[7/10]  train mae 1.2861 | mae val 1.0330
     additional metrics:  [7/10]  train 3.1852 | val 2.1020

[8/10]  train mae 1.2504 | mae val 1.0217
     additional metrics:  [8/10]  train 3.1760 | val 2.0898

[9/10]  train mae 1.2464 | mae val 1.0120
     additional metrics:  [9/10]  train 3.1592 | val 2.0907

[10/10]  train mae 1.2272 | mae val 1.0113
     additional metrics:  [10/

In [28]:
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=2, hidden_dim=4, dropout=0.02,
    hood_k=200, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_021629
params: 18897
std of all pos_emb weights: [0.0]




[1/10]  train mae 2.1616 | mae val 1.5883
     additional metrics:  [1/10]  train 6.9524 | val 3.9480

[2/10]  train mae 1.7354 | mae val 1.1779
     additional metrics:  [2/10]  train 5.0136 | val 2.6048

[3/10]  train mae 1.4026 | mae val 1.0163
     additional metrics:  [3/10]  train 3.8120 | val 2.1245

[4/10]  train mae 1.2580 | mae val 1.0395
     additional metrics:  [4/10]  train 3.2324 | val 2.1198

[5/10]  train mae 1.2830 | mae val 1.0608
     additional metrics:  [5/10]  train 3.1843 | val 2.1595

[6/10]  train mae 1.3125 | mae val 1.0429
     additional metrics:  [6/10]  train 3.2623 | val 2.1247

[7/10]  train mae 1.2936 | mae val 1.0191
     additional metrics:  [7/10]  train 3.2217 | val 2.0902

[8/10]  train mae 1.2443 | mae val 1.0125
     additional metrics:  [8/10]  train 3.1486 | val 2.0946

[9/10]  train mae 1.2287 | mae val 1.0155
     additional metrics:  [9/10]  train 3.1403 | val 2.1211

[10/10]  train mae 1.2256 | mae val 1.0284
     additional metrics:  [10/

IndexError: index 175 is out of bounds for dimension 0 with size 100

In [30]:
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=3, hidden_dim=3, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_021904
params: 21021
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.2914 | mae val 1.0407
     additional metrics:  [1/10]  train 3.3821 | val 2.1940

[2/10]  train mae 1.2388 | mae val 1.0100
     additional metrics:  [2/10]  train 3.1879 | val 2.0907

[3/10]  train mae 1.2260 | mae val 1.0169
     additional metrics:  [3/10]  train 3.1014 | val 2.0809

[4/10]  train mae 1.2442 | mae val 1.0157
     additional metrics:  [4/10]  train 3.1324 | val 2.0809

[5/10]  train mae 1.2401 | mae val 1.0114
     additional metrics:  [5/10]  train 3.1413 | val 2.0842

[6/10]  train mae 1.2345 | mae val 1.0098
     additional metrics:  [6/10]  train 3.1357 | val 2.0982

[7/10]  train mae 1.2262 | mae val 1.0138
     additional metrics:  [7/10]  train 3.1246 | val 2.1178

[8/10]  train mae 1.2271 | mae val 1.0202
     additional metrics:  [8/10]  train 3.1438 | val 2.1376

[9/10]  train mae 1.2299 | mae val 1.0253
     additional metrics:  [9/10]  train 3.1584 | val 2.1510

[10/10]  train mae 1.2310 | mae val 1.0274
     additional metrics:  [10/

In [31]:
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=3, hidden_dim=10, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_021958
params: 27783
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.2294 | mae val 1.0131
     additional metrics:  [1/10]  train 3.1543 | val 2.0882

[2/10]  train mae 1.2434 | mae val 1.0121
     additional metrics:  [2/10]  train 3.1482 | val 2.1107

[3/10]  train mae 1.2344 | mae val 1.0259
     additional metrics:  [3/10]  train 3.1724 | val 2.1541

[4/10]  train mae 1.2295 | mae val 1.0230
     additional metrics:  [4/10]  train 3.1638 | val 2.1457

[5/10]  train mae 1.2238 | mae val 1.0151
     additional metrics:  [5/10]  train 3.1421 | val 2.1242

[6/10]  train mae 1.2300 | mae val 1.0109
     additional metrics:  [6/10]  train 3.1514 | val 2.1060

[7/10]  train mae 1.2328 | mae val 1.0096
     additional metrics:  [7/10]  train 3.1447 | val 2.0985

[8/10]  train mae 1.2257 | mae val 1.0097
     additional metrics:  [8/10]  train 3.1478 | val 2.0981

[9/10]  train mae 1.2255 | mae val 1.0109
     additional metrics:  [9/10]  train 3.1317 | val 2.1049

[10/10]  train mae 1.2294 | mae val 1.0133
     additional metrics:  [10/

In [32]:
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=10, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())

Run‑ID: 20250727_022044
params: 52017
std of all pos_emb weights: [0.0]




[1/10]  train mae 1.4989 | mae val 1.0153
     additional metrics:  [1/10]  train 3.6681 | val 2.1066

[2/10]  train mae 1.2259 | mae val 1.1107
     additional metrics:  [2/10]  train 3.1391 | val 2.4289

[3/10]  train mae 1.2940 | mae val 1.1227
     additional metrics:  [3/10]  train 3.4053 | val 2.4612

[4/10]  train mae 1.2926 | mae val 1.0794
     additional metrics:  [4/10]  train 3.3601 | val 2.3229

[5/10]  train mae 1.2658 | mae val 1.0327
     additional metrics:  [5/10]  train 3.2685 | val 2.1752

[6/10]  train mae 1.2344 | mae val 1.0169
     additional metrics:  [6/10]  train 3.2075 | val 2.1208

[7/10]  train mae 1.2253 | mae val 1.0127
     additional metrics:  [7/10]  train 3.1318 | val 2.0895

[8/10]  train mae 1.2237 | mae val 1.0131
     additional metrics:  [8/10]  train 3.1557 | val 2.0788

[9/10]  train mae 1.2247 | mae val 1.0134
     additional metrics:  [9/10]  train 3.1175 | val 2.0783

[10/10]  train mae 1.2323 | mae val 1.0132
     additional metrics:  [10/

In [None]:
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
t0=time.time()
cfg = Cfg(
    # backbone
    dim=12, basis=6, depth=1, hidden_dim=4, dropout=0.02,
    hood_k=100, num_neighbors=8, norm_coors=True,
    epochs=10,
    # aggregation: 'linear' | 'nconv' | 'pool'
    aggregator   ='linear',

    # block switches
    use_rbf      =True,
    use_attn     =True,
    use_boost    =True,     # Linear(1→1) after aggregator
    use_prot     =True,      # protein‑level EGNN
    use_conv     =False,     # 1‑D conv after prot EGNN
    conv_kernel  =7,

    # training
    loss_type='mae',
    study_metrics=True,
    lr=5e-3, batch_size=1, #batchsize not safw to inc

    # misc
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=0, analysis_mode=False,
    num_paths=2, split_ratio=0.5, split_seed=0,
    runid=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
)
print("Run‑ID:", cfg.runid)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

model=Model(cfg); 
print("params:",sum(p.numel() for p in model.parameters()))

# ================================================================
# 4) loaders
# ================================================================
allp=glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
tr,val=split(allp)
train_ds=HoodDS(tr,cfg.hood_k); val_ds=HoodDS(val,cfg.hood_k)
coll=lambda b: pad(b,cfg.hood_k,cfg.device,cfg.analysis_mode)
tr_loader=DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True ,collate_fn=coll)
va_loader=DataLoader(val_ds,batch_size=cfg.batch_size,shuffle=False,collate_fn=coll)

# ================================================================
# 5) training utils
# ================================================================
p_fn = nn.L1Loss() if cfg.loss_type=='mae' else nn.MSELoss()
v_fn = nn.MSELoss() if cfg.loss_type=='mae' else nn.L1Loss()
opt  = torch.optim.AdamW(model.parameters(),lr=cfg.lr)
sch  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',0.5,3)
from torch.cuda.amp import GradScaler, autocast
scaler=GradScaler(enabled=(cfg.device=='cuda'))

def run(cfg, loader,train):
    model.train() if train else model.eval(); loss_sum=0;n=0;oloss_sum=0
    for z,x,y,m,*_ in loader:
        v=m.view(-1); z=z.view(-1,z.size(2))[v].to(cfg.device)
        x=x.view(-1,x.size(2),3)[v].to(cfg.device); y=y.view(-1)[v].to(cfg.device)
        with autocast(enabled=(cfg.device=='cuda')):
            pred=model(z,x).flatten(); loss=p_fn(pred,y)
        if train:
            opt.zero_grad(); scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        if cfg.study_metrics:
            other_loss = v_fn(pred,y)
            oloss_sum +=other_loss.item(); 
        loss_sum+=loss.item(); n+=1
        
    if not cfg.study_metrics:
        return loss_sum/n
    else:
        return (loss_sum/n, oloss_sum/n)
# run once, before the checker, to blank them all
with torch.no_grad():
    for m in model.modules():
        if hasattr(m, "pos_emb"):          # catches every EGNN_Network
            m.pos_emb.weight.zero_()
    print("std of all pos_emb weights:",
      [m.pos_emb.weight.std().item() for m in model.modules()
       if hasattr(m,"pos_emb")])


# ================================================================
# 6) train
# ================================================================
for e in range(cfg.epochs):
    tr=run(cfg,tr_loader,True)
    va=run(cfg,va_loader,False)
    if not cfg.study_metrics:
        sch.step(va)
        print(f"[{e+1}/{cfg.epochs}]  train {tr:.4f} | val {va:.4f}")
    else:
        sch.step(va[0])
        print(f"[{e+1}/{cfg.epochs}]  train {cfg.loss_type} {tr[0]:.4f} | {cfg.loss_type} val {va[0]:.4f}")
        print("     additional metrics: ",f"[{e+1}/{cfg.epochs}]  train {tr[1]:.4f} | val {va[1]:.4f}")
        print("")

print(time.time() - t0,"sec")

# ================================================================
# 7)  SE(3)‑equivariance & permutation‑invariance checker
# ================================================================
import torch, math, itertools, random
from collections import defaultdict
torch.set_printoptions(precision=3, sci_mode=True)

# ---------- helpers ----------------------------------------------------------
def _random_rotation(device):
    """Draw a random 3×3 rotation matrix from a unit quaternion."""
    q = torch.randn(4, device=device); q /= q.norm()
    w, x, y, z = q
    return torch.tensor([[1-2*(y*y+z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
                         [2*(x*y + z*w), 1-2*(x*x+z*z), 2*(y*z - x*w)],
                         [2*(x*z - y*w), 2*(y*z + x*w), 1-2*(x*x+y*y)]],
                        device=device)

def _prep_batch(batch, cfg):
    """Flatten (B,S,K,…) → (R,K,…) and drop padded rows."""
    z, x, _, m, *_ = batch
    mask  = m.view(-1)                             # (R,)
    z = z.view(-1, z.size(2))[mask].to(cfg.device)           # (R,K)
    x = x.view(-1, x.size(2), 3)[mask].to(cfg.device)        # (R,K,3)
    return z, x                                     # R = Σ valid residues

# ---------- core checks ------------------------------------------------------
@torch.no_grad()
def run_invariance_suite(model, loader, cfg,
                         max_batches=4, rot_trials=3,
                         atol=5e-4, rtol=5e-4, verbose=True):
    """
    For `max_batches` mini‑batches:
      • SE(3) equivariance   (rot + trans)
      • neighbour perm‑inv   (perm K)
      • residue  perm‑eqv    (perm R)
    Returns dict with max abs‑errors.
    """
    stats = defaultdict(float)
    model.eval()

    for b_id, batch in enumerate(loader):
        if b_id >= max_batches: break
        z, x = _prep_batch(batch, cfg)             # (R,K), (R,K,3)

        # --- baseline prediction -------------------------------------------------
        base = model(z, x).flatten()               # (R,)

        # -- 1) SE(3) equivariance -----------------------------------------------
        for t in range(rot_trials):
            R = _random_rotation(x.device)
            tvec = torch.randn(1,1,3, device=x.device)
            x_rt = (x @ R.T) + tvec
            p_rt = model(z, x_rt).flatten()
            err  = (base - p_rt).abs().max().item()
            stats['eqv'] = max(stats['eqv'], err)
            if verbose:
                print(f"[batch {b_id}  rot {t}]  max|Δ|={err:.3e}")

        # -- 2) neighbour‑perm invariance ----------------------------------------
        K = z.size(1)
        permK = torch.randperm(K, device=x.device)
        pK = model(z[:, permK], x[:, permK]).flatten()
        errK = (base - pK).abs().max().item()
        stats['permK'] = max(stats['permK'], errK)
        if verbose:
            print(f"[batch {b_id}  perm‑K ]  max|Δ|={errK:.3e}")

        # -- 3) residue‑perm equivariance ----------------------------------------
        Rn = z.size(0)
        permR = torch.randperm(Rn, device=x.device)
        zR, xR  = z[permR], x[permR]
        pR = model(zR, xR).flatten()
        # undo permutation on prediction
        pR = pR[permR.argsort()]
        errR = (base - pR).abs().max().item()
        stats['permR'] = max(stats['permR'], errR)
        if verbose:
            print(f"[batch {b_id}  perm‑R ]  max|Δ|={errR:.3e}")
            print("-"*55)

    return stats

# ---------- run the suite on real inputs -------------------------------------
print("\n================  INVARIANCE SUITE  ================\n")
stats = run_invariance_suite(model, tr_loader, cfg,
                             max_batches=3, rot_trials=5,
                             atol=5e-4, rtol=5e-4, verbose=True)

print("\n----------------  summary (max abs error) ----------------")
for k,v in stats.items():
    print(f"{k:6}: {v:.3e}")
print("----------------------------------------------------------")
print("✔  thresholds used: atol={:.1e}, rtol={:.1e}".format(5e-4,5e-4))
print("==========================================================\n")


err = (model(z, x) -
model(z[:, torch.randperm(cfg.hood_k)],
        x[:, torch.randperm(cfg.hood_k)])).abs().max()
print("perm‑K  max|Δ| =", err.item())