In [1]:
# # # !pip install -r requirements.txt
# # !pip install --upgrade jax jaxlib
# # !pip install --upgrade equinox # Add this line to upgrade equinox
# # !pip install -U "datasets>=2.19.0"
# !pip install --upgrade "jax[cuda12]==0.6.0" jax-cuda12-plugin==0.6.0
# #           └─ installs jaxlib 0.6.0 automatically
# !pip install equinox

In [2]:
# import os

# target_dir = "/content/drive/MyDrive/ConDiff-main"

# os.makedirs(target_dir, exist_ok=True)

# os.chdir(target_dir)

# print("Текущая рабочая директория:", os.getcwd())


In [3]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
# import optax
# import equinox as eqx
import numpy as np
# import jax.numpy as jnp
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from IPython import display
from functools import partial
# from jax.lax import scan, dot_general
from load_ConDiff import load_ConDiff
# from architectures import UNet
# from jax import config, random, grad, vmap, jit
# from jax.tree_util import tree_map, tree_flatten
import torch, time, pandas as pd
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import numpy as np
import matplotlib.pyplot as plt
from architectures.IAFNO_pt import get_IAFNO_pt
import os
import math
from itertools import product
from copy import deepcopy

CHECKPOINT_PATH = "iafno_poisson64.pth"



In [5]:
def relative_error(pred, targ):
    return ((pred - targ).view(pred.size(0), -1).norm(dim=1) /
            targ.view(targ.size(0), -1).norm(dim=1))

def save_checkpoint(model, optimizer, epoch, history, path=CHECKPOINT_PATH):
    torch.save({
        "epoch":   epoch,
        "model":   model.state_dict(),
        "optim":   optimizer.state_dict(),
        "history": history,
    }, path)

def load_checkpoint(model, optimizer=None, path=CHECKPOINT_PATH, map_location="cpu"):
    ckpt = torch.load(path, map_location=map_location)
    model.load_state_dict(ckpt["model"])
    if optimizer is not None:
        optimizer.load_state_dict(ckpt["optim"])
    return ckpt["epoch"], ckpt["history"]

def huber_relative_error(pred, targ, delta=0.01):
    """
    Huber-loss по пикселям, затем нормируем на ‖target‖, как в relative_error.
    delta — порог между L1 и L2; 0.01 работает для карт в диапазоне ~[0,1].
    """
    diff = pred - targ
    abs_diff = diff.abs()
    quad = torch.clamp(abs_diff, max=delta)
    lin  = abs_diff - quad
    huber = 0.5 * quad.pow(2) / delta + lin           # SmoothL1 вручную
    per_sample = huber.view(huber.size(0), -1).sum(dim=1)
    denom = targ.view(targ.size(0), -1).norm(dim=1) + 1e-8
    return per_sample / denom


In [None]:
from tqdm.auto import tqdm
import math, torch

def train_model_pt(model_data, features, targets, spec, device, use_huber=False):
    """обучает модель и возвращает (model, history_train, history_test)"""

    model = model_data["model"]
    feats_tr, feats_ts = features
    targs_tr, targs_ts = targets

    # ---- оптимизатор -----------------------------------------------------
    opt = torch.optim.AdamW(model.parameters(),
                            lr=spec["learning_rate"],
                            weight_decay=spec["weight_decay"])

    # ---- scheduler: 5 % warm-up + косинусный спад ------------------------
    steps_per_epoch = math.ceil(feats_tr.size(0) / spec["batch_size"])
    total_steps     = spec["N_epochs"] * steps_per_epoch
    scheduler = StepLR(opt, step_size=50, gamma=0.5)

    # ---- выбор функции ошибки -------------------------------------------
    loss_fn = huber_relative_error if use_huber else relative_error

    best_test = float("inf")
    hist_train, hist_test = [], []

    global_step = 0
    epoch_bar = tqdm(range(1, spec["N_epochs"] + 1), desc="Epochs", position=0)

    for epoch in epoch_bar:
        model.train()
        idx = torch.randperm(feats_tr.size(0), device=device)

        # ----- батчи -----
        batch_bar = tqdm(range(0, len(idx), spec["batch_size"]),
                         desc=f"e{epoch}", leave=False, position=1)

        for i in batch_bar:
            b = idx[i:i+spec["batch_size"]]
            pred = model(feats_tr[b])
            loss = loss_fn(pred, targs_tr[b]).mean()

            opt.zero_grad()
            loss.backward()
            opt.step()
            global_step += 1
            batch_bar.set_postfix(train_loss=f"{loss.item():.4e}",
                                  lr=scheduler.get_last_lr()[0])

        # ----- валидация -----
        model.eval()
        with torch.no_grad():
            pred_ts   = model(feats_ts)
            test_loss = loss_fn(pred_ts, targs_ts).mean()

        hist_train.append(loss.item())
        hist_test.append(test_loss.item())
        epoch_bar.set_postfix(train=f"{loss.item():.4e}",
                              test=f"{test_loss.item():.4e}",
                              lr=scheduler.get_last_lr()[0])
        scheduler.step()

        # чекпойнт, если улучшилось
        if test_loss < best_test:
            best_test = test_loss
            save_checkpoint(model, opt, epoch,
                            {"train": hist_train, "test": hist_test})

    return model, hist_train, hist_test


