In [27]:
import sys
from pathlib import Path
import numpy as np
import torch

from torch.utils.data import DataLoader

from src.config import GridSpec, PMLConfig
from src.ml import get_freq_dataset, SimpleFNO, eval_relative_metrics
from src.data import OmegaChannelWrapper, CoordWrapper

# -------------------------------------------------------------------
# Detect the project root (directory containing src/)
# -------------------------------------------------------------------

# In a notebook, __file__ does NOT exist → use Path.cwd()
CWD = Path.cwd()

PROJECT_ROOT = None
for parent in [CWD, *CWD.parents]:
    if (parent / "src" / "__init__.py").exists():
        PROJECT_ROOT = parent
        break

if PROJECT_ROOT is None:
    raise RuntimeError(
        f"Could not find project root containing src/. Started search from: {CWD}"
    )

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print("PROJECT_ROOT =", PROJECT_ROOT)

# -------------------------------------------------------------------
# Now safe to import project modules
# -------------------------------------------------------------------

from src.config import GridSpec, PMLConfig, SolverOptions
from src.ml import (
    SimpleFNO,
    train_model,
    eval_relative_metrics,
    AmpNormWrapper,
    StdNormWrapper,
    compute_input_stats,
    build_direct_map,
)
from src.data import (
        CoordWrapper, 
        ScaleWrapper, 
        IdentityWrapper, 
        OmegaCoordWrapper,
        DeltaTargetWrapper,
        GainWrapper,
        OmegaChannelWrapper,
        get_freq_dataset        
)        
from src.ml import build_direct_map, SimpleFNO, train_model, eval_relative_metrics
from src.utils import run_single_pair




PROJECT_ROOT = c:\Users\31624\Documents\MIT\Programming\FreqTransfer


In [24]:
grid = GridSpec(
    dims=2,
    shape=(48, 48),
    lengths=(1.0, 1.0),
)


In [28]:
# --------- 1. Helper functions for baselines / cropping ----------

def crop_interior_np(u: np.ndarray, margin: int) -> np.ndarray:
    """
    u: (..., H, W) numpy array
    returns: same array cropped by `margin` in both spatial dims.
    """
    if margin <= 0:
        return u
    return u[..., margin:-margin, margin:-margin]


def baseline_metrics_on_raw(raw_ds, kind: str, margin: int | None = None):
    """
    raw_ds: PrecomputedFreqDataset returned by get_freq_dataset
    kind: "zero" (u_pred = 0) or "src" (u_pred = u_src)
    margin: if not None, crop interior by this many cells (PML thickness).
    """
    rels = []
    N = len(raw_ds)

    for i in range(N):
        u_src, u_tgt = raw_ds[i]   # each is torch.Tensor (2, H, W)
        u_src_np = u_src.numpy()
        u_tgt_np = u_tgt.numpy()

        if kind == "zero":
            u_pred = np.zeros_like(u_tgt_np)
        elif kind == "src":
            u_pred = u_src_np
        else:
            raise ValueError("kind must be 'zero' or 'src'")

        if margin is not None:
            u_pred_c = crop_interior_np(u_pred, margin)
            u_tgt_c  = crop_interior_np(u_tgt_np, margin)
        else:
            u_pred_c = u_pred
            u_tgt_c  = u_tgt_np

        num = np.linalg.norm(u_pred_c - u_tgt_c)
        den = np.linalg.norm(u_tgt_c) + 1e-12
        rels.append(num / den)

    rels = np.array(rels, dtype=np.float64)
    return {
        "mean": float(np.mean(rels)),
        "median": float(np.median(rels)),
        "p90": float(np.quantile(rels, 0.90)),
    }


In [30]:
print("\nFNO metrics on full domain:")
for k, v in metrics_full.items():
    print(f"  {k}: {v:.3f}")


FNO metrics on full domain:
  rel_L2_mean: 1.221
  rel_L2_median: 1.178
  rel_L2_p90: 1.431
  mag_RMSE: 0.000
  phase_RMSE: 1.811


In [29]:
# --------- 2. One concrete check for (ω_src, ω_tgt) = (2, 4) ----------

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

omega_src = 2.0
omega_tgt = 4.0

grid = GridSpec(dims=2, shape=(48, 48), lengths=(1.0, 1.0))
pml_cfg = PMLConfig(thickness=16, m=2, sigma_max=10.0)

N_samples  = 200
epochs     = 50
batch_size = 16
lr         = 1e-3

print(f"\n=== Check A: ω_src = {omega_src}, ω_tgt = {omega_tgt}, PML T = {pml_cfg.thickness} ===")

