In [1]:
# 1. Environment & dependencies
import gc
import glob
import math
import os
import random
import sys
import time
from itertools import chain
from os import PathLike
from pathlib import Path
from typing import List, Tuple, Union

import numpy as np
import pyvista as pv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensor
import torch_geometric
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import GCNConv, GraphNorm, JumpingKnowledge, TransformerConv
from tqdm.auto import tqdm
import wandb
import pyg_lib

# Data and model settings
DATA_ROOT = Path('data')
TARGET_FIELD = 'static(p)_coeffMean'
USE_NORMALS = True

# Reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print("pyg-lib ok;", getattr(pyg_lib, "__version__", "ok"))
print('Torch:', torch.__version__)
print('CUDA:', torch.version.cuda)
print('PyVista:', pv.__version__)
print('PyG:', torch_geometric.__version__ if 'torch_geometric' in sys.modules else 'unknown')
print("py  :", sys.version.split()[0])  # ex) 3.11.x

pyg-lib ok; 0.4.0+pt25cu118
Torch: 2.5.1
CUDA: 11.8
PyVista: 0.46.1
PyG: 2.6.1
py  : 3.10.18


## Load Graphs

In [2]:
file_path = "graphs_cache_slim.pt"

# torch로 로드
model_state = torch.load(file_path, weights_only=False)

In [3]:
train_graphs, val_graphs = model_state

In [4]:
def to_pyg(graphs):
    out = []
    for g in tqdm(graphs, desc='to_pyg'):
        if isinstance(g, dict):
            x = torch.tensor(g['x'], dtype=torch.float32)
            edge_index = torch.tensor(g['edge_index'], dtype=torch.long)
            y = torch.tensor(g['y'], dtype=torch.float32)
            out.append(Data(x=x, edge_index=edge_index, y=y))
        else:
            out.append(g)
    return out

train_graphs = to_pyg(train_graphs)
val_graphs   = to_pyg(val_graphs)

# PyG Data로 변환된 뒤에 한 번만 실행
for lst in (train_graphs, val_graphs):
    for d in lst:
        if getattr(d, 'pos', None) is None:
            d.pos = d.x[:, :3].contiguous()   # x = [xyz,(normals...)]

to_pyg:   0%|          | 0/3 [00:00<?, ?it/s]

to_pyg:   0%|          | 0/1 [00:00<?, ?it/s]

## Utilities

In [5]:
# 공용 옵션(줄 폭/갱신간격/TTY 아닌 환경 자동 비활성화)
_TQDM_KW = dict(
    ncols=100,                 # 줄바꿈 방지 (원하면 120~140)
    dynamic_ncols=False,       # 고정 폭이 깔끔
    mininterval=0.25,          # 너무 잦은 갱신 방지
    smoothing=0.1,
    bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt} {postfix}",
    disable=not sys.stdout.isatty()  # 로그 캡처/파일일 땐 자동 끔
)

def _bar(total, desc, position, leave):
    return tqdm(total=total, desc=desc, position=position, leave=leave, **_TQDM_KW)


def _mkbar(total, desc, position=0, leave=False, progress=True):
    if not progress:
        return None  # 진행바 비활성화
    return tqdm(
        total=total, desc=desc, position=position, leave=leave,
        ncols=100, dynamic_ncols=False, mininterval=0.25, smoothing=0.1,
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt} {postfix}"
    )

def _update(bar, n=1, postfix=None):
    if bar is None: return
    if postfix is not None:
        if isinstance(postfix, dict): bar.set_postfix(postfix)
        else: bar.set_postfix_str(str(postfix))
    bar.update(n)

def _close(bar):
    if bar is not None: bar.close()

