<a href="https://colab.research.google.com/github/jk673/grapinnformer/blob/main/test_gnn_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import packages and modules

In [1]:
import torch, os, sys, subprocess

def in_colab():
    try:
        import google.colab  # Colab 전용 모듈
        return True
    except ImportError:
        return False

# 버전 및 wheel URL 구성
torch_ver = torch.__version__.split('+')[0]           # e.g. '2.8.0'
cuda_ver  = torch.version.cuda or ''
cu_tag    = f"cu{cuda_ver.replace('.','')}" if torch.cuda.is_available() and cuda_ver else "cpu"
url = f"https://data.pyg.org/whl/torch-{torch_ver}+{cu_tag}.html"

print("Torch:", torch.__version__, "| CUDA:", cuda_ver or "cpu", "| PYG wheel index:", url)

# 설치
if in_colab():
    !pip install pyvista
    !pip install torch_geometric
    subprocess.check_call([
        sys.executable, "-m", "pip", "install", "-U",
        "pyg-lib", "torch-scatter", "torch-sparse", "torch-cluster", "torch-spline-conv",
        "-f", url
    ])

print("✅ Done. Please restart runtime/kernel, then re-run your code.")


Torch: 2.8.0+cu126 | CUDA: 12.6 | PYG wheel index: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Collecting pyvista
  Downloading pyvista-0.46.1-py3-none-any.whl.metadata (15 kB)
Collecting vtk!=9.4.0 (from pyvista)
  Downloading vtk-9.5.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.5 kB)
