# Toy Tiny Training Loop Visualizer

이 노트북은 **아주 작은 학습 루프**를 M0~M4에 대해 직접 돌리고,
각 epoch마다 시각화해서 **학습 속도**와 **학습 품질**을 비교하는 용도입니다.

비교 항목:
- 학습 속도: epoch 시간, step 시간
- 학습 품질: validation DSM, real-nearest distance(작을수록 좋음)
- epoch별 샘플 분포(실데이터 오버레이)
- epoch별 벡터장 + curl 히트맵


In [1]:
from __future__ import annotations

import time
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch


def resolve_repo_root(start: Path) -> Path:
    """Resolve repository root by searching parent directories.

    Args:
        start: Starting path (notebook working directory).

    Returns:
        Repo root path containing `src/` and `configs/`.

    How it works:
        Walks upward from `start` to filesystem root and returns the
        first directory that matches this project layout.
    """
    for candidate in [start, *start.parents]:
        if (candidate / 'src').is_dir() and (candidate / 'configs').is_dir():
            return candidate
    raise RuntimeError(
        f'Could not locate repo root from {start}. '
        'Open this notebook inside the Advance_score repository.'
    )


ROOT = resolve_repo_root(Path.cwd().resolve())
if str(ROOT) not in sys.path:
    # 로컬 src 패키지 임포트를 위해 루트를 path에 추가한다.
    sys.path.insert(0, str(ROOT))

from src.data import make_loader, sample_toy_data, unpack_batch
from src.models import build_model, score_fn_from_model
from src.sampling import sample_heun
from src.trainers.train_step_baseline import train_step_baseline
from src.trainers.train_step_m3 import train_step_m3
from src.trainers.train_step_m4 import train_step_m4
from src.trainers.train_step_reg import train_step_reg
from src.trainers.train_step_struct import train_step_struct
from src.trainers.common import compute_dsm_for_score
from src.utils.config import ensure_experiment_defaults, load_config
from src.utils.feature_encoder import build_feature_encoder
from src.utils.seed import seed_everything

plt.style.use('seaborn-v0_8-whitegrid')
pd.set_option('display.max_columns', 200)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'ROOT={ROOT}')
print(f'DEVICE={DEVICE}')


ROOT=/home/junyeobe/projects/Advance_score
DEVICE=cuda


In [2]:
# ===== Tiny Training Controls =====
MODEL_IDS = ['M0', 'M1', 'M2', 'M3', 'M4']
SEED = 0

EPOCHS = 10
STEPS_PER_EPOCH = 30
BATCH_SIZE = 512
LR = 2.0e-4

NUM_EVAL_REAL = 1024
NUM_EVAL_FAKE = 1024
EVAL_DSM_REPEATS = 3
SAMPLE_NFE = 40

SIGMA_VIS = 0.20
GRID_POINTS = 27

# 시각화할 epoch 목록. None이면 모든 epoch를 표시한다.
EPOCHS_TO_SHOW = None

print('Configured tiny training loop.')


Configured tiny training loop.


## Helper/Trainer 셀 설명 (중요)

아래 코드 셀은 **tiny 학습 실험의 핵심 로직**을 담고 있습니다.

무엇을 하나요:
- `model_id -> config` 매핑(`M0~M4`)
- 모델별 학습 step dispatch (`run_one_train_step`)
- epoch 단위 평가(`val_dsm`, `nn_real_dist`, 샘플, 벡터장, curl)
- 전체 tiny 학습 루프 실행(`run_tiny_training_suite`)

모델별 objective(loss) 정리:
- `M0` (Baseline): `L = L_DSM`
- `M1` (Jacobian reg): `L = L_DSM + lambda_sym * R_sym + mu_loop * R_loop`
- `M2` (Struct conservative): `s = grad_x phi`, `L = L_DSM`
- `M3` (Jacobian-free nonlocal): `L = L_DSM + mu1 * R_loop_multi + mu2 * R_cycle`
- `M4` (Hybrid hard-conservative): `L = L_DSM + alpha * R_match + beta * R_cycle_low(optional)`