In [6]:
def build_csr_from_edge_index_no_torch_sparse(edge_index: Tensor, num_nodes: int,
                                              make_undirected: bool=True) -> Tuple[Tensor, Tensor]:
    assert edge_index.dtype == torch.long, f"dtype must be torch.long, got {edge_index.dtype}"
    assert edge_index.dim() == 2 and edge_index.size(0) == 2, f"edge_index shape must be [2, E], got {tuple(edge_index.shape)}"

    row0 = edge_index[0].contiguous()
    col0 = edge_index[1].contiguous()

    if make_undirected:
        row = torch.cat([row0, col0], dim=0)
        col = torch.cat([col0, row0], dim=0)
    else:
        row, col = row0, col0

    assert row.numel() == col.numel(), f"row({row.numel()}) and col({col.numel()}) must match"

    # 행-열 키로 정렬
    stride = num_nodes
    key = row * stride + col
    perm = torch.argsort(key)
    row = row[perm]
    col = col[perm]

    deg = torch.bincount(row, minlength=num_nodes)
    rowptr = torch.empty(num_nodes + 1, dtype=torch.long)
    rowptr[0] = 0
    rowptr[1:] = torch.cumsum(deg, dim=0)
    assert col.numel() == rowptr[-1].item(), "CSR col length must equal rowptr[-1]"
    return rowptr, col.contiguous()



from typing import List

@torch.no_grad()
def batched_khop_cover_pure_with_progress(
    rowptr, col, num_nodes, *,
    min_nodes=4096, k0=3, k_max=5, seeds_per_iter=512,
    early_cut_ratio=1.10, overlap_ratio=0.2, seed=1234,
    progress=True, gid=0
):
    g = torch.Generator(device='cpu'); g.manual_seed(seed)
    uncovered = torch.ones(num_nodes, dtype=torch.bool)
    order = torch.randperm(num_nodes, generator=g)
    ptr = 0; patches = []

    # 같은 라인 번호 고정(깨끗한 화면 유지)
    bar_scan = _mkbar(num_nodes, f"[g{gid}] scan", position=2, leave=False, progress=progress)
    bar_cov  = _mkbar(num_nodes, f"[g{gid}] covered", position=1, leave=False, progress=progress)
    prev_cov = 0

    def neighbors_of(frontier_idx):
        starts = rowptr[frontier_idx]; ends = rowptr[frontier_idx+1]
        counts = (ends - starts)
        if counts.sum().item() == 0: return torch.empty(0, dtype=torch.long)
        off = torch.empty_like(counts); off[0]=0
        if counts.numel()>1: off[1:] = torch.cumsum(counts[:-1], dim=0)
        total = counts.sum().item()
        buf = torch.empty(total, dtype=torch.long)
        for s,e,o in zip(starts.tolist(), ends.tolist(), off.tolist()):
            if e>s: buf[o:o+(e-s)] = col[s:e]
        return torch.unique(buf)

    while ptr < num_nodes and uncovered.any():
        batch = []; scanned=0
        while len(batch) < seeds_per_iter and ptr < num_nodes:
            u = order[ptr].item(); ptr += 1; scanned += 1
            if uncovered[u]: batch.append(u)
        if scanned: _update(bar_scan, scanned)
        if not batch: break
        batch = torch.tensor(batch, dtype=torch.long)

        mask = torch.zeros(num_nodes, dtype=torch.bool); frontier = batch; k=0
        while True:
            mask[frontier] = True
            if k>=k0 and mask.sum().item() >= int(min_nodes*early_cut_ratio): break
            if k==k_max: break
            neighs = neighbors_of(frontier)
            if neighs.numel()==0: break
            cand = neighs[~mask[neighs]]
            if cand.numel()==0: break
            frontier = cand; k += 1

        chosen = mask.nonzero(as_tuple=False).view(-1)
        if chosen.numel() > min_nodes:
            perm = torch.randperm(chosen.numel(), generator=g)[:min_nodes]
            chosen = chosen[perm]
        patches.append(chosen)

        if chosen.numel()>0:
            take = torch.rand(chosen.size(0), generator=g) > overlap_ratio
            uncovered[chosen[take]] = False

        now_cov = (num_nodes - uncovered.sum().item())
        _update(bar_cov, now_cov - prev_cov, postfix=f"patches={len(patches)} k={k}")
        prev_cov = now_cov

    _close(bar_scan); _close(bar_cov)
    return patches




from torch_geometric.data import Data

