In [17]:
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 15
BATCH_SIZE  =  1           # now safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 2, 8 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=3
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    r0  = rbf.transpose(1, 2)        # (R, basis, N)
    tok = torch.cat((r0, h0), dim=1)       # (R, dim+basis, N)

    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        #preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            #preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")


Epoch   0 | train L1 = 3.7255
              |  val L1 = 3.7098
Epoch   1 | train L1 = 3.3110
              |  val L1 = 3.5907
Epoch   2 | train L1 = 3.1786
              |  val L1 = 3.3921
Epoch   3 | train L1 = 2.9457
              |  val L1 = 2.9334
Epoch   4 | train L1 = 2.7034
              |  val L1 = 2.5983
Epoch   5 | train L1 = 2.6741
              |  val L1 = 3.7105
Epoch   6 | train L1 = 3.1938
              |  val L1 = 3.5580
Epoch   7 | train L1 = 3.3964
              |  val L1 = 3.7247
Epoch   8 | train L1 = 3.2884
              |  val L1 = 3.7554
Epoch   9 | train L1 = 3.1997
              |  val L1 = 3.5579
15.773443937301636 sec


with protein egnn

In [23]:
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 15
BATCH_SIZE  =  1           # now safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 2, 8 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=3
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    r0  = rbf.transpose(1, 2)        # (R, basis, N)
    tok = torch.cat((r0, h0), dim=1)       # (R, dim+basis, N)

    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids.permute(1,0,2)                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")


Epoch   0 | train L1 = 3.5549
              |  val L1 = 3.8103
Epoch   1 | train L1 = 3.2875
              |  val L1 = 3.4910
Epoch   2 | train L1 = 2.9995
              |  val L1 = 3.0372
Epoch   3 | train L1 = 2.6572
              |  val L1 = 2.6012
Epoch   4 | train L1 = 2.4278
              |  val L1 = 2.4152
Epoch   5 | train L1 = 2.2994
              |  val L1 = 2.2199
Epoch   6 | train L1 = 2.2000
              |  val L1 = 2.0956
Epoch   7 | train L1 = 2.1452
              |  val L1 = 2.0702
Epoch   8 | train L1 = 2.0996
              |  val L1 = 1.9954
Epoch   9 | train L1 = 2.0313
              |  val L1 = 2.0054
16.874836444854736 sec


In [25]:
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 15
BATCH_SIZE  =  1           # now safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 6, 6 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=4
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    r0  = rbf.transpose(1, 2)        # (R, basis, N)
    tok = torch.cat((r0, h0), dim=1)       # (R, dim+basis, N)

    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids.permute(1,0,2)                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")


Epoch   0 | train L1 = 3.6461
              |  val L1 = 3.7602
Epoch   1 | train L1 = 3.3495
              |  val L1 = 3.6736
Epoch   2 | train L1 = 3.1531
              |  val L1 = 3.4099
Epoch   3 | train L1 = 2.6844
              |  val L1 = 2.5964
Epoch   4 | train L1 = 2.3363
              |  val L1 = 2.3687
Epoch   5 | train L1 = 2.1052
              |  val L1 = 2.2530
Epoch   6 | train L1 = 2.1028
              |  val L1 = 2.0454
Epoch   7 | train L1 = 1.9126
              |  val L1 = 2.0342
Epoch   8 | train L1 = 1.8397
              |  val L1 = 1.9076
Epoch   9 | train L1 = 1.8524
              |  val L1 = 1.7860
17.69021701812744 sec


ablation

rbf 12 8 HIDDEN DIM 3, no pegnn

In [26]:
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 100
BATCH_SIZE  =  1           # now safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 12, 8 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=3
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    r0  = rbf.transpose(1, 2)        # (R, basis, N)
    tok = torch.cat((r0, h0), dim=1)       # (R, dim+basis, N)

    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        #preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            #preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")


Epoch   0 | train L1 = 3.4655
              |  val L1 = 3.7604
Epoch   1 | train L1 = 3.3019
              |  val L1 = 3.5398
Epoch   2 | train L1 = 3.0392
              |  val L1 = 3.4065
Epoch   3 | train L1 = 2.7296
              |  val L1 = 2.8663
Epoch   4 | train L1 = 2.6993
              |  val L1 = 2.8542
Epoch   5 | train L1 = 2.4855
              |  val L1 = 2.6787