In [7]:
def get_results(grid, type_of_pde="poisson", direction_to_save="data"):
    import load_ConDiff
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. датасет
    (feat_tr, targ_tr), (feat_ts, targ_ts) = load_ConDiff.get_datasets(
        direction_to_save, type_of_pde, grid
    )
    feat_tr = torch.tensor(feat_tr, dtype=torch.float32, device=device)
    targ_tr = torch.tensor(targ_tr, dtype=torch.float32, device=device)
    feat_ts = torch.tensor(feat_ts, dtype=torch.float32, device=device)
    targ_ts = torch.tensor(targ_ts, dtype=torch.float32, device=device)

    # 2. модель + оптимизация
    model_data, spec = get_IAFNO(grid, device)
    model, h_train, h_test = train_model_pt(
        model_data,
        [feat_tr, feat_ts],
        [targ_tr, targ_ts],
        spec,
        device,
        use_huber=True
    )

    # 3. итоговая метрика
    model.eval()
    with torch.no_grad():
        pred = model(feat_ts)
        err  = relative_error(pred, targ_ts)
    data = {
        "history_train": h_train,
        "history_test" : h_test,
        "test_error_mean": err.mean().item(),
        "test_error_std" : err.std().item()
    }
    return data, model, model_data, feat_ts, targ_ts


In [8]:
def plot_results(model, history, features_test, targets_test):
    """
    model         : обученная IAFNO (PyTorch)
    history       : dict  {'history_train': [...], 'history_test': [...]}
    features_test : torch.Tensor (N,1,H,W)  – rhs
    targets_test  : torch.Tensor (N,1,H,W)  – solution
    """
    model.eval()
    with torch.no_grad():
        pred = model(features_test).cpu().numpy()[:, 0]   # (N,H,W)

    targ = targets_test.cpu().numpy()[:, 0]

    display.clear_output(wait=True)
    plt.rcParams["font.family"] = "serif"

    fig, ax = plt.subplots(1, 3, figsize=(15, 4))

    # кривая потерь
    ax[0].set_title("Loss")
    ax[0].set_yscale("log")
    ax[0].plot(history["history_train"], "-",  color="red",   label="train")
    ax[0].plot(history["history_test"],  "-.", color="green", label="test")
    ax[0].legend();  ax[0].grid(ls="-.")
    ax[0].spines[["top", "right"]].set_visible(False)

    ax[1].contourf(pred[0]);  ax[1].set_title("Prediction")
    ax[2].contourf(targ[0]);  ax[2].set_title("Target")

    plt.tight_layout();  plt.show()


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
grid   = 64
CHECKPOINT_PATH = "iafno_poisson64.pth"

model_data, spec = get_IAFNO_pt(grid, device)
model = model_data["model"]

# ---------- датасет -------------------------------------------------------
(rhs_tr, x_tr), (rhs_ts, x_ts) = load_ConDiff("data", "poisson", grid)

rhs_tr = rhs_tr.reshape(-1, 1, grid, grid)
x_tr   = x_tr.reshape(-1, 1, grid, grid)
rhs_ts = rhs_ts.reshape(-1, 1, grid, grid)
x_ts   = x_ts.reshape(-1, 1, grid, grid)

feat_tr = torch.tensor(rhs_tr, dtype=torch.float32, device=device)
targ_tr = torch.tensor(x_tr,   dtype=torch.float32, device=device)
feat_ts = torch.tensor(rhs_ts, dtype=torch.float32, device=device)
targ_ts = torch.tensor(x_ts,   dtype=torch.float32, device=device)

