In [1]:
import os
import glob
from typing import List, Tuple

import numpy as np
import torch
from torch import nn
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from tqdm import tqdm

# ----------------------------
# CONFIG
# ----------------------------
MAX_PIPES: int = 40  # accommodates pipe1-pipe39 (index 1-39). we also keep slot 0.
NODE_FEAT_DIM: int = 3  # length, radius, num_recv
CLS_WEIGHT: float = 5.0  # weight for pipe-classification loss
BATCH_SIZE: int = 32
EPOCHS: int = 50
LR: float = 3e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# ---------------------------------------------------------
#   UTILITIES
# ---------------------------------------------------------

def read_pipe_simulation(path: str) -> Tuple[int, int, float, float, float]:
    """Read simulation_data.txt and return tuple:
    (pipe_id, parent_id, length, radius, num_receivers)
    """
    with open(path, "r") as f:
        parts = f.readline().strip().split()
    pipe_id = int(parts[0])
    parent_id = int(parts[1])
    length = float(parts[2])
    radius = float(parts[3])
    num_recv = float(parts[4]) if len(parts) > 4 else 0.0
    return pipe_id, parent_id, length, radius, num_recv


def build_graph_from_run(run_dir: str) -> Tuple[Data, torch.Tensor, torch.Tensor]:
    """Create a PyG Data graph for one simulation run.

    Returns (data_graph, cls_target, reg_target)
    cls_target: scalar tensor (0-based emitter pipe index)
    reg_target: tensor([r, z])
    """
    # -------- gather node features & edges ----------
    node_feats: List[List[float]] = []  # indexed by pipe_id-1 (0-based)
    edge_index = []  # list of (src, dst)

    # Preload list of pipe folders present.
    pipe_folders = [d for d in os.listdir(run_dir) if d.startswith("pipe")]
    for pf in pipe_folders:
        pipe_id = int(pf.replace("pipe", ""))  # 1-based
        while len(node_feats) < pipe_id:
            node_feats.append([0.0] * NODE_FEAT_DIM)  # pad missing ids with zero rows

        sim_path = os.path.join(run_dir, pf, "simulation_data.txt")
        if not os.path.exists(sim_path):
            continue
        p_id, parent_id, length, radius, num_recv = read_pipe_simulation(sim_path)
        node_feats[p_id - 1] = [length, radius, num_recv]
        if parent_id >= 1:
            # convert to 0-based indices
            edge_index.append([parent_id - 1, p_id - 1])
            edge_index.append([p_id - 1, parent_id - 1])  # undirected

    # Convert lists to tensors
    if not node_feats:  # no pipes? skip
        raise ValueError(f"No pipe data in {run_dir}")
    x = torch.tensor(node_feats, dtype=torch.float32)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # -------- target --------
    tgt_path = os.path.join(run_dir, "targetOutput.txt")
    with open(tgt_path, "r") as f:
        tparts = f.readline().strip().split()
    tgt_pipe = int(tparts[0])  # 1-based
    tgt_r = float(tparts[1])
    tgt_z = float(tparts[2]) if len(tparts) > 2 else 0.0

    cls_target = torch.tensor(tgt_pipe - 1, dtype=torch.long)  # 0-based class
    reg_target = torch.tensor([tgt_r, tgt_z], dtype=torch.float32)

    data = Data(x=x, edge_index=edge_index, y=cls_target, pos_target=reg_target)
    return data, cls_target, reg_target


# ---------------------------------------------------------
#   DATASET CLASS
# ---------------------------------------------------------
class PipeGraphDataset(Dataset):
    """PyG Dataset that loads all simulation runs under a directory."""

    def __init__(self, root: str):
        super().__init__(None, None, None)
        self.root_dir = root
        self.run_dirs = [os.path.join(root, d) for d in os.listdir(root)
                         if os.path.isdir(os.path.join(root, d)) and not d.startswith('.')]

    def len(self):
        return len(self.run_dirs)

    def get(self, idx):
        run_dir = self.run_dirs[idx]
        data, cls_target, reg_target = build_graph_from_run(run_dir)
        return data


