# ⚡ Training Workflow (1–14)
이 구간은 빠르게 데이터 일부로 end-to-end 학습과 검증을 수행하는 스모크 테스트 플로우입니다. 아래 순서대로 각 셀을 실행하세요.

# 1. Import required libraries


In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
import sys, math, json, random, contextlib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torch_geometric.datasets import AirfRANS
from torch_geometric.data import Data, Batch
from matplotlib.tri import Triangulation
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch_geometric.data import Data
from navier_stokes_physics_loss import NavierStokesPhysicsLoss
from airfrans_utils import prepare_airfrans_graph_for_physics, estimate_node_area, build_bc_masks_airfrans
import contextlib
import wandb  
from torch.cuda.amp import GradScaler, autocast

def get_lr(optim):
    return optim.param_groups[0].get('lr', None)


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(42)
print('SmokeTest | PyTorch:', torch.__version__, '| CUDA?', torch.cuda.is_available())

SmokeTest | PyTorch: 2.8.0+cu128 | CUDA? True


# Configuration

In [2]:
# 2) Configuration (minimal for smoke)
from dataclasses import dataclass, asdict

@dataclass
class SmokeCfg:
    seed: int = 42
    task: str = 'scarce'
    root: str = 'Dataset'
    # subsample graph count for smoke
    limit_train: int = 180
    limit_val: int = 20

    # training
    batch_size: int = 19
    epochs: int = 100
    hidden: int = 128
    layers: int = 7
    lr: float = 4e-4
    weight_decay: float = 1e-2  # typical AdamW wd
    betas: tuple[float, float] = (0.9, 0.95)
    eps: float = 1e-8
    amp: bool = False

    # lr scheduler: 'cosine', 'cosine_warm_restarts', 'reduce_on_plateau', or None
    lr_scheduler: str = 'cosine'
    # cosine params
    cosine_T_max: int = 80  # epochs
    cosine_eta_min: float = 1e-6
    # warm restarts params
    wr_T_0: int = 10
    wr_T_mult: int = 1
    wr_eta_min: float = 1e-6
    # reduce on plateau params
    rop_factor: float = 0.5
    rop_patience: int = 5
    rop_min_lr: float = 1e-6

    # Physics-Informed Loss Configuration
    # =====================================
    # Curriculum learning schedule
    ramp_start_epoch: int = 30              # Epoch to start ramping physics losses
    ramp_epochs: int = 80                   # Number of epochs to ramp up
    ramp_mode: str = 'linear'               # 'linear' or 'cosine'
    
    # MSE/Data loss
    data_loss_weight: float = 1.0           # Weight for MSE loss (constant)
    
    # Continuity equation loss
    continuity_loss_weight: float = 0.05    # Initial continuity weight
    continuity_target_weight: float = 0.10  # Target continuity weight after ramp
    
    # Momentum equation loss  
    momentum_loss_weight: float = 0.05      # Initial momentum weight
    momentum_target_weight: float = 0.10    # Target momentum weight after ramp
    
    # Boundary condition loss
    bc_loss_weight: float = 0.05            # Weight for boundary condition loss
    
    # Physics parameters
    chord_length: float = 1.0               # Airfoil chord length
    nu_molecular: float = 1.5e-5            # Molecular viscosity
    dynamic_uref_from_data: bool = True     # Compute reference velocity from data
    dynamic_re_from_data: bool = True       # Compute Reynolds number from data
    uinf_from: str = 'inlet'                # 'inlet', 'farfield', or 'robust'
    
    # Stability & outlier control
    use_huber_for_physics: bool = True      # Use Huber loss for physics terms
    huber_delta: float = 0.05               # Huber loss delta parameter
    use_perimeter_norm_for_div: bool = True # Normalize divergence by perimeter
    div_area_floor_factor: float = 0.25     # Area floor factor for stability
    div_min_degree: int = 2                 # Minimum node degree for physics loss
    
    # Debug & monitoring
    physics_debug: bool = False              # Enable physics loss debugging
    physics_debug_level: int = 1            # Debug verbosity (1=summary, 2=detailed)
    physics_debug_every: int = 50           # Log debug info every N steps

    # Global Context & Attention Configuration
    use_global_tokens: bool = True           # Enable/disable global tokens
    num_global_tokens: int = 4               # Number of global tokens
    attention_heads: int = 4                 # Multi-head attention heads
    attention_layers: int = 7               # Number of transformer layers
    attention_dropout: float = 0.0           # Attention dropout rate
    use_cross_attention: bool = True         # Cross-attention between local and global
    global_pooling_type: str = 'attention'   # 'mean', 'max', 'attention', 'set2set'
    positional_encoding: bool = True         # Use positional encoding
    pos_encoding_max_len: int = 50000        # Max sequence length for positional encoding
    # Advanced attention options
    use_residual_attention: bool = True      # Residual connections in attention
    attention_normalization: str = 'layer'   # 'layer', 'batch', 'rms'
    temperature_scaling: bool = True         # Temperature scaling for attention
    attention_bias: bool = False             # Use bias in attention projections

    # W&B Artifact 관리
    use_wandb_artifacts: bool = False        # W&B artifact 사용 여부
    artifact_save_best_only: bool = True     # best 모델만 업로드
    artifact_save_interval: int = 50         # periodic 저장 간격 (epochs)
    
    # Checkpoint 관리
    ckpt_dir: str = "checkpoints"           # 로컬 체크포인트 디렉토리
    ckpt_interval: int = 5                  # 로컬 체크포인트 저장 간격
    
    # W&B 설정
    wandb_project: str = "storm"
    wandb_mode: str = "online"              # "online", "offline", "disabled"
    log_every_n_steps: int = -1             # 로깅 빈도
    log_epoch_only: bool = True             # Epoch 로깅만 사용

scfg = SmokeCfg()
set_seed(scfg.seed)
print('Smoke config:', asdict(scfg))