# 1) Load raw dataset (u_src, u_tgt) for baselines
raw_ds = get_freq_dataset(
    grid=grid,
    pml=pml_cfg,
    omega_src=omega_src,
    omega_tgt=omega_tgt,
    N_samples=N_samples,
    omega_to_k=lambda om: om,   # your omega_to_k is identity
    overwrite=False,
)

# 2) Baselines on full domain
base0_full   = baseline_metrics_on_raw(raw_ds, "zero", margin=None)
basesrc_full = baseline_metrics_on_raw(raw_ds, "src",  margin=None)

# 3) Baselines on interior (crop out PML with margin = pml_cfg.thickness)
margin = pml_cfg.thickness
base0_int   = baseline_metrics_on_raw(raw_ds, "zero", margin=margin)
basesrc_int = baseline_metrics_on_raw(raw_ds, "src",  margin=margin)

print("\nBaseline metrics:")
print(f"  baseline0_full: mean={base0_full['mean']:.3f}, "
      f"median={base0_full['median']:.3f}, p90={base0_full['p90']:.3f}")
print(f"  baseline_src_full: mean={basesrc_full['mean']:.3f}, "
      f"median={basesrc_full['median']:.3f}, p90={basesrc_full['p90']:.3f}")
print(f"  baseline0_interior: mean={base0_int['mean']:.3f}, "
      f"median={base0_int['median']:.3f}, p90={base0_int['p90']:.3f}")
print(f"  baseline_src_interior: mean={basesrc_int['mean']:.3f}, "
      f"median={basesrc_int['median']:.3f}, p90={basesrc_int['p90']:.3f}")

# 4) Wrap for FNO: add ω-channels + coords
freq_ds = OmegaChannelWrapper(raw_ds, omega_src=omega_src, omega_tgt=omega_tgt)
freq_ds = CoordWrapper(freq_ds, grid=grid, normalise=True)

x0, y0 = freq_ds[0]
in_ch = x0.shape[0]
print("\nWrapped sample shape:", x0.shape, y0.shape, f"(in_ch={in_ch})")

# 5) Build and train FNO on wrapped freq-transfer dataset
model = SimpleFNO(
    in_ch=in_ch,          # 6 here
    width=48,
    modes=(12, 12),
    layers=4,
    out_ch=2,
    use_global_skip=False,
).to(device)

# You can use your own training loop or the helper; here we do a quick custom loop:
train_loader = DataLoader(freq_ds, batch_size=batch_size, shuffle=True)

opt = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()

for ep in range(1, epochs + 1):
    model.train()
    tr_losses = []
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        opt.zero_grad()
        yhat = model(x)
        loss = loss_fn(yhat, y)
        loss.backward()
        opt.step()

        tr_losses.append(loss.item())

    print(f"[{ep:03d}/{epochs}] train={np.mean(tr_losses):.4e}")

# 6) Evaluate FNO on full (wrapped) dataset
metrics_full = eval_relative_metrics(model, freq_ds, batch_size=batch_size, device=device)

print("\nFNO metrics on full domain:")
for k, v in metrics_full.items():
    print(f"  {k}: {v:.3f}")

Using device: cpu

=== Check A: ω_src = 2.0, ω_tgt = 4.0, PML T = 16 ===
[get_freq_dataset] Loading cached dataset from: C:\Users\31624\Documents\MIT\Programming\FreqTransfer\data\freq_transfer_cached\wsrc2p000_wtgt4p000_N200_grid48x48_pmlT16_m2_sig10p00
  Loaded N=200 samples, ω_src=2.0, ω_tgt=4.0

Baseline metrics:
  baseline0_full: mean=1.000, median=1.000, p90=1.000
  baseline_src_full: mean=1.454, median=1.327, p90=1.895
  baseline0_interior: mean=1.000, median=1.000, p90=1.000
  baseline_src_interior: mean=2.275, median=2.141, p90=2.793

Wrapped sample shape: torch.Size([6, 48, 48]) torch.Size([2, 48, 48]) (in_ch=6)
[001/50] train=4.2170e-03
[002/50] train=2.2259e-04
[003/50] train=4.6795e-05
[004/50] train=1.4201e-05
[005/50] train=3.8861e-06
[006/50] train=1.2069e-06
[007/50] train=3.5516e-07
[008/50] train=8.3554e-08
[009/50] train=2.8427e-08
[010/50] train=1.9278e-08
[011/50] train=1.3878e-08
[012/50] train=1.2733e-08
[013/50] train=1.2305e-08
[014/50] train=1.2307e-08
[015/5