# 1) Добавляем двумерные координаты (x,y) в каналы:
xs = torch.linspace(0, 1, grid, device=device)
ys = torch.linspace(0, 1, grid, device=device)
X, Y = torch.meshgrid(xs, ys, indexing="xy")
# coords.shape = (2, grid, grid) → (B,2,grid,grid)
coords = torch.stack([X, Y], dim=0) \
            .unsqueeze(0) \
            .repeat(feat_tr.size(0), 1, 1, 1)
feat_tr = torch.cat([feat_tr, coords], dim=1)  # (B,3,grid,grid)
feat_ts = torch.cat(
    [feat_ts, coords[: feat_ts.size(0)]], dim=1
)

# 2) Z-score нормировка forcing и target:
f_mean, f_std = feat_tr.mean(), feat_tr.std()
feat_tr = (feat_tr - f_mean) / f_std
feat_ts = (feat_ts - f_mean) / f_std

u_mean, u_std = targ_tr.mean(), targ_tr.std()
targ_tr = (targ_tr - u_mean) / u_std
targ_ts = (targ_ts - u_mean) / u_std


# ---------- обучение или загрузка ----------------------------------------
if os.path.isfile(CHECKPOINT_PATH):
    _, hist = load_checkpoint(model, path=CHECKPOINT_PATH, map_location=device)
    history = hist
    print("Чекпойнт найден — обучение пропущено.")
else:
    model, h_tr, h_ts = train_model_pt(
        model_data,
        [feat_tr, feat_ts],
        [targ_tr, targ_ts],
        spec,
        device
    )
    history = {"history_train": h_tr, "history_test": h_ts}
    print("Обучение завершено.")

# ---------- оценка и график ----------------------------------------------
model.eval()
with torch.no_grad():
    pred = model(feat_ts)
    err  = relative_error(pred, targ_ts)

print(f"IAFNO  test error: {err.mean():.3f} ± {err.std():.3f}")

plot_results(model, history, feat_ts, targ_ts)


Epochs:   0%|          | 0/400 [00:00<?, ?it/s]

e1:   0%|          | 0/63 [00:00<?, ?it/s]

e2:   0%|          | 0/63 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
print(f"IAFNO  test error: {err.mean():.3f} ± {err.std():.3f}")


IAFNO  test error: 1.283 ± 0.306


In [None]:

search_space = {
    "width" : [64, 96],
    "depth" : [8, 10],
    "modes" : [grid // 4],
    "n_imp" : [2, 3],
    "lr"    : [3e-4, 1e-4, 1e-5],
    "wd"    : [1e-6],
    "batch" : [16]
}
configs = [dict(zip(search_space, v)) for v in product(*search_space.values())]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
grid   = 64
(rhs_tr, x_tr), (rhs_ts, x_ts) = load_ConDiff("data","poisson",grid)

tr = torch.tensor(rhs_tr.reshape(-1,1,grid,grid), dtype=torch.float32, device=device)
yr = torch.tensor(x_tr.reshape (-1,1,grid,grid), dtype=torch.float32, device=device)
ts = torch.tensor(rhs_ts.reshape(-1,1,grid,grid), dtype=torch.float32, device=device)
ys = torch.tensor(x_ts.reshape (-1,1,grid,grid), dtype=torch.float32, device=device)

results = []

for cfg_id, hp in enumerate(configs, 1):
    print(f"\n>>> run {cfg_id}/{len(configs)}  {hp}")

    model_data, spec = get_IAFNO_pt(grid, device, **hp)
    t0 = time.time()
    model, h_tr, h_ts = train_model_pt(
        model_data,
        [tr, ts],
        [yr, ys],
        spec,
        device,
        use_huber=True
    )
    dt = time.time() - t0

    model.eval()
    with torch.no_grad():
        err = relative_error(model(ts), ys)
    mean_err = err.mean().item()
    std_err  = err.std().item()

    row = deepcopy(hp)
    row.update({"mean": mean_err, "std": std_err, "time(min)": dt/60})
    results.append(row)

    print(f"err={mean_err:.3f} ± {std_err:.3f}  |  time {dt/60:.1f} min")

# -- сводная таблица и лучшая конфигурация --
df = pd.DataFrame(results).sort_values("mean")
display(df.head())
best = df.iloc[0].to_dict()
print("\nBEST:", best)