def fast_make_subgraph_pure(g: Data, node_idx: Tensor) -> Data:
    node_idx = node_idx.unique()
    mask = torch.zeros(g.num_nodes, dtype=torch.bool)
    mask[node_idx] = True

    ei = g.edge_index
    e_mask = mask[ei[0]] & mask[ei[1]]
    sub_edge = ei[:, e_mask]

    new_id = torch.full((g.num_nodes,), -1, dtype=torch.long)
    new_id[mask] = torch.arange(mask.sum(), dtype=torch.long)
    sub_edge = new_id[sub_edge]

    sub = Data()
    # 원본 E(엣지 수)
    E = ei.size(1)

    for k, v in g.items():
        if k == "edge_index":
            sub.edge_index = sub_edge
            continue

        if torch.is_tensor(v):
            # 1) node-level: [N, ...]
            if v.size(0) == g.num_nodes:
                sub[k] = v[mask]
                continue
            # 2) edge-level: [E, ...]
            if v.size(0) == E:
                sub[k] = v[e_mask]
                continue
            # 3) graph-level scalar/벡터 (크기 1 또는 배치 없음)
            if v.numel() == v.shape[0] and v.shape[0] == 1:
                sub[k] = v
                continue
            # 4) 그 외(예: y_graph 같은 그래프 레벨 텐서)
            #   - shape이 [G, ...]로 오면 그대로 둡니다(단일 그래프 가정).
            sub[k] = v
        else:
            sub[k] = v
            
    return sub


from os import PathLike
from typing import Union, List
import torch
from torch import Tensor
AnyPath = Union[str, bytes, PathLike]

def save_indices_packed(patches: List[Tensor], save_path: AnyPath, *, progress=True, position=3):
    sizes = [p.numel() for p in patches]
    ptr = torch.zeros(len(patches) + 1, dtype=torch.long)
    if sizes: ptr[1:] = torch.cumsum(torch.tensor(sizes, dtype=torch.long), dim=0)
    nodes = torch.empty(ptr[-1].item(), dtype=torch.long)
    off = 0

    bar = _mkbar(len(patches), "[cache] save patches", position=position, leave=False, progress=progress)
    for p in patches:
        n = p.numel(); nodes[off:off+n] = p; off += n
        _update(bar, 1)
    _close(bar)
    torch.save({"ptr": ptr, "nodes": nodes}, save_path)

def load_indices_packed(save_path: AnyPath, *, progress=True, position=3) -> List[Tensor]:
    obj = torch.load(save_path, map_location='cpu')
    ptr, nodes = obj["ptr"], obj["nodes"]
    out: List[Tensor] = []
    bar = _mkbar(ptr.numel()-1, "[cache] load patches", position=position, leave=False, progress=progress)
    for i in range(ptr.numel()-1):
        lo, hi = ptr[i].item(), ptr[i+1].item()
        out.append(nodes[lo:hi].clone())
        _update(bar, 1)
    _close(bar)
    return out