Smoke config: {'seed': 42, 'task': 'scarce', 'root': 'Dataset', 'limit_train': 180, 'limit_val': 20, 'batch_size': 19, 'epochs': 100, 'hidden': 128, 'layers': 7, 'lr': 0.0004, 'weight_decay': 0.01, 'betas': (0.9, 0.95), 'eps': 1e-08, 'amp': False, 'lr_scheduler': 'cosine', 'cosine_T_max': 80, 'cosine_eta_min': 1e-06, 'wr_T_0': 10, 'wr_T_mult': 1, 'wr_eta_min': 1e-06, 'rop_factor': 0.5, 'rop_patience': 5, 'rop_min_lr': 1e-06, 'ramp_start_epoch': 30, 'ramp_epochs': 80, 'ramp_mode': 'linear', 'data_loss_weight': 1.0, 'continuity_loss_weight': 0.05, 'continuity_target_weight': 0.1, 'momentum_loss_weight': 0.05, 'momentum_target_weight': 0.1, 'bc_loss_weight': 0.05, 'chord_length': 1.0, 'nu_molecular': 1.5e-05, 'dynamic_uref_from_data': True, 'dynamic_re_from_data': True, 'uinf_from': 'inlet', 'use_huber_for_physics': True, 'huber_delta': 0.05, 'use_perimeter_norm_for_div': True, 'div_area_floor_factor': 0.25, 'div_min_degree': 2, 'physics_debug': False, 'physics_debug_level': 1, 'physics_d

# 3. Load dataset indices

In [3]:
from torch_geometric.transforms import BaseTransform

class _PreparePhysics(BaseTransform):
    def __call__(self, data):
        # edge_attr_dxdy가 이미 있을 경우 build_edge_attr_dxdy는 생략되고 나머지만 수행
        return prepare_airfrans_graph_for_physics(data, verbose=False)

# 3) Load dataset indices (train/val split)
assert os.path.isdir(scfg.root), f"Dataset folder not found: {scfg.root}"
try:
    ds_train = AirfRANS(root=scfg.root, train=True, task=scfg.task, transform=_PreparePhysics())
    ds_test  = AirfRANS(root=scfg.root, train=False, task=scfg.task, transform=_PreparePhysics())
except TypeError:
    ds_train = AirfRANS(root=scfg.root, train=True, transform=_PreparePhysics())
    ds_test  = AirfRANS(root=scfg.root, train=False, transform=_PreparePhysics())

if scfg.task == 'scarce':
    # Scarce provides train only; create 90/10 split from ds_train
    n = len(ds_train)
    ids_all = list(range(n))
    random.Random(scfg.seed).shuffle(ids_all)
    ids_train = ids_all[:n]
    # limit if requested
    if scfg.limit_train > 0:
        ids_train = ids_train[:scfg.limit_train + scfg.limit_val]
        
    train_raw = Subset(ds_train, ids_train)
    val_raw = None
    
else:
    ids_train = list(range(min(scfg.limit_train+scfg.limit_val, len(ds_train))))
    ids_val = ids_train[-scfg.limit_val:] if scfg.limit_val>0 else []
    ids_train = ids_train[:scfg.limit_train] if scfg.limit_train>0 else ids_train
    train_raw = Subset(ds_train, ids_train)
    val_raw   = Subset(ds_train, ids_val) if ids_val else []

print('Loaded subset indices:', len(train_raw), 'train |', len(val_raw) if isinstance(val_raw, Subset) else 0, 'val/test')

Loaded subset indices: 200 train | 0 val/test


# 4. Load prebuilt graphs and ensure features (index-aligned with raw)

In [4]:
# 6) Load prebuilt graphs and ensure features (index-aligned with raw)
import glob, os, re
from utils import with_pos2, prep_graph, validate_edges, _prep_graph_for_norm

USE_PREBUILT = True
PREBUILT_ROOT = 'prebuilt_edges/scarce'  # change to your path if different
PREBUILT_TRAIN_DIR = f"{PREBUILT_ROOT}/train"
PREBUILT_TEST_DIR  = f"{PREBUILT_ROOT}/test"
DOWNSAMPLED_ROOT = 'downsampled_graphs/scarce'

# Load prebuilt edge graphs
train_edge_files = sorted(glob.glob(os.path.join(PREBUILT_TRAIN_DIR, 'graph_*.pt')))
val_edge_files   = sorted(glob.glob(os.path.join(PREBUILT_TEST_DIR,  'graph_*.pt')))
print(f"[prebuilt] found: {len(train_edge_files)} train and {len(val_edge_files)} val graphs under {PREBUILT_ROOT}")

# Load tensors and prepare
train_edges = []
for p in train_edge_files:
    d = torch.load(p, map_location='cpu', weights_only=False)
    if not isinstance(d, Data):
        d = Data(**d)
    train_edges.append(prep_graph(d))

val_edges = []
for p in val_edge_files:
    d = torch.load(p, map_location='cpu', weights_only=False)
    if not isinstance(d, Data):
        d = Data(**d)
    val_edges.append(prep_graph(d))

print(f"Graphs prepared. Example dims -> x: {train_edges[0].x.shape if len(train_edges)>0 else None}  edge_attr: {train_edges[0].edge_attr.shape if (len(train_edges)>0 and hasattr(train_edges[0],'edge_attr') and train_edges[0].edge_attr is not None) else None}")

validate_edges(train_edges, 'train_edges')

[prebuilt] found: 200 train and 0 val graphs under prebuilt_edges/scarce
Graphs prepared. Example dims -> x: torch.Size([16124, 5])  edge_attr: torch.Size([95510, 5])
[validate] train_edges: total=200 bad=0


# 5. Nnormalized datasets


In [5]:
if scfg.task == 'scarce':
    n = len(train_edges)
    n_train = int(n * 0.9)
    ids_all = list(range(n))
    random.Random(scfg.seed).shuffle(ids_all)
    ids_train = ids_all[:n_train]
    ids_val = ids_all[n_train:]

    # Use prebuilt graphs, not raw dataset
    train_edges_subset = [train_edges[i] for i in ids_train]
    val_edges_subset = [train_edges[i] for i in ids_val] if ids_val else []
else:
    train_edges_subset = train_edges
    val_edges_subset = val_edges

train_prepped = [_prep_graph_for_norm(g) for g in train_edges_subset]
val_prepped   = [_prep_graph_for_norm(g) for g in val_edges_subset] if isinstance(val_edges_subset, list) else []

# 8b) Fit scalers on train_prepped
if 'StandardScaler' not in globals():
    class StandardScaler:
        def __init__(self):
            self.mean = None
            self.std = None
        def fit(self, t: torch.Tensor):
            self.mean = t.mean(dim=0)
            self.std = t.std(dim=0).clamp_min(1e-8)
            return self
        def transform(self, t: torch.Tensor):
            return (t - self.mean) / self.std
        def inverse(self, t: torch.Tensor):
            return t * self.std + self.mean

# Concatenate node features/targets across train graphs for fitting
X_train = torch.cat([d.x for d in train_prepped if hasattr(d, 'x') and d.x is not None], dim=0)
Y_train = torch.cat([d.y for d in train_prepped if hasattr(d, 'y') and d.y is not None], dim=0)

x_scaler = StandardScaler().fit(X_train)
y_scaler = StandardScaler().fit(Y_train)

# 8c) Build normalized dataset wrappers
class NormalizedDataset(torch.utils.data.Dataset):
    def __init__(self, graphs, x_scaler, y_scaler):
        self.graphs = graphs
        self.x_scaler = x_scaler
        self.y_scaler = y_scaler
        
    def __len__(self):
        return len(self.graphs)
        
    def __getitem__(self, idx: int):
        d = self.graphs[idx]
        dm = Data(**{k: v for k, v in d})
        dm.x = self.x_scaler.transform(d.x)
        if hasattr(d, 'y') and d.y is not None:
            dm.y = self.y_scaler.transform(d.y)
        else:
            dm.y = d.y
            
        # DON'T attach norm params as graph attributes - they cause batching issues
        # Instead, we'll handle denormalization differently
        # dm.x_norm_params = {'mean': self.x_scaler.mean.clone(), 'scale': self.x_scaler.std.clone()}
        # dm.y_norm_params = {'mean': self.y_scaler.mean.clone(), 'scale': self.y_scaler.std.clone()} if dm.y is not None else None
        
        # Store scalers as module-level attributes for physics loss to access
        dm.has_norm = True  # Flag to indicate normalized data
        
        # Ensure edge_attr_dxdy is present (needed for physics loss)
        if hasattr(d, 'edge_attr_dxdy'):
            dm.edge_attr_dxdy = d.edge_attr_dxdy
        elif hasattr(d, 'edge_attr'):
            # If we have edge_attr but not edge_attr_dxdy, use the last 2 dims as dxdy
            if d.edge_attr.shape[1] >= 2:
                dm.edge_attr_dxdy = d.edge_attr[:, -2:]  # Last 2 columns should be dx, dy
            dm.edge_attr = d.edge_attr
        
        # Build BC masks properly
        from airfrans_utils import build_bc_masks_airfrans
        dm = build_bc_masks_airfrans(dm)
        
        # Ensure individual BC masks are present as attributes
        if hasattr(dm, 'bc_mask_dict'):
            for bc_type, mask in dm.bc_mask_dict.items():
                setattr(dm, f'is_{bc_type}', mask)
        else:
            # Fallback: create default masks if build_bc_masks_airfrans failed
            num_nodes = dm.x.size(0)
            # Use the normalized x for BC detection
            x_orig = d.x  # Use original (non-normalized) for BC detection
            
            # Wall nodes: distance_wall < threshold (column 2 of original x)
            if x_orig.size(1) > 2:
                wall_dist = x_orig[:, 2]
                dm.is_wall = (wall_dist < 1e-6)
            else:
                dm.is_wall = torch.zeros(num_nodes, dtype=torch.bool)
            
            # For AirfRANS, we typically don't have explicit inlet/outlet/farfield in the features
            # These would need to be inferred from position or other criteria
            dm.is_inlet = torch.zeros(num_nodes, dtype=torch.bool)
            dm.is_outlet = torch.zeros(num_nodes, dtype=torch.bool)
            dm.is_farfield = torch.zeros(num_nodes, dtype=torch.bool)
            
            # Simple heuristics for inlet/outlet/farfield based on position
            if hasattr(dm, 'pos'):
                x_coords = dm.pos[:, 0]
                y_coords = dm.pos[:, 1]
                
                # Inlet: leftmost boundary (x < -1)
                dm.is_inlet = (x_coords < -1.0) & ~dm.is_wall
                
                # Outlet: rightmost boundary (x > 2)
                dm.is_outlet = (x_coords > 2.0) & ~dm.is_wall
                
                # Farfield: top/bottom boundaries (|y| > 1)
                dm.is_farfield = (torch.abs(y_coords) > 1.0) & ~dm.is_wall & ~dm.is_inlet & ~dm.is_outlet
        
        return dm

train_norm = NormalizedDataset(train_prepped, x_scaler, y_scaler)
val_norm   = NormalizedDataset(val_prepped, x_scaler, y_scaler) if isinstance(val_prepped, list) and len(val_prepped) > 0 else []

# Debug BC mask creation for a single sample
test_single = train_norm[0]
print("Single sample BC check:")
print(f"  Total nodes: {test_single.x.size(0)}")

# Check original features that determine BC
if hasattr(test_single, 'x'):
    x_orig = train_prepped[0].x  # Original unnormalized
    print(f"  Original x shape: {x_orig.shape}")
    if x_orig.size(1) > 2:
        wall_dist = x_orig[:, 2]
        print(f"  Wall distance range: [{wall_dist.min():.3e}, {wall_dist.max():.3e}]")
        print(f"  Nodes with wall_dist < 1e-6: {(wall_dist < 1e-6).sum().item()}")

# Check the BC masks
for bc_type in ['wall', 'inlet', 'outlet', 'farfield']:
    mask_name = f'is_{bc_type}'
    if hasattr(test_single, mask_name):
        mask = getattr(test_single, mask_name)
        print(f"  {mask_name}: {mask.sum().item()} nodes ({mask.sum().item()/len(mask)*100:.1f}%)")

# Check position-based criteria if available
if hasattr(test_single, 'pos'):
    pos = test_single.pos
    print(f"\n  Position ranges:")
    print(f"    x: [{pos[:, 0].min():.2f}, {pos[:, 0].max():.2f}]")
    print(f"    y: [{pos[:, 1].min():.2f}, {pos[:, 1].max():.2f}]")



print('Prepared normalized datasets:', len(train_norm), 'train |', (len(val_norm) if isinstance(val_norm, NormalizedDataset) else len(val_norm)), 'val')
if len(train_prepped) > 0:
    print('Example dims -> x:', tuple(train_prepped[0].x.shape), '| edge_attr:', (tuple(train_prepped[0].edge_attr.shape) if hasattr(train_prepped[0], 'edge_attr') and train_prepped[0].edge_attr is not None else None))


Single sample BC check:
  Total nodes: 16059
  Original x shape: torch.Size([16059, 7])
  Wall distance range: [0.000e+00, 3.563e+00]
  Nodes with wall_dist < 1e-6: 1026
  is_wall: 1026 nodes (6.4%)
  is_inlet: 1645 nodes (10.2%)
  is_outlet: 1833 nodes (11.4%)
  is_farfield: 2289 nodes (14.3%)

  Position ranges:
    x: [-2.16, 4.23]
    y: [-1.63, 1.62]
Prepared normalized datasets: 180 train | 20 val
Example dims -> x: (16059, 7) | edge_attr: (94686, 5)


In [6]:
data = train_norm[0]

row, col = data.edge_index
deg = torch.bincount(row, minlength=data.num_nodes) + torch.bincount(col, minlength=data.num_nodes)
print("deg==0:", int((deg==0).sum()), " / ", data.num_nodes)
print("deg<2 :", int((deg<2).sum()))


deg==0: 88  /  16059
deg<2 : 88


# 6. DataLoaders


In [7]:
# Use true batching with PyG Batch.from_data_list so batch_size>1 works correctly

def collate_pyg(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return Batch.from_data_list(batch)

train_loader = DataLoader(train_norm, batch_size=scfg.batch_size, shuffle=True, num_workers=0, collate_fn=collate_pyg)
val_loader   = DataLoader(val_norm,   batch_size=scfg.batch_size, shuffle=False, num_workers=0, collate_fn=collate_pyg) if isinstance(val_norm, NormalizedDataset) else []
print('Loaders ready:', len(train_norm), 'train samples | batch_size =', scfg.batch_size)
print('Loaders ready:', len(val_norm), 'val samples | batch_size =', scfg.batch_size)

Loaders ready: 180 train samples | batch_size = 19
Loaders ready: 20 val samples | batch_size = 19


# Learning rate scheduler

In [8]:
# Learning Rate Scheduler 설정
def create_lr_scheduler(optimizer, config):
    """Configuration에 따라 적절한 LR scheduler를 생성합니다."""
    
    if config.lr_scheduler is None:
        print("🚫 Learning rate scheduler: None (constant LR)")
        return None
    
    elif config.lr_scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=config.cosine_T_max,
            eta_min=config.cosine_eta_min
        )
        print(f"📊 Learning rate scheduler: CosineAnnealingLR")
        print(f"   T_max: {config.cosine_T_max}, eta_min: {config.cosine_eta_min}")
        return scheduler
    
    elif config.lr_scheduler == 'cosine_warm_restarts':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=config.wr_T_0,
            T_mult=config.wr_T_mult,
            eta_min=config.wr_eta_min
        )
        print(f"🔄 Learning rate scheduler: CosineAnnealingWarmRestarts")
        print(f"   T_0: {config.wr_T_0}, T_mult: {config.wr_T_mult}, eta_min: {config.wr_eta_min}")
        return scheduler
    
    elif config.lr_scheduler == 'reduce_on_plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',  # validation loss를 minimize
            factor=config.rop_factor,
            patience=config.rop_patience,
            min_lr=config.rop_min_lr,
        )
        print(f"📉 Learning rate scheduler: ReduceLROnPlateau")
        print(f"   factor: {config.rop_factor}, patience: {config.rop_patience}, min_lr: {config.rop_min_lr}")
        return scheduler
    
    else:
        print(f"❌ Unknown scheduler: {config.lr_scheduler}, using None")
        return None