Epoch   6 | train L1 = 2.1847
              |  val L1 = 2.0534
Epoch   7 | train L1 = 1.9407
              |  val L1 = 1.8874
Epoch   8 | train L1 = 1.8346
              |  val L1 = 1.9031
Epoch   9 | train L1 = 1.6756
              |  val L1 = 1.8190
178.06058812141418 sec


now with pegnn and rbf 12 8 HD 3

In [27]:
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 100
BATCH_SIZE  =  1           # now safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 12, 8 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=3
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    r0  = rbf.transpose(1, 2)        # (R, basis, N)
    tok = torch.cat((r0, h0), dim=1)       # (R, dim+basis, N)

    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids.permute(1,0,2)                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")


Epoch   0 | train L1 = 3.5314
              |  val L1 = 3.8011
Epoch   1 | train L1 = 3.3841
              |  val L1 = 3.7654
Epoch   2 | train L1 = 3.3483
              |  val L1 = 3.7135
Epoch   3 | train L1 = 3.2191
              |  val L1 = 3.3462
Epoch   4 | train L1 = 2.8690
              |  val L1 = 2.9730
Epoch   5 | train L1 = 2.6776
              |  val L1 = 2.8889
Epoch   6 | train L1 = 2.4746
              |  val L1 = 2.7486
Epoch   7 | train L1 = 2.1433
              |  val L1 = 1.9590
Epoch   8 | train L1 = 1.7988
              |  val L1 = 1.5335
Epoch   9 | train L1 = 1.6244
              |  val L1 = 1.5674
175.53697037696838 sec


HIDDEN DIM 4

with rbf, 6 6 HD 4

In [28]:
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 100
BATCH_SIZE  =  1           # not safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 6, 6 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=4
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    r0  = rbf.transpose(1, 2)        # (R, basis, N)
    tok = torch.cat((r0, h0), dim=1)       # (R, dim+basis, N)

    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids.permute(1,0,2)                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")


Epoch   0 | train L1 = 3.7299
              |  val L1 = 3.8041
Epoch   1 | train L1 = 3.4110
              |  val L1 = 3.8359
Epoch   2 | train L1 = 3.4060
              |  val L1 = 3.9078
Epoch   3 | train L1 = 3.3977
              |  val L1 = 3.7853
Epoch   4 | train L1 = 3.3813
              |  val L1 = 3.8105
Epoch   5 | train L1 = 3.3805
              |  val L1 = 3.7775
Epoch   6 | train L1 = 3.3330
              |  val L1 = 3.7182
Epoch   7 | train L1 = 3.2467
              |  val L1 = 3.4676
Epoch   8 | train L1 = 3.0564
              |  val L1 = 3.0519
Epoch   9 | train L1 = 2.5600
              |  val L1 = 2.8689
128.25800108909607 sec


In [33]:
NO rbf, dim 12

SyntaxError: invalid syntax (<ipython-input-33-f126f83bcf66>, line 1)

In [30]:
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 100
BATCH_SIZE  =  1           # not safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 12, 0 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=4
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    #rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    #r0  = rbf.transpose(1, 2)        # (R, basis, N)
    #tok = torch.cat((h0), dim=1)       # (R, dim+basis, N)
    tok=h0
    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids.permute(1,0,2)                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")


Epoch   0 | train L1 = 3.9070
              |  val L1 = 3.7480
Epoch   1 | train L1 = 3.3117
              |  val L1 = 3.5646
Epoch   2 | train L1 = 3.0381
              |  val L1 = 3.1896
Epoch   3 | train L1 = 2.7943
              |  val L1 = 2.9842
Epoch   4 | train L1 = 2.7332
              |  val L1 = 2.9643
Epoch   5 | train L1 = 2.6919
              |  val L1 = 3.1137
Epoch   6 | train L1 = 2.5600
              |  val L1 = 2.9541
Epoch   7 | train L1 = 2.4171
              |  val L1 = 2.6732
Epoch   8 | train L1 = 2.3921
              |  val L1 = 2.8359
Epoch   9 | train L1 = 2.3996
              |  val L1 = 2.7632
144.48865246772766 sec


rbf, dim 12 6 basis (hd4)

In [32]:
#centroids.per
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 100
BATCH_SIZE  =  1           # not safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 12, 6 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=4
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    r0  = rbf.transpose(1, 2)        # (R, basis, N)
    tok = torch.cat((r0, h0), dim=1)       # (R, dim+basis, N)

    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids.permute(1,0,2)                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")