Downloading pyvista-0.46.1-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading vtk-9.5.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (112.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.1/112.1 MB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: vtk, pyvista
Successfully installed pyvista-0.46.1 vtk-9.5.0
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.9 MB/s[0

In [2]:
# 1. Environment & dependencies
import gc
import glob
import math
import os
import random
import sys
import time
from itertools import chain
from os import PathLike
from pathlib import Path
from typing import List, Tuple, Union

import numpy as np
import pyvista as pv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensor
import torch_geometric
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import GCNConv, GraphNorm, JumpingKnowledge, TransformerConv
from tqdm.auto import tqdm
import wandb
import pyg_lib

# Data and model settings
DATA_ROOT = Path('data')
TARGET_FIELD = 'static(p)_coeffMean'
USE_NORMALS = True

# Reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print("pyg-lib ok;", getattr(pyg_lib, "__version__", "ok"))
print('Torch:', torch.__version__)
print('CUDA:', torch.version.cuda)
print('PyVista:', pv.__version__)
print('PyG:', torch_geometric.__version__ if 'torch_geometric' in sys.modules else 'unknown')
print("py  :", sys.version.split()[0])  # ex) 3.11.x

pyg-lib ok; 0.4.0+pt28cu126
Torch: 2.8.0+cu126
CUDA: 12.6
PyVista: 0.46.1
PyG: 2.6.1
py  : 3.12.11


# Model

In [None]:
class FourierPosEnc(nn.Module):
    def __init__(self, in_ch=3, num_frequencies=2):  # ★ pos 주파수 절반으로
        super().__init__()
        self.register_buffer("freqs", 2.0**torch.arange(num_frequencies) * torch.pi)
        self.out_dim = in_ch*(2*num_frequencies)
    def forward(self, pos):
        pe = []
        for f in self.freqs:
            ang = f * pos
            pe += [torch.sin(ang), torch.cos(ang)]
        return torch.cat(pe, dim=-1)

class FiLM(nn.Module):
    def __init__(self, cond_dim, hidden):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(cond_dim, hidden*2), nn.GELU(), nn.Linear(hidden*2, hidden*2)
        )
    def forward(self, h, cond):
        if cond is None: return h
        if cond.dim()==1: cond = cond.unsqueeze(0)
        gamma, beta = torch.chunk(self.mlp(cond), 2, dim=-1)
        return h * (1 + gamma.squeeze(0)) + beta.squeeze(0)

class BoundaryGraphNet(nn.Module):
    def __init__(self, in_dim, hidden=96, layers=4, dropout=0.15,   # ★ 192→96, 6→4
                 heads=2, use_pos_enc=True, pos_freqs=2, cond_dim=4, # ★ 4→2
                 out_dim_p=1, out_dim_tau=3, jk_mode='last', use_checkpoint=False,
                 predict_tau=False):
        super().__init__()
        self.use_pos_enc = use_pos_enc
        self.pos_enc = FourierPosEnc(3, pos_freqs) if use_pos_enc else None
        self.predict_tau = predict_tau
        pe_dim = self.pos_enc.out_dim if use_pos_enc else 0

        self.x_enc = nn.Sequential(
            nn.Linear(in_dim + pe_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, hidden)
        )

        # ★ edge_dim을 8로 크게 축소
        self.edge_encoder = nn.Sequential(
            nn.Linear(4, 16), nn.GELU(), nn.Linear(16, 8)
        )
        self.edge_dim = 8

        self.layers = nn.ModuleList()
        self.norms  = nn.ModuleList()
        for _ in range(layers):
            # ★ concat=False, out_channels=hidden, add_self_loops=False
            self.layers.append(TransformerConv(
                in_channels=hidden,
                out_channels=hidden,
                heads=heads,
                concat=False,
                dropout=dropout,
                edge_dim=self.edge_dim,
                beta=True,
            ))
            self.norms.append(GraphNorm(hidden))

        # ★ JK를 'last'로 기본 설정(메모리 절약). 필요시 'max'로 바꿔도 됨
        self.jk_mode = jk_mode
        if jk_mode == 'max':
            self.jk = JumpingKnowledge(mode='max')
        else:
            self.jk = None

        self.p_head   = nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim_p))
        self.tau_head = nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim_tau))

        self.film = FiLM(cond_dim, hidden) if cond_dim>0 else None
        self.dropout = dropout
        self.use_checkpoint = use_checkpoint  # ★ 필요할 때만 켜기

    @staticmethod
    def _make_edge_attr(pos, edge_index):
        i, j = edge_index
        rij = pos[j] - pos[i]
        dij = torch.norm(rij, dim=1, keepdim=True).clamp_min(1e-12)
        return torch.cat([rij, dij], dim=1)  # [E,4]

    def forward(self, data):
        x, pos, edge_index = data.x, data.pos, data.edge_index
        pe = self.pos_enc(pos) if self.use_pos_enc else None
        h = self.x_enc(torch.cat([x, pe], dim=-1) if pe is not None else x)

        eattr = getattr(data, 'edge_attr', None)
        if eattr is None:
            eattr = self._make_edge_attr(pos, edge_index)
        e = self.edge_encoder(eattr)

        hs = [] if self.jk_mode == 'max' else None

        for conv, norm in zip(self.layers, self.norms):
            def block(h_in):
                h_mid = conv(h_in, edge_index, edge_attr=e)
                if self.film is not None and hasattr(data, 'global_cond'):
                    h_mid = self.film(h_mid, data.global_cond)
                h_mid = norm(h_mid)
                h_mid = F.gelu(h_mid)
                h_mid = F.dropout(h_mid, p=self.dropout, training=self.training)
                return h_mid

            h_res = h
            # ★ 선택적 체크포인팅(메모리↓, 연산↑)
            h = block(h) if not self.use_checkpoint else cp.checkpoint(block, h)
            h = h + h_res
            if hs is not None:
                hs.append(h)

        if self.jk_mode == 'max':
            h = self.jk(hs)
        # 'last'면 그냥 마지막 h 사용

        p_pred   = self.p_head(h)
        tau_pred = self.tau_head(h)
        return (p_pred, tau_pred) if self.predict_tau else p_pred

# Training Loop

In [None]:
in_dim = 8
target_dim = 1
inferred_cond_dim = 2

# --------- 하이퍼파라미터 ---------
BATCH_SIZE   = 1
LR           = 1e-4
WEIGHT_DECAY = 1e-5
EPOCHS       = 75
PRINT_EVERY  = 5
GRAD_CLIP    = 5.0

# surface-only loss 가중치
LOSS_WEIGHTS = {"data": 1.0, "tv": 0.05, "lap": 0.01}
USE_TV = True
EPS = 1e-12