# Loss

In [9]:
# 13) Enhanced Train/Val epoch routines with Physics Loss

mse_loss_fn = nn.MSELoss()

def compute_loss_with_physics(predictions, targets, data, loss_fn=None, *, step: int | None = None):
    """Compute loss using physics-informed loss function or fallback to MSE
    Returns a differentiable scalar loss tensor for backward as first value,
    and a lightweight dict of float metrics for logging as second value.
    """
    if loss_fn is not None:
        try:
            # Always let the physics loss handle batched Data (PyG batches are a big disjoint graph)
            loss_dict = loss_fn(predictions, targets, data=data, step=step)

            # Ensure total_loss is a Tensor usable for backward
            total_loss = loss_dict.get('total_loss')
            if not isinstance(total_loss, torch.Tensor):
                total_loss = torch.as_tensor(total_loss, dtype=predictions.dtype, device=predictions.device)

            # Prepare a logging-friendly dict (floats only) to avoid holding graph refs
            log_dict = {}
            for k, v in loss_dict.items():
                if isinstance(v, torch.Tensor):
                    try:
                        log_dict[k] = float(v.detach().item())
                    except Exception:
                        # Fallback if it's not 0-dim
                        log_dict[k] = float(v.detach().mean().item())
                else:
                    log_dict[k] = float(v)

            return total_loss, log_dict
        except Exception as e:
            print(f"Warning: Physics loss failed ({e}), falling back to MSE")
            mse_loss = mse_loss_fn(predictions, targets)
            return mse_loss, {
                'mse_loss': float(mse_loss.detach().item()), 
                'continuity_loss': 0.0, 
                'momentum_loss': 0.0,
                'bc_loss': 0.0,  # ← BC loss 추가
                'total_loss': float(mse_loss.detach().item())
            }
    else:
        # Fallback to simple MSE
        mse_loss = mse_loss_fn(predictions, targets)
        return mse_loss, {
            'mse_loss': float(mse_loss.detach().item()), 
            'bc_loss': 0.0,  # ← BC loss 추가
            'total_loss': float(mse_loss.detach().item())
        }