#mute(1,0,2).shape

Epoch   0 | train L1 = 3.4965
              |  val L1 = 3.8579
Epoch   1 | train L1 = 3.3397
              |  val L1 = 3.6933
Epoch   2 | train L1 = 3.2384
              |  val L1 = 3.6057
Epoch   3 | train L1 = 3.0111
              |  val L1 = 3.4762
Epoch   4 | train L1 = 2.7586
              |  val L1 = 3.1231
Epoch   5 | train L1 = 2.5308
              |  val L1 = 2.8734
Epoch   6 | train L1 = 2.2200
              |  val L1 = 2.2098
Epoch   7 | train L1 = 1.8598
              |  val L1 = 1.6053
Epoch   8 | train L1 = 1.5820
              |  val L1 = 1.7834
Epoch   9 | train L1 = 1.6601
              |  val L1 = 1.5638
179.5386073589325 sec


NOW W no pegnn

In [38]:
#centroids.per
import datetime, time
from architecture import *
import torch
import glob, math, time, datetime
import numpy as np
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from egnn_pytorch import EGNN_Network
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
# 0) start timer
t0 = time.time()
N_NEIGHBORS = 100
BATCH_SIZE  =  1           # not safe to increase
PIN_MEMORY  = torch.cuda.is_available()
# reproducibility + device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# decide AMP only on GP0
use_amp = (device.type == "cuda")
if use_amp:
    scaler = GradScaler()
else:
    class DummyCM:
        def __enter__(self): pass
        def __exit__(self, *args): pass
    autocast = DummyCM
    scaler   = None

def init_model(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout):
    
    def build_egnn(dim,depth,hidden_dim,num_neighbors, num_edge_tokens,num_global_tokens,dropout):
        return StackedEGNN(
            dim=dim, depth=depth, hidden_dim=hidden_dim,
            dropout=dropout,
            num_positions=1000, num_tokens=118,
            num_nearest_neighbors=num_neighbors,
            norm_coors=True,
            num_edge_tokens=num_edge_tokens,
            num_global_tokens=num_global_tokens
        )
    net   = build_egnn(dim,depth,hidden_dim,num_neighbors,num_edge_tokens,num_global_tokens,dropout).to(device)
    mha   = AttentionBlock(embed_dim=dim+basis, num_heads=num_heads, hidden_dim=hidden_dim).to(device)
    RBF   = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device) 
    return net, mha, RBF
#net,mha,RBF=init_model
# 3) instantiate everything
dim, basis = 12, 6 #scale to 3,16 at least # dim must be divisible by 2
depth=2 #scale to 2, at least
hidden_dim=4
num_heads=dim + basis 
num_edge_tokens=256
num_global_tokens=256
dropout=0.02
cutoff=10.0
num_neighbors=2


runid=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
import numpy as np, torch, glob

class InMemoryHoodDataset(Dataset):
    """
    Loads *.npz files, constructs fixed-size neighbourhoods around each
    site (anchor) and stores the result entirely in RAM.

    For a protein with S sites the shapes are
        z   : (S, N_NEIGHBORS)      int32
        pos : (S, N_NEIGHBORS, 3)   float32
        y   : (S,)                  float32
    """
    def __init__(self, paths, n_neighbors=N_NEIGHBORS, pin_memory=PIN_MEMORY):
        super().__init__()
        self.data = []
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute")

        for p in paths:
            try:
                dat = np.load(p, allow_pickle=True)
                z_all   = dat["z"].astype(np.int32)        # (N,)
                pos_all = dat["pos"].astype(np.float32)    # (N,3)
                sites   = dat["sites"].astype(np.float32)  # (S,3)
                y       = dat["pks"].astype(np.float32)    # (S,)

                if len(sites) == 0:
                    continue  # skip empty entries

                nbrs.fit(pos_all)
                idx = nbrs.kneighbors(sites, return_distance=False)   # (S, N_NEIGHBORS)

                z_hood   = torch.from_numpy(z_all[idx])            # (S,N_NEIGHBORS)
                pos_hood = torch.from_numpy(pos_all[idx])          # (S,N_NEIGHBORS,3)
                y        = torch.from_numpy(y)                     # (S,)

                if pin_memory:
                    z_hood   = z_hood.pin_memory()
                    pos_hood = pos_hood.pin_memory()
                    y        = y.pin_memory()

                self.data.append((z_hood, pos_hood, y))
            except Exception as e:
                print(f"skipping {p}: {e}")

    def __len__(self):             return len(self.data)
    def __getitem__(self, idx):    return self.data[idx]