# --------- device ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


# --------- 모델 정의/이동 ----------
# ⚠️ out_dim을 타겟 채널 수로!
model = BoundaryGraphNet(
    in_dim=in_dim,
    hidden=256,
    layers=4,
    dropout=0.1,
    heads=2,
    use_pos_enc=True,
    pos_freqs=2,
    cond_dim=4   # 예: [Uinf, rho, nu, A_ref]를 data.global_cond로 넣을 때
).to(device)

# --------- optimizer (+옵션: 스케줄러) ----------
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=LR*0.1)

# --------- W&B init ----------
os.environ["WANDB_DISABLE_SYMLINKS"] = "true"  # <- wandb.init 이전에!

wandb_run = wandb.init(
    project="gnn-pinn-250819",     # <- 프로젝트명 바꿔도 됨
    name=f"surface_only_{int(time.time())}",
    config={
        "epochs": EPOCHS,
        "lr": LR,
        "weight_decay": WEIGHT_DECAY,
        "batch_size": BATCH_SIZE,
        "grad_clip": GRAD_CLIP,
        "loss_weights": LOSS_WEIGHTS,
        "use_tv": USE_TV,
        "feat_dim": in_dim,
        "target_dim": target_dim,
        "model": model.__class__.__name__,
        "optimizer": "Adam",
    },
)
# 과도한 비용 피하려면 watch는 주석처리 가능
# wandb.watch(model, log="all", log_freq=50)

# --------- 기하 도우미/손실 ---------
def edge_geo_terms(pos: Tensor, edge_index: Tensor):
    i, j = edge_index
    e_ij   = pos[j] - pos[i]
    len_ij = e_ij.norm(dim=1, keepdim=True).clamp_min(torch.finfo(pos.dtype).tiny)
    t_ij   = e_ij / len_ij
    return i, j, t_ij, len_ij

def graph_edge_weights(pos: Tensor, edge_index: Tensor, mode="invlen", clamp=10.0):
    _, _, _, len_ij = edge_geo_terms(pos, edge_index)
    if mode == "invlen":
        w = (1.0 / (len_ij + EPS)).squeeze(1)
        if clamp is not None:
            w = w.clamp_max(clamp)
        return w
    else:
        return torch.ones(edge_index.size(1), device=pos.device, dtype=pos.dtype)

def edge_tv_or_l2(y_pred: Tensor, edge_index: Tensor, edge_w: Tensor, use_tv=True):
    i, j = edge_index
    diff = y_pred[i] - y_pred[j]  # (E,C)
    if use_tv:
        return (diff.abs() * edge_w.unsqueeze(-1)).mean()
    else:
        return ((diff**2) * edge_w.unsqueeze(-1)).mean()

def laplacian_reg(y_pred, edge_index, num_nodes, edge_w=None, eps=1e-12):
    """
    y_pred: [N, C]
    edge_index: [2, E] (long)
    edge_w: [E] or None
    """
    # --- 준비: dtype/device 정렬 ---
    device = y_pred.device
    dtype  = torch.float32  # 라플라시안은 fp32로 계산 권장
    # (원래 y_pred가 fp16/ bf16여도 여기선 fp32로 올려 계산 후 다시 캐스팅)

    # 인덱스는 long, 같은 디바이스
    i, j = edge_index
    i = i.to(device=device, dtype=torch.long)
    j = j.to(device=device, dtype=torch.long)

    E = i.numel()
    C = y_pred.size(1)

    # 엣지 가중치
    if edge_w is None:
        edge_w = torch.ones(E, device=device, dtype=dtype)
    else:
        edge_w = edge_w.to(device=device, dtype=dtype)

    # y를 fp32로
    y = y_pred.to(dtype)

    # --- degree 및 가중합 ---
    deg = torch.zeros(num_nodes, device=device, dtype=dtype)
    deg.scatter_add_(0, i, edge_w)  # deg[u] = sum_v w_uv

    wyj = torch.zeros((num_nodes, C), device=device, dtype=dtype)
    wyj.scatter_add_(0, i.unsqueeze(-1).expand(-1, C),
                     (edge_w.unsqueeze(-1) * y[j]))

    # 라플라시안 L y = D y - W y  (여기선 행 정규화/대칭 정규화 없이 단순형)
    Ly = deg.unsqueeze(-1) * y - wyj

    # 규제값: ||Ly||^2 (노드/채널 평균)
    reg = (Ly.pow(2).sum(dim=1)).mean()

    # 원래 dtype으로 캐스팅해서 리턴
    return reg.to(y_pred.dtype)