@torch.no_grad()
def run_epoch(loader, model, device, scaler=None, desc: str = 'val', loss_fn=None):
    model.eval()
    total_losses = []; mse_losses = []; continuity_losses = []; momentum_losses = []
    bc_losses = []  # ← BC loss 리스트 추가
    cont_w_used_hist, mom_w_used_hist = [], []

    if loader is None or (isinstance(loader, list) and len(loader)==0):
        return float('nan'), {}

    steps = len(loader)
    pbar = tqdm(total=steps, desc=desc, leave=False)

    for batch in loader:
        try:
            if batch is None:
                pbar.update(1); continue

            b = batch.to(device)
            with (autocast(enabled=(scfg.amp and torch.cuda.is_available()))
                  if torch.cuda.is_available() else contextlib.nullcontext()):
                out = model(b)
                _, loss_dict = compute_loss_with_physics(out, b.y, b, loss_fn=loss_fn, step=None)

            total_losses.append(loss_dict['total_loss'])
            mse_losses.append(loss_dict['mse_loss'])
            continuity_losses.append(loss_dict.get('continuity_loss', 0.0))
            momentum_losses.append(loss_dict.get('momentum_loss', 0.0))
            bc_losses.append(loss_dict.get('bc_loss', 0.0))  # ← BC loss 수집
            if 'cont_weight_used' in loss_dict: cont_w_used_hist.append(loss_dict['cont_weight_used'])
            if 'mom_weight_used'  in loss_dict: mom_w_used_hist.append(loss_dict['mom_weight_used'])

            postfix = {"total": f"{loss_dict['total_loss']:.4e}"}
            if 'continuity_loss' in loss_dict: postfix["cont"] = f"{loss_dict['continuity_loss']:.4e}"
            if 'momentum_loss' in loss_dict:   postfix["momentum"] = f"{loss_dict['momentum_loss']:.4e}"
            if 'bc_loss' in loss_dict:         postfix["bc"] = f"{loss_dict['bc_loss']:.4e}"  # ← BC loss 표시
            pbar.set_postfix(postfix)

        finally:
            pbar.update(1)

    pbar.close()

    avg_losses = {
        'total_loss': np.mean(total_losses) if total_losses else float('nan'),
        'mse_loss': np.mean(mse_losses) if mse_losses else float('nan'),
        'continuity_loss': np.mean(continuity_losses) if continuity_losses else float('nan'),
        'momentum_loss': np.mean(momentum_losses) if momentum_losses else float('nan'),
        'bc_loss': np.mean(bc_losses) if bc_losses else float('nan'),  # ← BC loss 평균
    }
    if cont_w_used_hist: avg_losses['cont_weight_used'] = float(np.mean(cont_w_used_hist))
    if mom_w_used_hist:  avg_losses['mom_weight_used']  = float(np.mean(mom_w_used_hist))
    return avg_losses['total_loss'], avg_losses