로그 지표 해석:
- `train_loss_mean`: epoch 내 평균 학습 loss (낮을수록 좋음)
- `val_dsm`: 고정 eval batch에서 DSM 추정 (낮을수록 좋음)
- `nn_real_dist`: 생성 샘플의 최근접 real 거리 평균 (낮을수록 분포 적합)
- `step_time_ms_mean`: step당 시간 (낮을수록 빠름)
- `trajectory_length_mean`: 샘플링 궤적 길이 (동역학 비교 지표)


In [3]:
def model_id_to_config_path(model_id: str) -> Path:
    """Map canonical model id to toy config path.

    Args:
        model_id: Canonical model token (`M0`..`M4`).

    Returns:
        Config file path under `configs/toy`.
    """
    suffix = model_id[1:]
    return ROOT / 'configs' / 'toy' / f'm{suffix}.yaml'


def load_tiny_cfg(model_id: str, batch_size: int, lr: float) -> dict:
    """Load and override toy config for tiny-loop experiments.

    Args:
        model_id: Canonical model token (`M0`..`M4`).
        batch_size: Tiny-loop training batch size.
        lr: Learning rate for optimizer.

    Returns:
        Resolved config dictionary tailored for tiny interactive runs.

    How it works:
        Starts from official toy config and applies minimal overrides
        so training stays fast in notebook mode.
    """
    cfg = ensure_experiment_defaults(load_config(model_id_to_config_path(model_id)))

    # Tiny-loop overrides for fast interactive iteration.
    cfg['dataset']['batch_size'] = int(batch_size)
    cfg['dataset']['num_workers'] = 0
    cfg['train']['amp'] = False
    cfg['train']['lr'] = float(lr)
    cfg['train']['clip_grad_norm'] = 1.0
    cfg['loss']['reg_freq'] = 1

    return cfg


def score_requires_grad(model_id: str) -> bool:
    """Return whether sampler must enable grad path for this model id.

    Args:
        model_id: Canonical model token.

    Returns:
        True for structure-dependent score wrappers (`M2`, `M4`).
    """
    return model_id in {'M2', 'M4'}


def run_one_train_step(
    model_id: str,
    model: torch.nn.Module,
    x0: torch.Tensor,
    cfg: dict,
    global_step: int,
    feature_encoder: torch.nn.Module | None,
) -> tuple[torch.Tensor, dict[str, float]]:
    """Dispatch one training objective step by model id.

    Args:
        model_id: Canonical model token (`M0`..`M4`).
        model: Active model instance.
        x0: Clean training batch.
        cfg: Model config dictionary.
        global_step: Global step index used by gated regularizers.
        feature_encoder: Optional frozen encoder for M3/M4 cycle losses.

    Returns:
        Tuple `(loss_tensor, metric_dict)`.

    Model-Loss mapping:
        - M0: DSM only
        - M1: DSM + Jacobian asymmetry/loop regularization
        - M2: DSM with conservative structure (`score = grad phi`)
        - M3: DSM + multi-scale loop + graph cycle
        - M4: DSM + boundary matching + optional low-noise cycle
    """
    if model_id == 'M0':
        return train_step_baseline(model, x0, cfg)
    if model_id == 'M1':
        return train_step_reg(model, x0, cfg, global_step)
    if model_id == 'M2':
        return train_step_struct(model, x0, cfg)
    if model_id == 'M3':
        assert feature_encoder is not None
        return train_step_m3(model, x0, cfg, global_step, feature_encoder)
    if model_id == 'M4':
        assert feature_encoder is not None
        return train_step_m4(model, x0, cfg, global_step, feature_encoder)

    raise ValueError(f'Unsupported model_id: {model_id}')


