In [41]:
# train_gnn.py
import os, glob, json, math
import numpy as np
import torch
from torch.utils.data import random_split
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GraphSAGE, global_max_pool
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import softmax
from torch_geometric.utils import add_self_loops  # add this import
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm

MAX_PIPES   = 20          # upper bound; graphs may be smaller
R_STATS_DIM = 7           # how many per-receiver stats you saved
DEVICE      = 'cuda' if torch.cuda.is_available() else 'cpu'

In [42]:
# ----------  small helpers ---------------------------------------------------
def downsample(ts, T=256):
    """Uniformly down-sample / pad a 1-D list to length T."""
    if len(ts) >= T:
        idx = np.linspace(0, len(ts)-1, T, dtype=int)
        return np.array(ts)[idx]
    out = np.zeros(T, dtype=float)
    out[:len(ts)] = ts
    return out

def encode_receiver(stats):
    """
    stats = list[float]  length R_STATS_DIM
    Very small 1-layer MLP encoder.  Returns 16-D vector.
    """
    w = np.array([0.2,0.2,0.2,0.1,0.1,0.1,0.1])[:R_STATS_DIM]
    return (np.array(stats)*w).astype(np.float32)        # shape (R_STATS_DIM,)


In [43]:
# ----------  Dataset class ---------------------------------------------------
class VeinTreeDataset(Dataset):
    def __init__(self, root):
        super().__init__(root)
        self.run_dirs = [d for d in os.listdir(root) if not d.startswith('.')]
        self.run_dirs.sort()

    def len(self):
        return len(self.run_dirs)

    # -------------------------------------------------------------------------
    def get(self, idx):
        run_path = os.path.join(self.root, self.run_dirs[idx])
        pipe_folders = sorted([p for p in os.listdir(run_path) if p.startswith('pipe')],
                              key=lambda x: int(x.replace('pipe','')))
        # first pass — collect structural info ---------------------------------
        parent = {}
        length = {}
        radius = {}
        recs   = {}           # pipe_id → list[(r,z,stats)]
        for pf in pipe_folders:
            pid = int(pf.replace('pipe',''))
            with open(os.path.join(run_path, pf, 'simulation_data.txt')) as f:
                t = f.readline().split()
            pid = int(t[0])
            try:                                    # robust parent-id parsing
                par = int(t[1])
            except ValueError:                      # e.g. 'None' or 'Non'
                par = -1
            L, R = map(float, t[2:4])
            parent[pid]  = par
            length[pid]  = L
            radius[pid]  = R
            files = glob.glob(os.path.join(run_path,pf,'#*-Ring type.txt'))
            for rf in files:
                with open(rf) as f:
                    rline = f.readline().split()
                    slist = [float(x.strip()) for x in f.readline().split(',')[:R_STATS_DIM]]
                recs.setdefault(pid, []).append((float(rline[1]), float(rline[2]), slist))
        # second pass — build PyG graph ---------------------------------------
        pipe_ids = sorted(parent.keys())                # local idx 0…N-1
        id2row   = {pid:i for i,pid in enumerate(pipe_ids)}
        # build directed edges parent → child
        edge_src, edge_dst = [], []
        for pid in pipe_ids:
            par = parent[pid]
            if par > 0 and par in id2row:
                edge_src.append(id2row[par])
                edge_dst.append(id2row[pid])

        edge_index = torch.tensor([edge_src + edge_dst,
                                edge_dst + edge_src],          # add reverse
                                dtype=torch.long)

        edge_index, _ = add_self_loops(edge_index, num_nodes=len(pipe_ids))
        # ---------- node features --------------------------------------------
        N = len(pipe_ids)
        x = []
        for pid in pipe_ids:
            has_rec = 1.0 if pid in recs else 0.0
            # aggregate receiver stats (mean) or zeros
            if pid in recs:
                v = np.mean([encode_receiver(s) for _,_,s in recs[pid]], axis=0)
            else:
                v = np.zeros(R_STATS_DIM, dtype=np.float32)
            x.append(np.concatenate([
                    [ length[pid] , radius[pid] , has_rec ],
                    v]))
        x = torch.tensor(np.vstack(x), dtype=torch.float32)
        # ---------- read labels ----------------------------------------------
        with open(os.path.join(run_path,'targetOutput.txt')) as f:
            pid_e, r_e , z_e = f.readline().split()[:3]
        emitter_row   = id2row[int(pid_e)]
        y_pipe        = torch.tensor([emitter_row], dtype=torch.long)
        y_coord       = torch.tensor([[float(z_e)/length[int(pid_e)],
                                       float(r_e)/radius[int(pid_e)]]],
                                     dtype=torch.float32)
        return Data(x=x,
                    edge_index=edge_index,
                    y_pipe=y_pipe,          # scalar
                    y_coord=y_coord,        # 1×2
                    num_nodes=N)
# -----------------------------------------------------------------------------