In [7]:
def cluster_patch_loader_pure(
    graphs: List[Data],
    min_nodes=4096, save_root=None,
    batch_size=4, shuffle=True, num_workers=0,
    k0=3, k_max=5, seeds_per_iter=1024,
    early_cut_ratio=1.10, overlap_ratio=0.2, seed=42,
    *, progress=True, return_list=False
) -> Tuple[iter, int]:

    loaders = []
    len_epoch = 0
    bar_graphs = _mkbar(len(graphs), "[graphs]", position=0, leave=True, progress=progress)

    for gid, g in enumerate(graphs):
        cache_dir = Path(save_root) / f"graph_{gid}" if save_root else None
        packed = cache_dir / "patch_indices.pth" if cache_dir else None

        if cache_dir and packed.exists():
            patches = load_indices_packed(packed, progress=progress, position=3)
        else:
            # CSR (여긴 표시 1칸만 점유)
            bar_csr = _mkbar(1, f"[g{gid}] build CSR", position=3, leave=False, progress=progress)
            rowptr, col = build_csr_from_edge_index_no_torch_sparse(
                g.edge_index.cpu().long(), g.num_nodes, make_undirected=True
            )
            _update(bar_csr, 1); _close(bar_csr)

            patches = batched_khop_cover_pure_with_progress(
                rowptr=rowptr, col=col, num_nodes=g.num_nodes,
                min_nodes=min_nodes, k0=k0, k_max=k_max, seeds_per_iter=seeds_per_iter,
                early_cut_ratio=early_cut_ratio, overlap_ratio=overlap_ratio, seed=seed,
                progress=progress, gid=gid
            )
            if cache_dir:
                cache_dir.mkdir(parents=True, exist_ok=True)
                save_indices_packed(patches, packed, progress=progress, position=3)

        # Subgraph build (한 줄)
        bar_subs = _mkbar(len(patches), f"[g{gid}] make subgraphs", position=3, leave=False, progress=progress)
        subs = []
        for idx in patches:
            subs.append(fast_make_subgraph_pure(g, idx))
            _update(bar_subs, 1)
        _close(bar_subs)

        ld = DataLoader(subs, batch_size=batch_size, shuffle=shuffle,
                        num_workers=num_workers, pin_memory=True)
        loaders.append(ld)
        _update(bar_graphs, 1)
        len_epoch += len(ld)

    _close(bar_graphs)

    if return_list:
        return loaders, len_epoch
    
    else:
        return chain(*loaders), len_epoch

In [8]:
# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("CUDA cache cleared.")

# Collect garbage
gc.collect()
print("Garbage collected.")

CUDA cache cleared.
Garbage collected.


In [10]:
train_loader, len_epoch = cluster_patch_loader_pure(
    train_graphs, save_root="cache/train",
    min_nodes=3072, batch_size=24, num_workers=4,
    k0=3, k_max=5, seeds_per_iter=1024,
    early_cut_ratio=1.08, overlap_ratio=0.25, seed=42,
    progress=True, return_list=True
)

val_loader, len_val  = cluster_patch_loader_pure(
    val_graphs,
    save_root="cache/val",
    min_nodes=3072, batch_size=24, num_workers=4,
    k0=3, k_max=5, seeds_per_iter=1024,
    early_cut_ratio=1.08, overlap_ratio=0.25, seed=42,
    progress=True, return_list=True
)


print("len_epoch:", len_epoch, "len_val:", len_val)


[graphs]:   0%|                                                                         | 0/3 ?it/s 

[g0] build CSR:   0%|                                                                   | 0/1 ?it/s 

[g0] scan:   0%|                                                                  | 0/1292868 ?it/s 

[g0] covered:   0%|                                                               | 0/1292868 ?it/s 

[cache] save patches:   0%|                                                           | 0/663 ?it/s 

[g0] make subgraphs:   0%|                                                            | 0/663 ?it/s 

[g1] build CSR:   0%|                                                                   | 0/1 ?it/s 

[g1] scan:   0%|                                                                   | 0/962658 ?it/s 

[g1] covered:   0%|                                                                | 0/962658 ?it/s 

[cache] save patches:   0%|                                                           | 0/494 ?it/s 

[g1] make subgraphs:   0%|                                                            | 0/494 ?it/s 

[g2] build CSR:   0%|                                                                   | 0/1 ?it/s 

[g2] scan:   0%|                                                                  | 0/1268335 ?it/s 

[g2] covered:   0%|                                                               | 0/1268335 ?it/s 

[cache] save patches:   0%|                                                           | 0/651 ?it/s 

[g2] make subgraphs:   0%|                                                            | 0/651 ?it/s 

[graphs]:   0%|                                                                         | 0/1 ?it/s 

[g0] build CSR:   0%|                                                                   | 0/1 ?it/s 

[g0] scan:   0%|                                                                  | 0/1131049 ?it/s 

[g0] covered:   0%|                                                               | 0/1131049 ?it/s 

[cache] save patches:   0%|                                                           | 0/580 ?it/s 

[g0] make subgraphs:   0%|                                                            | 0/580 ?it/s 

len_epoch: 77 len_val: 25