def evaluate_score_grid(
    score_fn,
    sigma_value: float,
    points: int,
    xlim: tuple[float, float] = (-6.0, 6.0),
    ylim: tuple[float, float] = (-6.0, 6.0),
    device: torch.device = torch.device('cpu'),
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Evaluate score field on a dense 2D grid for visualization.

    Args:
        score_fn: Callable `score_fn(x, sigma)` returning `[B,2]` score.
        sigma_value: Scalar sigma used for entire grid.
        points: Number of grid points per axis.
        xlim: X-axis range.
        ylim: Y-axis range.
        device: Execution device for score evaluation.

    Returns:
        Tuple `(X, Y, U, V)` for vector-field plotting.
    """
    xs = np.linspace(xlim[0], xlim[1], points)
    ys = np.linspace(ylim[0], ylim[1], points)
    X, Y = np.meshgrid(xs, ys)

    grid_np = np.stack([X.reshape(-1), Y.reshape(-1)], axis=1).astype(np.float32)
    x_tensor = torch.from_numpy(grid_np).to(device)
    sigma = torch.full((x_tensor.shape[0],), float(sigma_value), device=device, dtype=torch.float32)

    with torch.enable_grad():
        score = score_fn(x_tensor, sigma).detach().cpu().numpy()

    U = score[:, 0].reshape(points, points)
    V = score[:, 1].reshape(points, points)
    return X, Y, U, V


def finite_diff_curl(U: np.ndarray, V: np.ndarray, xs: np.ndarray, ys: np.ndarray) -> np.ndarray:
    """Compute finite-difference curl proxy `dV/dx - dU/dy` in 2D.

    Args:
        U: X-component grid array.
        V: Y-component grid array.
        xs: X-axis coordinates.
        ys: Y-axis coordinates.

    Returns:
        Curl array with same shape as `U`/`V`.
    """
    dV_dy, dV_dx = np.gradient(V, ys, xs, edge_order=2)
    dU_dy, dU_dx = np.gradient(U, ys, xs, edge_order=2)
    return dV_dx - dU_dy


def evaluate_model_epoch(
    model_id: str,
    model: torch.nn.Module,
    cfg: dict,
    real_eval: torch.Tensor,
    eval_dsm_repeats: int,
    num_eval_fake: int,
    sample_nfe: int,
    sigma_vis: float,
    grid_points: int,
    device: torch.device,
) -> dict:
    """Compute epoch-level quality/speed diagnostics for one model.

    Args:
        model_id: Canonical model token (`M0`..`M4`).
        model: Trained model at current epoch.
        cfg: Model config dictionary.
        real_eval: Fixed real toy batch used for epoch-wise eval.
        eval_dsm_repeats: Number of DSM Monte-Carlo repeats.
        num_eval_fake: Number of generated fake samples for comparison.
        sample_nfe: NFE for Heun sampler.
        sigma_vis: Sigma value for field visualization snapshots.
        grid_points: Field-grid resolution per axis.
        device: Runtime device.

    Returns:
        Dict with keys:
        - `val_dsm`: validation DSM estimate
        - `nn_real_dist`: mean nearest-real distance
        - `trajectory_length_mean`: sampler trajectory length
        - `samples_np`: generated samples numpy array
        - `field`: dict containing X/Y/U/V/curl arrays

    How it works:
        Reuses project score wrapper + Heun sampler, then computes a simple
        distribution-fit proxy (`nearest real distance`) and field snapshot.
    """
    model.eval()
    sigma_min = float(cfg['loss']['sigma_min'])
    sigma_max = float(cfg['loss']['sigma_max'])
    weight_mode = str(cfg['loss'].get('weight_mode', 'sigma2'))

    score_fn = score_fn_from_model(model, model_id, create_graph=False)

    dsm_vals = []
    for _ in range(int(eval_dsm_repeats)):
        dsm, _ = compute_dsm_for_score(
            score_fn=score_fn,
            x0=real_eval,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            weight_mode=weight_mode,
        )
        dsm_vals.append(float(dsm.detach().item()))
    val_dsm = float(np.mean(dsm_vals))

    fake, stats = sample_heun(
        score_fn=score_fn,
        shape=(int(num_eval_fake), 2),
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        nfe=int(sample_nfe),
        device=device,
        score_requires_grad=score_requires_grad(model_id),
        return_trajectory=False,
    )

    fake_detached = fake.detach()
    real_for_dist = real_eval[: fake_detached.shape[0]]
    nn_dist = torch.cdist(fake_detached, real_for_dist).min(dim=1).values.mean().item()

    X, Y, U, V = evaluate_score_grid(
        score_fn=score_fn,
        sigma_value=float(sigma_vis),
        points=int(grid_points),
        device=device,
    )
    xs = np.linspace(-6.0, 6.0, int(grid_points))
    ys = np.linspace(-6.0, 6.0, int(grid_points))
    curl = finite_diff_curl(U=U, V=V, xs=xs, ys=ys)

    return {
        'val_dsm': float(val_dsm),
        'nn_real_dist': float(nn_dist),
        'trajectory_length_mean': float(stats.get('trajectory_length_mean', float('nan'))),
        'samples_np': fake_detached.detach().cpu().numpy(),
        'field': {'X': X, 'Y': Y, 'U': U, 'V': V, 'curl': curl},
    }


def run_tiny_training_suite(
    model_ids: list[str],
    epochs: int,
    steps_per_epoch: int,
    batch_size: int,
    lr: float,
    num_eval_real: int,
    num_eval_fake: int,
    eval_dsm_repeats: int,
    sample_nfe: int,
    sigma_vis: float,
    grid_points: int,
    device: torch.device,
    seed: int,
):
    """Run tiny training loops for multiple model ids and collect snapshots.

    Args:
        model_ids: Model ids to train (`M0`..`M4`).
        epochs: Number of tiny-loop epochs.
        steps_per_epoch: Number of optimizer steps per epoch.
        batch_size: Training batch size.
        lr: AdamW learning rate.
        num_eval_real: Number of fixed real eval samples.
        num_eval_fake: Number of generated eval samples per epoch.
        eval_dsm_repeats: DSM repeat count for smoother estimate.
        sample_nfe: Heun NFE for epoch eval.
        sigma_vis: Sigma for vector-field snapshots.
        grid_points: Grid resolution for vector-field snapshots.
        device: Runtime torch device.
        seed: Global random seed.

    Returns:
        Tuple `(history_df, sample_snaps, field_snaps, real_eval_np)`.

    How it works:
        Initializes one model/optimizer/loader per model id, trains each
        for `steps_per_epoch`, then logs epoch metrics + visual snapshots.
    """
    seed_everything(int(seed))

    states = {}
    for model_id in model_ids:
        cfg = load_tiny_cfg(model_id=model_id, batch_size=batch_size, lr=lr)

        model = build_model(cfg).to(device)
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=float(cfg['train']['lr']),
            betas=tuple(cfg['train']['betas']),
            weight_decay=float(cfg['train'].get('weight_decay', 0.01)),
        )

        loader = make_loader(cfg, train=True)
        data_iter = iter(loader)

        feature_encoder = None
        if model_id in {'M3', 'M4'}:
            feature_encoder = build_feature_encoder(
                dataset_name='toy',
                channels=int(cfg['dataset'].get('channels', 1)),
                device=device,
            )

        states[model_id] = {
            'cfg': cfg,
            'model': model,
            'optimizer': optimizer,
            'loader': loader,
            'data_iter': data_iter,
            'feature_encoder': feature_encoder,
            'global_step': 0,
        }

    # 고정 평가 배치: epoch 간 비교 분산을 줄이기 위해 공유한다.
    cfg_ref = states[model_ids[0]]['cfg']
    real_eval = sample_toy_data(cfg_ref, num_samples=int(num_eval_real)).to(device)

    history_rows = []
    sample_snaps = {m: [] for m in model_ids}
    field_snaps = {m: [] for m in model_ids}

    for epoch in range(1, int(epochs) + 1):
        print(f'\n[Epoch {epoch}/{epochs}]')
        for model_id in model_ids:
            state = states[model_id]
            cfg = state['cfg']
            model = state['model']
            optimizer = state['optimizer']
            feature_encoder = state['feature_encoder']

            model.train()
            step_times = []
            train_loss_vals = []

            epoch_t0 = time.perf_counter()
            for _ in range(int(steps_per_epoch)):
                state['global_step'] += 1

                try:
                    batch = next(state['data_iter'])
                except StopIteration:
                    state['data_iter'] = iter(state['loader'])
                    batch = next(state['data_iter'])

                x0 = unpack_batch(batch).to(device)

                optimizer.zero_grad(set_to_none=True)
                t0 = time.perf_counter()

                loss, metrics = run_one_train_step(
                    model_id=model_id,
                    model=model,
                    x0=x0,
                    cfg=cfg,
                    global_step=state['global_step'],
                    feature_encoder=feature_encoder,
                )

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float(cfg['train']['clip_grad_norm']))
                optimizer.step()

                step_ms = (time.perf_counter() - t0) * 1000.0
                step_times.append(step_ms)
                train_loss_vals.append(float(metrics['loss_total']))

            epoch_time = time.perf_counter() - epoch_t0

            eval_out = evaluate_model_epoch(
                model_id=model_id,
                model=model,
                cfg=cfg,
                real_eval=real_eval,
                eval_dsm_repeats=eval_dsm_repeats,
                num_eval_fake=num_eval_fake,
                sample_nfe=sample_nfe,
                sigma_vis=sigma_vis,
                grid_points=grid_points,
                device=device,
            )

            sample_snaps[model_id].append(eval_out['samples_np'])
            field_snaps[model_id].append(eval_out['field'])

            history_rows.append(
                {
                    'epoch': int(epoch),
                    'model_id': model_id,
                    'train_loss_mean': float(np.mean(train_loss_vals)),
                    'step_time_ms_mean': float(np.mean(step_times)),
                    'epoch_time_sec': float(epoch_time),
                    'val_dsm': float(eval_out['val_dsm']),
                    'nn_real_dist': float(eval_out['nn_real_dist']),
                    'trajectory_length_mean': float(eval_out['trajectory_length_mean']),
                }
            )

            # Epoch 요약: 모델별 수렴 속도(손실)와 품질(proxy)를 함께 출력한다.
            print(
                f'  {model_id}: '
                f'train_loss={np.mean(train_loss_vals):.4f}, '
                f'val_dsm={eval_out["val_dsm"]:.4f}, '
                f'nn_real_dist={eval_out["nn_real_dist"]:.4f}, '
                f'step_ms={np.mean(step_times):.2f}'
            )

    history_df = pd.DataFrame(history_rows)
    return history_df, sample_snaps, field_snaps, real_eval.detach().cpu().numpy()


SyntaxError: f-string: unmatched '[' (4291461890.py, line 397)

In [None]:
history_df, sample_snaps, field_snaps, real_eval_np = run_tiny_training_suite(
    model_ids=MODEL_IDS,
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    batch_size=BATCH_SIZE,
    lr=LR,
    num_eval_real=NUM_EVAL_REAL,
    num_eval_fake=NUM_EVAL_FAKE,
    eval_dsm_repeats=EVAL_DSM_REPEATS,
    sample_nfe=SAMPLE_NFE,
    sigma_vis=SIGMA_VIS,
    grid_points=GRID_POINTS,
    device=DEVICE,
    seed=SEED,
)

display(history_df.head())


In [None]:
# ============================
# Speed / Quality Curves
# ============================

fig, axes = plt.subplots(1, 4, figsize=(22, 4))

for model_id in MODEL_IDS:
    sub = history_df[history_df['model_id'] == model_id].sort_values('epoch')
    if sub.empty:
        continue

    axes[0].plot(sub['epoch'], sub['train_loss_mean'], marker='o', label=model_id)
    axes[1].plot(sub['epoch'], sub['val_dsm'], marker='o', label=model_id)
    axes[2].plot(sub['epoch'], sub['nn_real_dist'], marker='o', label=model_id)
    axes[3].plot(sub['epoch'], sub['step_time_ms_mean'], marker='o', label=model_id)

axes[0].set_title('Train Loss (mean/epoch)')
axes[1].set_title('Validation DSM')
axes[2].set_title('Nearest Real Distance')
axes[3].set_title('Step Time (ms)')

for ax in axes:
    ax.set_xlabel('epoch')
    ax.legend()

plt.tight_layout()
plt.show()


In [None]:
final_table = (
    history_df.sort_values('epoch')
    .groupby('model_id', as_index=False)
    .tail(1)[['model_id', 'epoch', 'train_loss_mean', 'val_dsm', 'nn_real_dist', 'step_time_ms_mean', 'epoch_time_sec']]
    .sort_values(['val_dsm', 'nn_real_dist'])
)
display(final_table)


In [None]:
def get_epochs_to_show(max_epoch: int, requested: list[int] | None):
    """Resolve epoch list for visualization loops.

    Args:
        max_epoch: Maximum available epoch index.
        requested: Optional user-requested epoch list.

    Returns:
        Valid epoch list in ascending order.
    """
    if requested is None:
        return list(range(1, max_epoch + 1))
    out = sorted(set(int(e) for e in requested if 1 <= int(e) <= max_epoch))
    return out if out else [max_epoch]


def plot_sample_epoch(epoch: int) -> None:
    """Plot generated sample scatter per model for one epoch.

    Args:
        epoch: 1-based epoch index to visualize.

    Returns:
        None. Displays matplotlib figure.
    """
    idx = int(epoch) - 1
    fig, axes = plt.subplots(1, len(MODEL_IDS), figsize=(4.5 * len(MODEL_IDS), 4), sharex=True, sharey=True)
    if len(MODEL_IDS) == 1:
        axes = [axes]

    for ax, model_id in zip(axes, MODEL_IDS):
        fake = sample_snaps[model_id][idx]

        # 실데이터를 옅은 점으로 깔고 생성 샘플을 덧그린다.
        ax.scatter(real_eval_np[:, 0], real_eval_np[:, 1], s=4, alpha=0.18, label='real', color='tab:gray')
        ax.scatter(fake[:, 0], fake[:, 1], s=6, alpha=0.6, label='fake', color='tab:blue')

        ax.set_title(f'{model_id} - epoch {epoch}')
        ax.set_xlim(-6, 6)
        ax.set_ylim(-6, 6)
        ax.set_aspect('equal')

    axes[0].legend(loc='upper right')
    plt.tight_layout()
    plt.show()


epochs_to_show = get_epochs_to_show(max_epoch=EPOCHS, requested=EPOCHS_TO_SHOW)
for epoch in epochs_to_show:
    plot_sample_epoch(epoch)


In [None]:
def plot_field_epoch(epoch: int) -> None:
    """Plot vector field and curl heatmap per model for one epoch.

    Args:
        epoch: 1-based epoch index.

    Returns:
        None. Displays matplotlib figure.
    """
    idx = int(epoch) - 1
    fig, axes = plt.subplots(len(MODEL_IDS), 2, figsize=(12, 4.2 * len(MODEL_IDS)))
    if len(MODEL_IDS) == 1:
        axes = np.array([axes])

    for row, model_id in enumerate(MODEL_IDS):
        snap = field_snaps[model_id][idx]
        X, Y = snap['X'], snap['Y']
        U, V = snap['U'], snap['V']
        curl = snap['curl']

        ax0 = axes[row, 0]
        ax0.quiver(X, Y, U, V, angles='xy', scale_units='xy', scale=10)
        ax0.set_title(f'{model_id} field @ epoch {epoch}')
        ax0.set_xlim(-6, 6)
        ax0.set_ylim(-6, 6)
        ax0.set_aspect('equal')

        ax1 = axes[row, 1]
        im = ax1.imshow(curl, extent=[-6, 6, -6, 6], origin='lower', cmap='coolwarm')
        ax1.set_title(f'{model_id} curl @ epoch {epoch}')
        ax1.set_aspect('equal')
        fig.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()


for epoch in epochs_to_show:
    plot_field_epoch(epoch)


## 빠른 사용 가이드
- 속도 위주 비교: `STEPS_PER_EPOCH`를 줄이고 `EPOCHS`를 5~8로 둡니다.
- 품질 위주 비교: `NUM_EVAL_FAKE`와 `SAMPLE_NFE`를 늘립니다.
- GPU 메모리가 부족하면 `BATCH_SIZE`와 `NUM_EVAL_FAKE`를 줄입니다.
- 특정 epoch만 보고 싶으면 `EPOCHS_TO_SHOW = [1, 3, 5, EPOCHS]`처럼 지정합니다.