def surface_only_loss(batch, pred: Tensor, loss_weights=LOSS_WEIGHTS):
    x = batch.x
    y = batch.y
    edge_index = batch.edge_index
    pos = batch.pos if hasattr(batch, 'pos') and batch.pos is not None else x[:, :3]

    data_loss = nn.functional.mse_loss(pred, y)
    w_e = graph_edge_weights(pos, edge_index, mode="invlen")

    tv_loss  = edge_tv_or_l2(pred, edge_index, w_e, use_tv=USE_TV)

    if hasattr(batch, 'batch') and batch.batch is not None:
        lap_loss_acc = 0.0
        uniq = batch.batch.unique()
        for g_id in uniq:
            mask = (batch.batch == g_id)
            node_idx = torch.nonzero(mask, as_tuple=False).squeeze(1)
            mask_i = mask[edge_index[0]]
            mask_j = mask[edge_index[1]]
            e_mask = mask_i & mask_j
            if e_mask.sum() == 0:
                continue
            sub_e = edge_index[:, e_mask]
            old2new = -torch.ones(mask.size(0), device=mask.device, dtype=torch.long)
            old2new[node_idx] = torch.arange(node_idx.size(0), device=mask.device)
            sub_e = old2new[sub_e]
            lap_loss_acc = lap_loss_acc + laplacian_reg(pred[mask], sub_e, node_idx.size(0), w_e[e_mask])
        lap_loss = lap_loss_acc / (uniq.numel() + EPS)
    else:
        lap_loss = laplacian_reg(pred, edge_index, x.size(0), w_e)

    loss = (loss_weights["data"] * data_loss +
            loss_weights["tv"]   * tv_loss  +
            loss_weights["lap"]  * lap_loss)

    return loss, {"loss_data": data_loss.detach(),
                  "loss_tv":   tv_loss.detach(),
                  "loss_lap":  lap_loss.detach()}

# ======================
# Training / Validation
# ======================
best_val = float('inf')
best_path = "best_surface_only.pt"
torch.cuda.empty_cache()