def train_epoch(loader, model, optim, device, scaler, desc: str = 'train',
                loss_fn=None, global_step_start: int = 0, scheduler=None, scheduler_step_mode: str = "epoch",
                log_every_n_steps: int = -1):  # -1로 설정하면 step 로깅 비활성화
    model.train()
    total_losses, mse_losses, continuity_losses, momentum_losses = [], [], [], []
    bc_losses = []
    cont_w_used_hist, mom_w_used_hist = [], []

    global_step = global_step_start
    steps = len(loader)
    pbar = tqdm(total=steps, desc=desc, leave=False)

    for batch_idx, batch in enumerate(loader):
        try:
            if batch is None:
                pbar.update(1); global_step += 1; continue

            b = batch.to(device)
            optim.zero_grad(set_to_none=True)

            use_scaler = (scaler is not None) and getattr(scaler, "is_enabled", lambda: False)()

            if use_scaler:
                with autocast(enabled=torch.cuda.is_available()):
                    out = model(b)
                    loss, loss_dict = compute_loss_with_physics(out, b.y, b, loss_fn=loss_fn, step=global_step)
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optim)
                scaler.update()
            else:
                with contextlib.nullcontext():
                    out = model(b)
                    loss, loss_dict = compute_loss_with_physics(out, b.y, b, loss_fn=loss_fn, step=global_step)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optim.step()

            if scheduler is not None and scheduler_step_mode == "step":
                try:
                    scheduler.step()
                except TypeError:
                    pass

            # 집계
            total_losses.append(loss_dict['total_loss'])
            mse_losses.append(loss_dict['mse_loss'])
            continuity_losses.append(loss_dict.get('continuity_loss', 0.0))
            momentum_losses.append(loss_dict.get('momentum_loss', 0.0))
            bc_losses.append(loss_dict.get('bc_loss', 0.0))
            if 'cont_weight_used' in loss_dict: cont_w_used_hist.append(loss_dict['cont_weight_used'])
            if 'mom_weight_used'  in loss_dict: mom_w_used_hist.append(loss_dict['mom_weight_used'])

            # === Step-level 로깅 완전 제거 또는 조건부 비활성화 ===
            if log_every_n_steps > 0 and (batch_idx % max(1, log_every_n_steps)) == 0:
                # Step-level 로깅을 원하는 경우에만 실행
                log_payload = {
                    "step": global_step,
                    "train/total": loss_dict['total_loss'],
                    "train/mse": loss_dict['mse_loss'],
                    "train/continuity": loss_dict.get('continuity_loss', 0.0),
                    "train/momentum": loss_dict.get('momentum_loss', 0.0),
                    "train/bc": loss_dict.get('bc_loss', 0.0),
                }
                if 'cont_weight_used' in loss_dict: log_payload["weight/cont_used"] = loss_dict['cont_weight_used']
                if 'mom_weight_used'  in loss_dict: log_payload["weight/mom_used"]  = loss_dict['mom_weight_used']
                lr_now = get_lr(optim)
                if lr_now is not None:
                    log_payload["lr"] = lr_now
                wandb.log(log_payload, step=global_step, commit=False)

            postfix = {"total": f"{loss_dict['total_loss']:.4e}",
                       "lr": f"{get_lr(optim):.2e}" if get_lr(optim) is not None else "n/a"}
            if 'continuity_loss' in loss_dict: postfix["cont"] = f"{loss_dict['continuity_loss']:.4e}"
            if 'momentum_loss' in loss_dict:   postfix["momentum"] = f"{loss_dict['momentum_loss']:.4e}"
            if 'bc_loss' in loss_dict:         postfix["bc"] = f"{loss_dict['bc_loss']:.4e}"
            pbar.set_postfix(postfix)

        finally:
            pbar.update(1)
            global_step += 1

    pbar.close()

    avg_losses = {
        'total_loss': np.mean(total_losses) if total_losses else float('nan'),
        'mse_loss': np.mean(mse_losses) if mse_losses else float('nan'),
        'continuity_loss': np.mean(continuity_losses) if continuity_losses else float('nan'),
        'momentum_loss': np.mean(momentum_losses) if momentum_losses else float('nan'),
        'bc_loss': np.mean(bc_losses) if bc_losses else float('nan'),
    }
    if cont_w_used_hist: avg_losses['cont_weight_used'] = float(np.mean(cont_w_used_hist))
    if mom_w_used_hist:  avg_losses['mom_weight_used']  = float(np.mean(mom_w_used_hist))

    return avg_losses['total_loss'], avg_losses, global_step