# ---------------------------------------------------------------------
# 2) collate function  -------------------------------------------------
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Pads the variable-length site dimension so the batch can be stacked
    into one tensor.  A boolean mask keeps track of which elements are
    real data (True) vs. padding (False).
    """
    # batch = list[(z,pos,y), ...]         len = B
    B               = len(batch)
    S_max           = max(item[0].shape[0] for item in batch)   # longest protein
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    zs   = torch.zeros (B, S_max, N_NEIGHBORS ,   dtype=torch.int32 , device=device)
    pos  = torch.zeros (B, S_max, N_NEIGHBORS ,3, dtype=torch.float32, device=device)
    ys   = torch.full  ((B, S_max),  float("nan"), dtype=torch.float32, device=device)
    #ys   = torch.full  (B, S_max,               float("nan"),        dtype=torch.float32, device=device)
    mask = torch.zeros (B, S_max,                                   dtype=torch.bool,     device=device)

    for b,(z,pos_b,y) in enumerate(batch):
        S = z.shape[0]
        zs  [b, :S] = z.to(device)
        pos [b, :S] = pos_b.to(device)
        ys  [b, :S] = y.to(device)
        mask[b, :S] = True

    return zs, pos, ys, mask             # shapes – see above

# ---------------------------------------------------------------------
# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------

# 0) parameters you might want to expose at the top of the script
# ---------------------------------------------------------------------

# 3) data loaders ------------------------------------------------------
# ---------------------------------------------------------------------
np.random.seed(0)
all_paths = glob.glob("../../../data/pkegnn_INS/inputs/*.npz")
np.random.shuffle(all_paths)
train_paths, val_paths = all_paths[:20], all_paths[20:30]

train_ds = InMemoryHoodDataset(train_paths)
val_ds   = InMemoryHoodDataset(val_paths)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds  , batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=pad_collate,
                          num_workers=0, pin_memory=PIN_MEMORY)



# ---------------------------------------------------------------------
# 4) model pieces ------------------------------------------------------
# ---------------------------------------------------------------------
egnn_net = StackedEGNN(dim=dim, depth=depth, hidden_dim=hidden_dim,
                       dropout=dropout, num_positions=1000, num_tokens=118,
                       num_nearest_neighbors=num_neighbors,
                       norm_coors=True,
                       num_edge_tokens=num_edge_tokens,
                       num_global_tokens=num_global_tokens).to(device)

rbf_layer = LearnableRBF(num_basis=basis, cutoff=cutoff).to(device)
mha_layer = AttentionBlock(embed_dim=dim + basis,
                           num_heads=num_heads,
                           hidden_dim=hidden_dim).to(device)
pred_head = nn.Linear(dim + basis, 1).to(device)

protein_egnn=EGNN(dim=1,update_coors=False,num_nearest_neighbors=3)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(
    list(egnn_net.parameters()) +
    list(rbf_layer.parameters()) +
    list(mha_layer.parameters()) +
    list(pred_head.parameters()) +
    list(protein_egnn.parameters()),
    lr=5e-3
)

epochs = 10  # or whatever you like

# ---------------------------------------------------------------------
# 5) forward for a *compressed* batch (R residues, N neighbours)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 5) single-path forward – no shape guessing, no branching
# ---------------------------------------------------------------------
def forward_residues(z_r, x_r):
    """
    z_r : (R, N)       int32   – atomic numbers for R residues
    x_r : (R, N, 3)    float32 – coordinates
    returns (R, dim + basis)   – per-residue embeddings
    """
    # ---------- EGNN ----------
    h_out, coords = egnn_net(z_r, x_r)          # h_out is [tensor] or tensor
    h = h_out[0] if isinstance(h_out, (list, tuple)) else h_out   # (R, N, dim)

    # ---------- RBF on *input* coords (already (R,N,3)) ----------
    #d   = torch.cdist(x_r, x_r)            # (R, N, N)
    centroids=coords.mean(dim=1).unsqueeze(1)
    rbf = rbf_layer(centroids,coords)                     # (R, N, N, basis)
    #print(centroids.shape,coords.shape)
    # ---------- concat & attention ----------
    h0  = h.transpose(1, 2)                # (R, dim,   N)
    r0  = rbf.transpose(1, 2)        # (R, basis, N)
    tok = torch.cat((r0, h0), dim=1)       # (R, dim+basis, N)

    tok, _ = mha_layer(tok.permute(2, 0, 1))   # (N, R, C) → attn(+PE)
    tok    = tok.permute(1, 0, 2).max(dim=1).values   # (R, C) max over neighbours
    return tok,    centroids.permute(1,0,2)                              # (R, dim + basis)
                                         # (R, dim+basis)

# ---------------------------------------------------------------------
# 6) training / validation loop ---------------------------------------
# ---------------------------------------------------------------------
for epoch in range(epochs):
    # ======== TRAIN ========
    egnn_net.train(); rbf_layer.train(); mha_layer.train(); pred_head.train()
    tr_losses = []

    for z, x, y, mask in train_loader:                 # z:(B,S,N)  mask:(B,S)
        # compress away padding →  (R, N), (R, N, 3), (R,)
        valid      = mask.view(-1)
        z_res      = z.view(-1, z.size(2))[valid].to(device)
        x_res      = x.view(-1, x.size(2), 3)[valid].to(device)
        y_res      = y.view(-1)[valid].to(device)

        optimizer.zero_grad()

        #model
        feats, centroids = forward_residues(z_res, x_res)         # (R, C)
        
        preds = pred_head(feats)       
        t=preds.unsqueeze(0)
        #preds=protein_egnn(t,centroids)[0]

        loss  = criterion(preds.flatten(), y_res)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        tr_losses.append(loss.item())

    print(f"Epoch {epoch:3d} | train L1 = {np.mean(tr_losses):.4f}")

    
    # ======== VALID ========
    egnn_net.eval(); rbf_layer.eval(); mha_layer.eval(); pred_head.eval()
    vl_losses = []

    with torch.no_grad():
        for z, x, y, mask in val_loader:
            valid   = mask.view(-1)
            z_res   = z.view(-1, z.size(2))[valid].to(device)
            x_res   = x.view(-1, x.size(2), 3)[valid].to(device)
            y_res   = y.view(-1)[valid].to(device)

                #model
            feats, centroids = forward_residues(z_res, x_res)         # (R, C)
            
            preds = pred_head(feats)       
            t=preds.unsqueeze(0)
            #preds=protein_egnn(t,centroids)[0]

            loss  = criterion(preds.flatten(), y_res)
            vl_losses.append(loss.item())

    print(f"              |  val L1 = {np.mean(vl_losses):.4f}")
print(time.time() - t0,"sec")
#mute(1,0,2).shape

Epoch   0 | train L1 = 3.4942
              |  val L1 = 3.8465
Epoch   1 | train L1 = 3.3362
              |  val L1 = 3.7035
Epoch   2 | train L1 = 3.2498
              |  val L1 = 3.6456
Epoch   3 | train L1 = 3.0285
              |  val L1 = 3.4390
Epoch   4 | train L1 = 2.7406
              |  val L1 = 3.3917
Epoch   5 | train L1 = 2.5255
              |  val L1 = 2.7860
Epoch   6 | train L1 = 2.1905
              |  val L1 = 2.2727
Epoch   7 | train L1 = 1.8957
              |  val L1 = 1.7053
Epoch   8 | train L1 = 1.6098
              |  val L1 = 1.5675
Epoch   9 | train L1 = 1.5538
              |  val L1 = 1.4907
140.04440879821777 sec


In [37]:
val_paths

['../../../data/pkegnn_INS/inputs/1afl.npz',
 '../../../data/pkegnn_INS/inputs/3n9x.npz',
 '../../../data/pkegnn_INS/inputs/5ftg.npz',
 '../../../data/pkegnn_INS/inputs/4kz7.npz',
 '../../../data/pkegnn_INS/inputs/3f12.npz',
 '../../../data/pkegnn_INS/inputs/6dy0.npz',
 '../../../data/pkegnn_INS/inputs/2lwl.npz',
 '../../../data/pkegnn_INS/inputs/3di1.npz',
 '../../../data/pkegnn_INS/inputs/5i4i.npz',
 '../../../data/pkegnn_INS/inputs/5j3y.npz']

no pegnn

In [20]:
t.shape

torch.Size([1, 111, 1])