for epoch in range(1, EPOCHS + 1):
    train_iter = chain(*(iter(ld) for ld in train_loader))

    model.train()
    tr_loss_sum = 0.0; tr_nodes = 0
    ep_data = ep_tv = ep_lap = 0.0
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    for step, batch in enumerate(train_iter, start=1):
        batch = batch.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)


        with torch.cuda.amp.autocast(enabled=True):
            pred = model(batch)
            if pred.dim() == 1:
                pred = pred.unsqueeze(1)
            loss, parts = surface_only_loss(batch, pred)


        # ★ 올바른 순서: backward -> unscale -> clip -> step -> update
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        if GRAD_CLIP and GRAD_CLIP > 0:
            nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()

        # ★ CosineAnnealingWarmRestarts는 iteration 단위 step이 자연스러움
        scheduler.step(epoch - 1 + step / len_epoch)

        n = batch.x.size(0)
        tr_loss_sum += loss.item() * n
        tr_nodes    += n
        ep_data += parts["loss_data"].item() * n
        ep_tv   += parts["loss_tv"].item()   * n
        ep_lap  += parts["loss_lap"].item()  * n

    train_loss = tr_loss_sum / max(tr_nodes, 1)
    train_data = ep_data / max(tr_nodes, 1)
    train_tv   = ep_tv   / max(tr_nodes, 1)
    train_lap  = ep_lap  / max(tr_nodes, 1)

    # -------- Validation --------
    model.eval()
    va_loss_sum = 0.0
    va_nodes    = 0
    va_data, va_tv, va_lap = 0.0, 0.0, 0.0

    val_iter = chain(*(iter(ld) for ld in val_loader))
    with torch.no_grad():
        for batch in val_iter:
            batch = batch.to(device, non_blocking=True)
            pred = model(batch)
            if pred.dim() == 1:
                pred = pred.unsqueeze(1)
            v_loss, v_parts = surface_only_loss(batch, pred)

            n = batch.x.size(0)
            data_loss = v_parts["loss_data"]           # 데이터 MSE만 사용
            va_loss_sum += data_loss.item() * n   # total 대신 data loss만 누적
            va_nodes    += n

            va_data += v_parts["loss_data"].item() * n
            va_tv   += v_parts["loss_tv"].item()   * n
            va_lap  += v_parts["loss_lap"].item()  * n

    val_loss = va_loss_sum / max(va_nodes, 1)
    val_data = va_data / max(va_nodes, 1)
    val_tv   = va_tv   / max(va_nodes, 1)
    val_lap  = va_lap  / max(va_nodes, 1)

    # -------- W&B logging --------
    current_lr = optimizer.param_groups[0]["lr"]
    wandb.log({
        "epoch": epoch,
        "lr": current_lr,
        "train/loss": train_loss,
        "train/loss_data": train_data,
        "train/loss_tv": train_tv,
        "train/loss_lap": train_lap,
        "val/loss": val_loss,
        "val/loss_data": val_data,
        "val/loss_tv": val_tv,
        "val/loss_lap": val_lap,
    })

    if epoch % PRINT_EVERY == 0 or epoch == 1:

        print(f"[{epoch:03d}] train {train_loss:.6f} | val {val_loss:.6f} | lr {current_lr:.2e}")

        # -------- best model 저장 & W&B 업로드 --------
        if val_loss < best_val:
            best_val = val_loss
            torch.save({"model": model.state_dict(),
                         "optimizer": optimizer.state_dict(),
                         "epoch": epoch,
                         "val_loss": best_val}, best_path)
            # "config": wandb.config}, best_path)
            # 파일을 W&B에 첨부 (Artifacts가 필요 없으면 이걸로 충분)
            artifact = wandb.Artifact(
            name="best_model",
            type="model",
            metadata={"val_loss": float(best_val)}
            )
            artifact.add_file(best_path, name="best_surface_only.pt")
            wandb.log_artifact(artifact)  # <- 이 한 줄이면 끝

        # run 종료
        wandb.finish()


# Validation/Visualization

In [None]:
import pyvista as pv
import numpy as np
import torch
import re

model.eval()

# 고정 컬러범위(자동으로 쓰고 싶으면 vmin=vmax=None)
vmin, vmax = -0.5, 0.0

def _to_numpy(x):
    if torch.is_tensor(x):
        return x.detach().cpu().numpy()
    return np.asarray(x)

def _collect_per_node_scalars(sample, device, N_expected):
    """
    sample 안에서 길이 N_expected 인 1D 스칼라(또는 (N,1))들을 찾아 dict로 반환.
    y가 (N,k)면 각 열을 y[:,i]로 분해해서 y_col_i 로 넣음.
    """
    cands = {}

    # x, y 우선
    if hasattr(sample, 'y') and sample.y is not None:
        y = sample.y.to(device)
        if y.dim() == 1 and y.shape[0] == N_expected:
            cands['y'] = _to_numpy(y)
        elif y.dim() == 2 and y.shape[0] == N_expected:
            for i in range(y.shape[1]):
                arr = y[:, i]
                if arr.dim() == 1:
                    cands[f'y_col_{i}'] = _to_numpy(arr)
                elif arr.dim() == 2 and arr.shape[1] == 1:
                    cands[f'y_col_{i}'] = _to_numpy(arr[:, 0])

    # 일반 속성들 순회
    for name, val in vars(sample).items():
        # 이미 처리한 y/x 제외
        if name in ('y', 'x'):
            continue
        try:
            arr = val.to(device) if torch.is_tensor(val) else val
        except Exception:
            continue

        if torch.is_tensor(arr) or isinstance(arr, np.ndarray):
            arr_np = _to_numpy(arr)
            if arr_np.ndim == 1 and arr_np.shape[0] == N_expected:
                cands[name] = arr_np
            elif arr_np.ndim == 2 and arr_np.shape[0] == N_expected and arr_np.shape[1] == 1:
                cands[name] = arr_np[:, 0]

        # dict 스타일(node_data 등)
        if isinstance(val, dict):
            for k, v in val.items():
                arr2 = v.to(device) if torch.is_tensor(v) else v
                if torch.is_tensor(arr2) or isinstance(arr2, np.ndarray):
                    arr2_np = _to_numpy(arr2)
                    if arr2_np.ndim == 1 and arr2_np.shape[0] == N_expected:
                        cands[f'{name}.{k}'] = arr2_np
                    elif arr2_np.ndim == 2 and arr2_np.shape[0] == N_expected and arr2_np.shape[1] == 1:
                        cands[f'{name}.{k}'] = arr2_np[:, 0]

    return cands