# Optuna configuration

In [10]:
# Add this cell after imports

# Install optuna if not already installed
# !pip install optuna optuna-dashboard

import optuna
from optuna.trial import TrialState
from optuna.visualization import plot_optimization_history, plot_param_importances
import joblib
from global_context_processor import EnhancedCFDModelWithGlobalContext


def objective(trial):
    """
    Optuna objective function for hyperparameter optimization
    Returns validation loss to minimize
    """

    # 1) Model Architecture Hyperparameters
    hidden_dim = trial.suggest_categorical('hidden_dim', [64, 128, 256, 512])
    num_layers = trial.suggest_int('num_layers', 3, 10)
    dropout_p = trial.suggest_float('dropout', 0.0, 0.5, step=0.05)

    # 2) Training Hyperparameters
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-1, log=True)
    batch_size = trial.suggest_categorical('batch_size', [1, 2, 4, 8])

    # 3) Optimizer Configuration
    beta1 = trial.suggest_float('beta1', 0.8, 0.99)
    beta2 = trial.suggest_float('beta2', 0.9, 0.999)
    eps = trial.suggest_float('eps', 1e-9, 1e-6, log=True)

    # 4) Physics Loss Hyperparameters
    continuity_weight = trial.suggest_float('continuity_weight', 0.001, 0.5, log=True)
    continuity_target_weight = trial.suggest_float(
        'continuity_target_weight', continuity_weight, 1.0
    )
    momentum_weight = trial.suggest_float('momentum_weight', 0.001, 0.5, log=True)
    momentum_target_weight = trial.suggest_float(
        'momentum_target_weight', momentum_weight, 1.0
    )
    bc_loss_weight = trial.suggest_float('bc_loss_weight', 0.001, 0.2, log=True)

    # 5) Curriculum Learning
    ramp_start_epoch = trial.suggest_int('ramp_start_epoch', 5, 20)
    ramp_epochs = trial.suggest_int('ramp_epochs', 10, 20)

    # 6) Global Context (if using attention model)
    use_global_tokens = trial.suggest_categorical('use_global_tokens', [True, False])
    if use_global_tokens:
        # suggest_int(name, low, high, step=None)도 가능하지만 여기선 카테고리로 명확히
        num_global_tokens = trial.suggest_categorical('num_global_tokens', [2, 4, 8])
        attention_heads = trial.suggest_categorical('attention_heads', [2, 4, 8])
        attention_layers = trial.suggest_categorical('attention_layers', [2, 4, 8])
        use_cross_attention = trial.suggest_categorical('use_cross_attention', [True, False])
        positional_encoding = trial.suggest_categorical('positional_encoding', [True, False])
        global_pooling_type = trial.suggest_categorical(
            'global_pooling_type', ['mean', 'max', 'attention', 'set2set']
        )
    else:
        # 사용 안 해도 config에 안전하게 채워 넣기 위한 기본값
        num_global_tokens = 0
        attention_heads = 4
        attention_layers = 2
        use_cross_attention = False
        positional_encoding = False
        global_pooling_type = 'mean'

    # 7) Learning Rate Scheduler
    lr_scheduler_type = trial.suggest_categorical(
        'lr_scheduler', ['cosine', 'cosine_warm_restarts', 'reduce_on_plateau', None]
    )

    # ---- Create config with suggested parameters ----
    config = SmokeCfg(
        hidden=hidden_dim,
        layers=num_layers,
        lr=lr,
        weight_decay=weight_decay,
        batch_size=batch_size,
        betas=(beta1, beta2),
        eps=eps,
        continuity_loss_weight=continuity_weight,
        continuity_target_weight=continuity_target_weight,
        momentum_loss_weight=momentum_weight,
        momentum_target_weight=momentum_target_weight,
        bc_loss_weight=bc_loss_weight,
        ramp_start_epoch=ramp_start_epoch,
        ramp_epochs=ramp_epochs,
        use_global_tokens=use_global_tokens,
        num_global_tokens=num_global_tokens,
        attention_heads=attention_heads,
        attention_layers=attention_layers,
        use_cross_attention=use_cross_attention,
        positional_encoding=positional_encoding,
        global_pooling_type=global_pooling_type,
        lr_scheduler=lr_scheduler_type,
        epochs=20,              # Shorter for hyperparameter search
        wandb_mode='disabled'   # Disable wandb during search
    )

    # ---- DataLoaders (batch size from trial) ----
    train_loader_trial = DataLoader(
        train_norm,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=collate_pyg
    )
    val_loader_trial = DataLoader(
        val_norm,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_pyg
    )

    # ---- Model ----
    model_trial = EnhancedCFDModelWithGlobalContext(
        node_feat_dim=7,
        edge_feat_dim=5,
        hidden_dim=config.hidden,
        output_dim=4,
        num_mp_layers=config.layers,
        dropout_p=dropout_p,
        config=config
    ).to(device)

    # ---- Optimizer ----
    optimizer_trial = torch.optim.AdamW(
        model_trial.parameters(),
        lr=config.lr,
        weight_decay=config.weight_decay,
        betas=config.betas,
        eps=config.eps
    )

    # ---- LR Scheduler ----
    scheduler_trial = create_lr_scheduler(optimizer_trial, config)

    # ---- Physics Loss ----
    loss_fn_trial = NavierStokesPhysicsLoss(
        data_loss_weight=getattr(config, 'data_loss_weight', 1.0),
        continuity_loss_weight=config.continuity_loss_weight,
        continuity_target_weight=config.continuity_target_weight,
        momentum_loss_weight=config.momentum_loss_weight,
        momentum_target_weight=config.momentum_target_weight,
        curriculum_ramp_steps=config.ramp_epochs * max(1, len(train_loader_trial)),
        ramp_start_step=config.ramp_start_epoch * max(1, len(train_loader_trial)),
        bc_loss_weight=config.bc_loss_weight,
        chord_length=getattr(config, 'chord_length', 1.0),
        dynamic_uref_from_data=getattr(config, 'dynamic_uref_from_data', False),
        dynamic_re_from_data=getattr(config, 'dynamic_re_from_data', False),
        nu_molecular=getattr(config, 'nu_molecular', 1.5e-5),
        use_huber_for_physics=getattr(config, 'use_huber_for_physics', False),
        huber_delta=getattr(config, 'huber_delta', 1.0),
        debug=False
    )

    # ---- Training loop ----
    scaler = GradScaler(enabled=False)
    best_val_loss = float('inf')
    patience_counter = 0
    max_patience = 5

    for epoch in range(config.epochs):
        # Train
        model_trial.train()
        global_step = epoch * max(1, len(train_loader_trial))

        for batch in train_loader_trial:
            if batch is None:
                continue
            batch = batch.to(device)
            optimizer_trial.zero_grad()

            predictions = model_trial(batch)
            loss, _ = compute_loss_with_physics(
                predictions, batch.y, batch,
                loss_fn=loss_fn_trial, step=global_step
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model_trial.parameters(), 1.0)
            optimizer_trial.step()
            global_step += 1

        # Validation
        model_trial.eval()
        val_losses = []
        with torch.no_grad():
            for batch in val_loader_trial:
                if batch is None:
                    continue
                batch = batch.to(device)
                predictions = model_trial(batch)
                _, loss_dict = compute_loss_with_physics(
                    predictions, batch.y, batch, loss_fn=loss_fn_trial
                )
                total = loss_dict.get('total_loss', None)
                if total is None:
                    # 혹시 키 이름이 다를 경우 대비
                    total = loss_dict.get('loss', None)
                if total is None:
                    continue
                # 텐서일 수 있으니 float로 변환
                try:
                    val_losses.append(float(total))
                except Exception:
                    val_losses.append(float(total.item()))

        avg_val_loss = float(np.mean(val_losses)) if len(val_losses) > 0 else float('inf')

        # Update scheduler
        if scheduler_trial is not None:
            if isinstance(scheduler_trial, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler_trial.step(avg_val_loss)
            else:
                scheduler_trial.step()

        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= max_patience:
                # print(f"[EarlyStop] epoch={epoch} best_val={best_val_loss:.6f}")
                break

        # Report to Optuna (for pruning)
        trial.report(avg_val_loss, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    # Clean up
    del model_trial
    del optimizer_trial
    torch.cuda.empty_cache()

    return best_val_loss


[OK] Enhanced Global Context & Attention Mechanism loaded!
Features: Multi-head attention, cross-attention, positional encoding, advanced pooling


In [11]:
# Run the hyperparameter optimization

# Create or load study
study_name = "airfrans_gnn_optimization"
storage_name = f"sqlite:///{study_name}.db"

# Create study with pruning
study = optuna.create_study(
    study_name=study_name,
    storage=storage_name,
    direction='minimize',
    pruner=optuna.pruners.MedianPruner(
        n_startup_trials=5,
        n_warmup_steps=5,
        interval_steps=1
    ),
    sampler=optuna.samplers.TPESampler(seed=42),
    load_if_exists=True
)

# Add default hyperparameters as the first trial (optional)
study.enqueue_trial({
    'hidden_dim': 128,
    'num_layers': 7,
    'dropout': 0.1,
    'lr': 4e-4,
    'weight_decay': 1e-2,
    'batch_size': 2,
    'beta1': 0.9,
    'beta2': 0.95,
    'eps': 1e-8,
    'continuity_weight': 0.05,
    'continuity_target_weight': 0.10,
    'momentum_weight': 0.05,
    'momentum_target_weight': 0.10,
    'bc_loss_weight': 0.05,
    'ramp_start_epoch': 10,
    'ramp_epochs': 10,
    'use_global_tokens': True,
    'num_global_tokens': 4,
    'attention_heads': 4,
    'lr_scheduler': 'cosine'
})

# Run optimization
n_trials = 100  # Number of trials to run
study.optimize(
    objective, 
    n_trials=n_trials,
    timeout=None,  # Can set timeout in seconds
    n_jobs=1,  # Use 1 for GPU, can increase for CPU-only
    gc_after_trial=True,
    show_progress_bar=True
)

print("\n" + "="*50)
print("Optimization Complete!")
print("="*50)
print(f"Number of finished trials: {len(study.trials)}")
print(f"Best trial value: {study.best_value:.6f}")
print("\nBest parameters:")
for key, value in study.best_params.items():
    print(f"  {key}: {value}")

[I 2025-09-20 23:07:21,205] A new study created in RDB with name: airfrans_gnn_optimization


  0%|          | 0/100 [00:00<?, ?it/s]

📊 Learning rate scheduler: CosineAnnealingLR
   T_max: 80, eta_min: 1e-06


  scaler = GradScaler(enabled=False)


[W 2025-09-20 23:07:56,636] Trial 0 failed with parameters: {'hidden_dim': 128, 'num_layers': 7, 'dropout': 0.1, 'lr': 0.0004, 'weight_decay': 0.01, 'batch_size': 2, 'beta1': 0.9, 'beta2': 0.95, 'eps': 1e-08, 'continuity_weight': 0.05, 'continuity_target_weight': 0.1, 'momentum_weight': 0.05, 'momentum_target_weight': 0.1, 'bc_loss_weight': 0.05, 'ramp_start_epoch': 10, 'ramp_epochs': 10, 'use_global_tokens': True, 'num_global_tokens': 4, 'attention_heads': 4, 'attention_layers': 4, 'use_cross_attention': True, 'positional_encoding': True, 'global_pooling_type': 'mean', 'lr_scheduler': 'cosine'} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\Kim\.conda\envs\pyg5090\Lib\site-packages\optuna\study\_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "C:\Users\Kim\AppData\Local\Temp\ipykernel_7816\2118916917.py", line 184, in objective
    loss.backward()
  File "c:\Users

KeyboardInterrupt: 

In [None]:
# Analyze and visualize results

# 1. Optimization History
fig = optuna.visualization.plot_optimization_history(study)
fig.show()

# 2. Parameter Importance
fig = optuna.visualization.plot_param_importances(study)
fig.show()

# 3. Parallel Coordinate Plot
fig = optuna.visualization.plot_parallel_coordinate(
    study, 
    params=['hidden_dim', 'num_layers', 'lr', 'continuity_weight', 'momentum_weight']
)
fig.show()

# 4. Slice Plot for specific parameters
fig = optuna.visualization.plot_slice(
    study,
    params=['lr', 'hidden_dim', 'num_layers', 'batch_size']
)
fig.show()

# 5. Get statistics
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
pruned_trials = [t for t in study.trials if t.state == TrialState.PRUNED]

print(f"Statistics:")
print(f"  Completed trials: {len(completed_trials)}")
print(f"  Pruned trials: {len(pruned_trials)}")
print(f"  Best trial: #{study.best_trial.number}")
print(f"  Best value: {study.best_value:.6f}")

# 6. Top 5 trials
df = study.trials_dataframe()
df_sorted = df.sort_values('value').head(5)
print("\nTop 5 trials:")
print(df_sorted[['number', 'value', 'params_hidden_dim', 'params_lr', 'params_num_layers']])

In [None]:
# Train final model with best hyperparameters

def train_with_best_params(study, epochs=100):
    """Train model with the best hyperparameters found"""
    
    best_params = study.best_params
    print(f"Training with best parameters from trial #{study.best_trial.number}")
    
    # Update configuration with best parameters
    final_config = SmokeCfg(
        hidden=best_params['hidden_dim'],
        layers=best_params['num_layers'],
        lr=best_params['lr'],
        weight_decay=best_params['weight_decay'],
        batch_size=best_params['batch_size'],
        betas=(best_params['beta1'], best_params['beta2']),
        eps=best_params['eps'],
        continuity_loss_weight=best_params['continuity_weight'],
        continuity_target_weight=best_params['continuity_target_weight'],
        momentum_loss_weight=best_params['momentum_weight'],
        momentum_target_weight=best_params['momentum_target_weight'],
        bc_loss_weight=best_params['bc_loss_weight'],
        ramp_start_epoch=best_params['ramp_start_epoch'],
        ramp_epochs=best_params['ramp_epochs'],
        use_global_tokens=best_params['use_global_tokens'],
        num_global_tokens=best_params.get('num_global_tokens', 4),
        attention_heads=best_params.get('attention_heads', 4),
        lr_scheduler=best_params['lr_scheduler'],
        epochs=epochs,
        wandb_mode='online'  # Enable wandb for final training
    )
    
    # Create new model and train
    # ... (use your existing training code with final_config)
    
    return final_config

# Train final model
final_config = train_with_best_params(study, epochs=100)