# ---------------------------------------------------------
#   MODEL
# ---------------------------------------------------------
class PipeGNN(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 64, num_layers: int = 3, num_classes: int = 39):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_dim, hidden))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden, hidden))
        self.cls_head = nn.Linear(hidden, num_classes)
        self.reg_head = nn.Linear(hidden, 2)

    def forward(self, data: Data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        # global pooling
        x = global_mean_pool(x, batch)
        logits = self.cls_head(x)
        coords = self.reg_head(x)
        return logits, coords

In [None]:
# ---------------------------------------------------------
#   TRAIN / VAL / TEST
# ---------------------------------------------------------

def train_epoch(model, loader, opt, cls_crit, reg_crit):
    model.train()
    total_cls = total_reg = 0.0
    n_samples = 0
    for data in loader:
        data = data.to(DEVICE)
        opt.zero_grad()
        logits, coords = model(data)
        cls_loss = cls_crit(logits, data.y)
        reg_loss = reg_crit(coords, data.pos_target)
        loss = CLS_WEIGHT * cls_loss + reg_loss
        loss.backward()
        opt.step()

        bs = data.num_graphs
        total_cls += cls_loss.item() * bs
        total_reg += reg_loss.item() * bs
        n_samples += bs
    return total_cls / n_samples, total_reg / n_samples


def eval_epoch(model, loader, cls_crit, reg_crit):
    model.eval()
    total_cls = total_reg = 0.0
    correct = 0
    pos_errors = []
    n_samples = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(DEVICE)
            logits, coords = model(data)
            cls_loss = cls_crit(logits, data.y)
            reg_loss = reg_crit(coords, data.pos_target)

            bs = data.num_graphs
            total_cls += cls_loss.item() * bs
            total_reg += reg_loss.item() * bs
            n_samples += bs

            pred_pipe = logits.argmax(dim=1)
            correct += (pred_pipe == data.y).sum().item()
            pos_errors.extend(((coords - data.pos_target).norm(dim=1)).cpu().numpy().tolist())
    acc = 100.0 * correct / n_samples
    return total_cls / n_samples, total_reg / n_samples, acc, np.mean(pos_errors)


In [None]:
# ---------------------------------------------------------
#   MAIN
# ---------------------------------------------------------

def main():
    data_root = os.path.join(os.path.dirname(__file__), "../output-processing/Outputs_Copy")
    dataset = PipeGraphDataset(data_root)
    print(f"Total samples: {len(dataset)}")

    # Split dataset
    idxs = np.random.permutation(len(dataset))
    train_split = int(0.7 * len(dataset))
    val_split = int(0.85 * len(dataset))
    train_idx, val_idx, test_idx = idxs[:train_split], idxs[train_split:val_split], idxs[val_split:]
    train_ds = dataset.index_select(train_idx.tolist())
    val_ds = dataset.index_select(val_idx.tolist())
    test_ds = dataset.index_select(test_idx.tolist())

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

    model = PipeGNN(in_dim=NODE_FEAT_DIM, hidden=128, num_layers=3, num_classes=39).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    cls_crit = nn.CrossEntropyLoss()
    reg_crit = nn.MSELoss()

    best_val = float('inf')
    for epoch in range(1, EPOCHS + 1):
        trc, trr = train_epoch(model, train_loader, opt, cls_crit, reg_crit)
        vcc, vcr, vacc, verr = eval_epoch(model, val_loader, cls_crit, reg_crit)
        print(f"Epoch {epoch:02d} | "
              f"train cls {trc:.4f} reg {trr:.4f} | "
              f"val cls {vcc:.4f} reg {vcr:.4f} acc {vacc:.2f}% posErr {verr:.4f}")
        if vcc + vcr < best_val:
            best_val = vcc + vcr
            torch.save(model.state_dict(), "best_pipe_gnn.pt")

    # ---- test ----
    tcc, tcr, tacc, terr = eval_epoch(model, test_loader, cls_crit, reg_crit)
    print("Test results | "
          f"cls loss {tcc:.4f} reg loss {tcr:.4f} | "
          f"acc {tacc:.2f}% posErr {terr:.4f}")


if __name__ == "__main__":
    try:
        import torch_geometric  # noqa
    except ImportError:
        raise ImportError("torch_geometric is required. Install via 'pip install torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-$(torch.version)[...]')")
    main()