def _select_gt_pressure(cands, p_pred):
    """
    이름 매칭(press|pressure|^p$) 우선,
    없으면 |corr| 최대 후보 선택.
    """
    if not cands:
        raise RuntimeError("샘플에서 노드별 스칼라 후보를 찾지 못했습니다.")

    # 1) 이름 우선 매칭
    pat = re.compile(r'(?:^|[_.-])(p|press|pressure)(?:$|[_.-])', re.IGNORECASE)
    name_matches = [k for k in cands.keys() if pat.search(k)]
    if len(name_matches) == 1:
        k = name_matches[0]
        print(f"GT pressure chosen by name: '{k}'")
        return cands[k], k
    elif len(name_matches) > 1:
        # 이름 후보가 여러 개라면, corr로 좁힘
        name_matches = name_matches

    # 2) 상관계수 최대(|r|)
    keys = name_matches if name_matches else list(cands.keys())
    def _corr(a, b):
        a = np.asarray(a).reshape(-1)
        b = np.asarray(b).reshape(-1)
        if a.std() < 1e-12 or b.std() < 1e-12:
            return 0.0
        return float(np.corrcoef(a, b)[0, 1])

    scores = [(k, _corr(cands[k], p_pred)) for k in keys]
    scores_sorted = sorted(scores, key=lambda x: abs(x[1]), reverse=True)
    best_k, best_r = scores_sorted[0]
    print("Candidate correlations with p_pred (top 5):")
    for k, r in scores_sorted[:5]:
        print(f"  {k:>24s} : r = {r:+.4f}")
    print(f"GT pressure chosen by correlation: '{best_k}' (|r|={abs(best_r):.4f})")
    return cands[best_k], best_k


In [None]:
import pyvista as pv
import numpy as np
import torch
import re
import time
from torch_geometric.data import Batch

# Use existing globals if set, otherwise compute defaults later
vmin = globals().get('vmin', None)
vmax = globals().get('vmax', None)

# Helper to convert tensors/arrays to numpy
def _to_numpy(x):
    if torch.is_tensor(x):
        return x.detach().cpu().numpy()
    return np.asarray(x)


# --- MODIFIED INFERENCE CODE ---

# Assuming val_concat was created from a list of Data objects (patches)
# We need to perform inference on these patches individually or in batches.
# Reuse the val_loader created earlier, which already handles batching.

all_preds = []
all_nodes = []

print("Starting batched inference on validation patches...")
start = time.time()

# Ensure model is in eval mode and on the correct device
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device) # Ensure model is on device before inference

with torch.no_grad():
    for i, batch in enumerate(val_loader):
        batch = batch.to(device, non_blocking=True)

        # Perform inference
        out = model(batch)
        if isinstance(out, tuple):
            pred = out[0]
        else:
            pred = out

        if pred.dim() == 1:
            pred = pred.unsqueeze(1)

        # Collect predictions and original node indices (if available in batch)
        # If batch comes from DataLoader of patches, original node indices might not be preserved
        # However, for visualization, we need to map predictions back to original mesh positions.
        # For this example, let's assume the patches cover the relevant part of the original mesh
        # and we can concatenate predictions directly.
        # If patches overlap or represent subsets needing remapping, a more complex
        # aggregation/remapping logic would be needed here.

        # For simplicity and visualization, concatenate predictions directly.
        # This assumes the order of nodes in concatenated patches matches some order in original mesh
        # or that the visualization only needs relative positions within batches.
        # A better approach would involve storing and using original node indices from patches.
        all_preds.append(pred.detach().cpu()) # Move back to CPU to avoid memory issues during concatenation

        # For visualization, we need the original coordinates corresponding to these predictions.
        # Assuming batch.pos contains the positions relative to each patch or original positions
        # We collect positions along with predictions.
        all_nodes.append(batch.pos.detach().cpu()) # Collect positions