In [44]:
## 2.  GNN model
class EmitterGNN(nn.Module):
    def __init__(self, in_dim, hidden=128, layers=3):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_dim, hidden))
        for _ in range(layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.cls_head = nn.Linear(hidden, 1)   # node logit
        self.reg_head = nn.Linear(hidden, 2)   # (z_rel, r_rel)

    def forward(self, data):
        x, ei = data.x, data.edge_index
        for conv in self.convs:
            x = F.relu(conv(x, ei))
        logits = self.cls_head(x).squeeze(-1)       # (ΣN)
        coords = self.reg_head(x)                   # (ΣN,2)
        return logits, coords, x                   # return x for possible pooling

In [45]:
## 3.  Loss
def emitter_loss(batch, logits, coords, λ=10.0):
    """
    Graph-aware loss:
      • cross-entropy over nodes *within each graph*  (no class-imbalance issues)
      • MSE on (z/L , r/R) for the emitter node
    """
    # 1) node-wise probability restricted to each graph
    probs = softmax(logits, batch.batch)                 # (ΣN,)

    # negative log-likelihood on true emitter indices
    emitter_idx = batch.ptr[:-1] + batch.y_pipe.squeeze()
    loss_cls = (-torch.log(probs[emitter_idx] + 1e-12)).mean()

    # 2) coordinate regression on the emitter rows
    pred_coord = coords[emitter_idx]                     # (B,2)
    loss_reg   = F.mse_loss(pred_coord, batch.y_coord.squeeze())

    return loss_cls + λ * loss_reg, loss_cls.item(), loss_reg.item()
# ---------------------------------------------------------------------------

In [46]:
## 4.  Training loop ----------------------------------------------------------
def train(root, epochs=200, lr=2e-3, batch_size=24):
    ds = VeinTreeDataset(root)
    n = len(ds)
    train_len = int(0.7*n)
    val_len   = int(0.15*n)
    test_len  = n - train_len - val_len
    train_ds, val_ds, test_ds = random_split(ds,[train_len,val_len,test_len],
                                             generator=torch.Generator().manual_seed(42))
    loader   = DataLoader(train_ds, batch_size, shuffle=True)
    val_lo   = DataLoader(val_ds, batch_size)
    test_lo  = DataLoader(test_ds, batch_size)
    model = EmitterGNN(in_dim=3+R_STATS_DIM).to(DEVICE)
    opt   = torch.optim.Adam(model.parameters(), lr=lr)
    best = math.inf
    for epoch in range(1, epochs+1):
        model.train(); tl=0
        for batch in loader:
            batch = batch.to(DEVICE)
            opt.zero_grad()
            logits, coords,_ = model(batch)
            loss,_,_ = emitter_loss(batch, logits, coords)
            loss.backward(); opt.step()
            tl += loss.item()*batch.num_graphs
        # ---- val ------------------------------------------------------------
        model.eval(); vl=0
        with torch.no_grad():
            for batch in val_lo:
                batch=batch.to(DEVICE)
                logits,coords,_ = model(batch)
                loss,_,_ = emitter_loss(batch,logits,coords)
                vl += loss.item()*batch.num_graphs
        vl /= len(val_ds); tl/=len(train_ds)
        if vl<best:
            best=vl; torch.save(model.state_dict(),'best_gnn.pt')
        if epoch%1==0:
            print(f'E{epoch:03d}  train {tl:.4f}  val {vl:.4f}')
    # ------------- test ------------------------------------------------------
    model.load_state_dict(torch.load('best_gnn.pt')); model.eval()
    correct=0; err=[]
    with torch.no_grad():
        for batch in test_lo:
            batch=batch.to(DEVICE)
            logits,coords,_=model(batch)
            idx = (logits.sigmoid()>0.5).int()        # one emitter per graph
            pred = []
            for i in range(batch.num_graphs):
                beg, end = batch.ptr[i], batch.ptr[i+1]
                pred.append(torch.argmax(logits[beg:end]).item())
            pred = torch.tensor(pred, device=DEVICE)
            correct += (pred==batch.y_pipe.squeeze()).sum().item()
            emitter_idx = batch.ptr[:-1]+batch.y_pipe.squeeze()
            err.extend(torch.sqrt(((coords[emitter_idx]-batch.y_coord.squeeze())**2).sum(1)).cpu().numpy())
    print(f'\nTest  pipe accuracy {correct/len(test_ds)*100:5.1f}%   '
          f'avg 3-D error {np.mean(err):.4f}')
# -----------------------------------------------------------------------------

In [47]:
if __name__=='__main__':
    train('/Users/daghanerdonmez/Desktop/molecular-simulation-mlp/output-processing/Outputs_Copy')

E001  train 4.4252  val 4.3470
E002  train 4.3241  val 4.3500
E003  train 4.3147  val 4.2805
E004  train 4.3106  val 4.3258
E005  train 4.3080  val 4.2711
E006  train 4.2968  val 4.3352
E007  train 4.3139  val 4.3041
E008  train 4.3001  val 4.2908
E009  train 4.3031  val 4.3232
E010  train 4.3006  val 4.3206
E011  train 4.3010  val 4.2707
E012  train 4.3088  val 4.3664
E013  train 4.3199  val 4.3118
E014  train 4.3026  val 4.2873
E015  train 4.2968  val 4.2924
E016  train 4.2934  val 4.2907
E017  train 4.2981  val 4.3291
E018  train 4.2932  val 4.2868
E019  train 4.3048  val 4.3276
E020  train 4.3029  val 4.2993
E021  train 4.2987  val 4.2827
E022  train 4.3224  val 4.3407
E023  train 4.3000  val 4.3124
E024  train 4.3160  val 4.2875
E025  train 4.3028  val 4.2981
E026  train 4.2911  val 4.3232
E027  train 4.3004  val 4.3080
E028  train 4.2977  val 4.2932
E029  train 4.2971  val 4.3323
E030  train 4.3080  val 4.2800
E031  train 4.2967  val 4.2861
E032  train 4.2919  val 4.2914
E033  tr