elapsed = time.time() - start
print(f'Batched inference time: {elapsed:.3f}s')

# Concatenate predictions and positions from all batches
p_pred = torch.cat(all_preds, dim=0).numpy()[:, 0]
coords = torch.cat(all_nodes, dim=0).numpy()
N = coords.shape[0] # Total number of nodes across all batches

# --- END MODIFIED INFERENCE CODE ---


# Assuming 'sample' here refers conceptually to the full validation graph
# from which val_patches were generated, for getting ground truth.
# If you need GT for each batch, you'd access batch.y inside the loop.
# For visualization of the *full* validation set prediction, we need the
# GT for the corresponding nodes.

# To get GT for visualization of the full val set, we need the original val_graphs
# and find the nodes that correspond to the collected coords.
# This requires a mapping which is not directly available from the DataLoader batch.
# A simpler approach for visualization is to plot one of the original val_graphs
# and its prediction obtained by running inference on its corresponding patches.

# Let's visualize the prediction and ground truth for the *first* graph in the validation set
# by running inference on its patches and then reassembling.

# --- Visualization for the first validation graph ---
if val_graphs:
    first_val_graph = val_graphs[0]
    # Find patches belonging to the first graph
    first_graph_patches = [p for p in val_patches if hasattr(p, 'graph_id') and p.graph_id == 0]

    if first_graph_patches:
        print(f"Visualizing results for the first validation graph (ID 0) with {len(first_graph_patches)} patches...")

        # Create a DataLoader for these specific patches
        first_graph_loader = DataLoader(
            first_graph_patches, batch_size=BATCH_SIZE, shuffle=False,
            pin_memory=pin_memory, num_workers=0, persistent_workers=False
        )

        patch_preds = []
        patch_nodes_pos = []
        patch_nodes_original_indices = [] # We need original indices to map predictions back to the full graph

        # Re-process patches to include original node indices if not already present
        # (Assuming original nodes are 0 to N-1 in the source graph)
        # If your patch generation already stores original indices, use them.
        # Otherwise, you might need to modify make_khop_patches or build_khop_patch_dataset
        # to store the 'subset' tensor (original indices) in each patch Data object.

        # For demonstration, let's assume patch.original_nodes contains the indices in the *original* graph
        # If not, you would need to modify patch creation or find another way to map.
        # Example modification during patch creation: p.original_nodes = subset

        # Assuming patches have an 'original_nodes' attribute containing indices from the full graph
        # If not, this part needs adjustment based on how patches were created.
        if not hasattr(first_graph_patches[0], 'original_nodes'):
             print("WARNING: Patches do not have 'original_nodes' attribute. Visualization mapping might be incorrect.")
             print("Consider adding 'p.original_nodes = subset' in make_khop_patches.")
             # Fallback: Use concatenated positions, assuming they are ordered correctly
             # This is less reliable for overlapping patches.
             all_patched_pos = torch.cat([p.pos.detach().cpu() for p in first_graph_patches], dim=0).numpy()
             # Create a dummy mapping based on concatenation order - USE WITH CAUTION
             unique_pos, inverse_indices = np.unique(all_patched_pos, axis=0, return_inverse=True)
             # This fallback is likely incorrect for complex patching schemes.
             # A robust solution requires storing original node indices in patches.

        collected_preds = torch.zeros(first_val_graph.num_nodes, pred.size(1), device='cpu')
        # Assuming average prediction for overlapping nodes
        node_counts = torch.zeros(first_val_graph.num_nodes, 1, device='cpu')


        model.eval() # Just in case
        with torch.no_grad():
            for i, batch in enumerate(first_graph_loader):
                 batch = batch.to(device, non_blocking=True)
                 out = model(batch)
                 if isinstance(out, tuple):
                    pred = out[0]
                 else:
                    pred = out
                 if pred.dim() == 1:
                    pred = pred.unsqueeze(1)

                 # Map predictions back to original graph indices
                 # This requires the original indices of nodes within the batch
                 # Assuming `batch.original_nodes` exists and holds the original indices
                 if hasattr(batch, 'original_nodes'):
                     original_indices = batch.original_nodes.cpu()
                     collected_preds[original_indices] += pred.detach().cpu()
                     node_counts[original_indices] += 1
                 else:
                     print(f"Warning: Batch {i} missing 'original_nodes'. Cannot map predictions to original graph for visualization.")
                     # If no mapping, cannot accurately visualize on the full graph structure.
                     # You might need to adjust `make_khop_patches` to store `subset`.
                     # For now, break if mapping is impossible.
                     collected_preds = None # Indicate mapping failed
                     break

        if collected_preds is not None:
            # Average predictions for overlapping nodes
            collected_preds = collected_preds / (node_counts + 1e-8)
            p_pred_full_graph = collected_preds.numpy()[:, 0]
            coords_full_graph = first_val_graph.pos.detach().cpu().numpy()
            p_true_full_graph = first_val_graph.y.detach().cpu().numpy()[:, 0] # Assuming y[:,0] is pressure

            # Calculate metrics on the full graph prediction
            mae_full = float(np.mean(np.abs(p_pred_full_graph - p_true_full_graph)))
            rmse_full = float(np.sqrt(np.mean((p_pred_full_graph - p_true_full_graph) ** 2)))
            print(f"[Full Graph 0] MAE: {mae_full:.6f}, RMSE: {rmse_full:.6f}")

            # Determine color limits for visualization
            if vmin is None or vmax is None:
                _min = float(min(p_pred_full_graph.min(), p_true_full_graph.min()))
                _max = float(max(p_pred_full_graph.max(), p_true_full_graph.max()))
                pad = 0.05 * (_max - _min + 1e-8)
                vmin_, vmax_ = _min - pad, _max + pad
            else:
                vmin_, vmax_ = vmin, vmax


            # PV data for the full graph
            cloud_true_full = pv.PolyData(coords_full_graph.copy()); cloud_true_full['pressure_true'] = p_true_full_graph
            cloud_pred_full = pv.PolyData(coords_full_graph.copy()); cloud_pred_full['pressure_pred'] = p_pred_full_graph

            # 1x2 plot + save
            pl = pv.Plotter(shape=(1, 2), off_screen=True, border=True)
            pl.subplot(0, 0)
            pl.add_text(f"Ground Truth (p) [Graph 0]", font_size=12)
            pl.add_mesh(cloud_true_full, scalars='pressure_true', cmap='viridis', clim=(vmin_, vmax_), point_size=6)

            pl.subplot(0, 1)
            pl.add_text("Prediction (p) [Graph 0]", font_size=12)
            pl.add_mesh(cloud_pred_full, scalars='pressure_pred', cmap='viridis', clim=(vmin_, vmax_), point_size=6)

            pl.link_views()
            pl.view_isometric()
            out_png = f"pred_vs_true_pressure_graph0.png"
            pl.show(screenshot=out_png)
            print(f"Saved side-by-side comparison to: {out_png}")

            # (Optional) Save for ParaView
            # cloud_true_full.save(f"pressure_true_graph0.vtp")
            # cloud_pred_full.save(f"pressure_pred_graph0.vtp")

        else:
            print("Skipping visualization for Graph 0 due to missing original node mapping.")

    else:
        print("No patches found for the first validation graph (ID 0). Skipping visualization.")

else:
    print("No validation graphs available to visualize.")


# --- Original visualization logic removed as it was causing OOM ---
# The original code attempted to visualize a large concatenated graph,
# which led to the OutOfMemoryError. Visualizing individual patches
# or reassembling predictions on the original graph is necessary.
# The code above now attempts to visualize the first full graph from its patches.