# Imports

In [None]:
import os
import glob
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import math
from functools import partial
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
from functools import partial
import datetime
import json
from PIL import Image
from attr import dataclass

import torch
from torch import nn, Tensor, optim
import torch.nn.functional as F
from torchvision import datasets
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, StepLR, SequentialLR, LinearLR
from torch.utils.data import  DataLoader
from torchvision.transforms import v2
from torchinfo import summary

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import pynop

%matplotlib inline


# Model creation


In [2]:
# block = partial(pynop.RITBlock, compute_ortho_loss=False, sampling=int(64 * 64))

# model = pynop.SharedLITNet(
#     in_channels=2,
#     out_channels=2,
#     modes=16,
#     hidden_channels=[128, 128, 128, 128, 128],
#     mlp_layers=2,
#     mlp_dim=256,
#     nonlinear=True,
#     activation=nn.GELU,
#     mlp_act=pynop.Sine,
#     norm=pynop.AdaptiveLayerNorm,
#     fixed_pos_encoding=True,
#     compute_ortho_loss=False,
#     ortho_loss_sampling=64 * 64,
#     dim=3,
# ).to("cuda")

# model = pynop.TLNO(
#         in_channels= 2,
#         out_channels= 2,
#         d_model= 256,
#         modes =  256,
#         num_heads=  4,
#         num_blocks = 4,
#         mlp_layers = 2,
#         mlp_hidden_dim = 256,
#         dropout=0.05,
#         transformer_mlp_factor= 2,
#         activation=nn.GELU,
#         trunk_activation=nn.GELU,
#         nonlinear=True,
#         dim=3,
#     ).to("cuda")

# model = pynop.ITLNO(
#     in_channels=2,
#     out_channels=2,
#     modes=16,
#     num_blocks=4,
#     hidden_channels=256,
#     num_heads=4,
#     linear_kernel=False,
#     mlp_layers=2,
#     mlp_dim=128,
#     mlp_activation=nn.GELU,# pynop.Sine,
#     compute_ortho_loss=False,
#     orth_loss_sampling=2048,
#     dim=3,
# ).to("cuda")


# model = pynop.LITNet(
#     in_channels=2,
#     out_channels=2,
#     modes=16,
#     block= RITBlock,
#     hidden_channels=[256, 256, 256, 256],
#     nonlinear=True,
#     mlp_layers=2,
#     mlp_dim=128,
#     mlp_act=nn.GELU,# pynop.Sine,
#     norm = pynop.AdaptiveLayerNorm,
#     fixed_pos_encoding=True,
#     dim=3,
# ).to("cuda")


model = pynop.Transolver(
    in_ch=2,
    out_ch=2,
    slice_num=64,
    n_layers=4,
    n_hidden=256,
    dropout=0,
    n_head=4,
    activation=nn.GELU,
    mlp_ratio=1,
    dim=3,
    cond_dim=None,
).to("cuda")

In [3]:
# summary(model, input_size=(1,2,128,128))
data = torch.rand((2, 2, 128, 128)).to("cuda")
time = torch.rand(2).unsqueeze(-1).float().to("cuda")
summary(model, input_data=(data, time), depth=4)

Layer (type:depth-idx)                   Output Shape              Param #
Transolver                               [2, 2, 128, 128]          256
├─MLPBlock: 1-1                          [2, 16384, 256]           --
│    └─Sequential: 2-1                   [2, 16384, 256]           --
│    │    └─Linear: 3-1                  [2, 16384, 512]           3,072
│    │    └─GELU: 3-2                    [2, 16384, 512]           --
│    │    └─Linear: 3-3                  [2, 16384, 512]           262,656
│    │    └─GELU: 3-4                    [2, 16384, 512]           --
│    │    └─Linear: 3-5                  [2, 16384, 256]           131,328
├─ModuleList: 1-2                        --                        --
│    └─TransolverBlock: 2-2              [2, 16384, 256]           4
│    │    └─LayerNorm: 3-6               [2, 16384, 256]           512
│    │    └─Linear: 3-7                  [2, 16384, 256]           65,792
│    │    └─Sequential: 3-8              [2, 4, 16384, 1]          

# Data loading, creation of the train/val sets and dataloaders

In [4]:
# datapath = Path("F:/Projets/2D_diff-react_NA_NA.h5")
datapath = Path("/media/jlux/SSD2/pdebench/2d_reaction_diffusion/133017.hdf5")
train_set = pynop.PDEBenchDataSet(
    datapath, T_unroll=10, step=10, load_in_ram=True, split_type="train", split_ratio=0.9, n_samples=100, seed=42
)
val_set = pynop.PDEBenchDataSet(
    datapath, T_unroll=10, step=10, load_in_ram=True, split_type="val", split_ratio=0.9, n_samples=100, seed=42
)

Scanning train: 100%|██████████| 90/90 [00:00<00:00, 200.55it/s]


Split train: 90 simulations
Total unrolled windows: 900
Status: All data loaded in RAM.


Scanning val: 100%|██████████| 10/10 [00:00<00:00, 207.93it/s]

Split val: 10 simulations
Total unrolled windows: 100
Status: All data loaded in RAM.





In [5]:
batch_size = 4
train_dataloader = DataLoader(train_set, shuffle=True, batch_size=batch_size, num_workers=10)
valid_dataloader = DataLoader(
    val_set, shuffle=False, batch_size=batch_size, num_workers=10, persistent_workers=True, pin_memory=True
)

# Preparing the training
**Notes:**

- It seems better to use the class `pynop.nL2Loss` or `pynop.nMSELoss` for the loss. It computes the normalized MSE or RMSE *by canal* (i.e. for each physical variable), especially if the variables have different order of magnitude.

- The metric RMSE, which is logged during training, is a mean over all spatial, channel and batch dimension. It is used here as it is often the metric given in the papers.

In [None]:
train_config = pynop.TrainingSchedule(
    start_autoregressive=0,
    final_autoregressive=50,
    min_autoregressive_steps=4,
    max_autoregressive_steps=8,
    detach_grad_steps=4,
    loss_fn=pynop.nMSELoss(),#torch.nn.MSELoss(),
    orthogonality_loss=False,
    ortho_weight=0.0,
    noise_level=5e-4,
    ortho_mode="model",
    time_normalization=100,
    residual=True
)



model_name = model.__class__.__name__
loss_name = str(type(train_config.loss_fn)).split(".")[-1]
folder = f"{model_name}" #_{loss_name}"
if train_config.residual:
    folder = folder + "_res"
folder = Path(folder)
# baselogdir = Path("F:/Projets/NLIT")
baselogdir = Path("/media/jlux/SSD2/ReactionDiffusion/runs") / folder
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
now = "20260123-140431"
logdir = Path(baselogdir / f"{now}")
callbacks = []

loggers = [TensorBoardLogger(logdir / Path("tb_logs")), CSVLogger(logdir)]

callbacks.append(
    ModelCheckpoint(
        monitor="val_loss", filename=os.path.join(logdir, "best_val_loss"), mode="min", save_top_k=2, save_last=False
    )
)

callbacks.append(
    ModelCheckpoint(
        monitor="loss",
        filename=os.path.join(logdir, "best_train_loss"),
        mode="min",
        save_top_k=2,
        save_last=False,
    )
)
callbacks.append(
    ModelCheckpoint(
        dirpath=logdir,
        filename="last",
        every_n_epochs=1,
        save_top_k=1,
        mode="max",
    )
)

callbacks.append(LearningRateMonitor(logging_interval="epoch"))
# callbacks.append(EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min"))


# print values in scientific format
class CustomProgressBar(TQDMProgressBar):
    def get_metrics(self, trainer, model):
        items = super().get_metrics(trainer, model)
        # On applique le format scientifique à tout le dictionnaire
        return {k: (f"{v:.3e}" if isinstance(v, float) else v) for k, v in items.items()}


# callbacks.append(TQDMProgressBar(leave=True))
callbacks.append(CustomProgressBar(leave=True))

max_epochs = 200
lr = 4e-4

# scheduler_config = [
#     {
#         "scheduler": ReduceLROnPlateau,
#         "mode": "min",
#         "patience": 15,
#         "factor": 0.5,
#         "monitor": "MSE",
#         "interval": "epoch",
#         "frequency": 1,
#         "cooldown": 5,
#     },
# ]


scheduler_config = [
    {
        "scheduler": StepLR,
        "step_size": 25,
        "gamma": 0.5,
        "interval": "epoch",
    },
]

# scheduler_config = [
#     {
#         "scheduler": CosineAnnealingLR,
#         "T_max": max_epochs * int(len(train_set) / batch_size),
#         "eta_min":lr*1e-2,
#         "interval": "step",
#         "frequency": 1,
#     },
# ]


optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=5e-5)

# From scratch
# lightning_model = pynop.ITModel(
#     model=model, train_config=train_config, optimizer=optimizer, scheduler_config=scheduler_config
# )
lightning_model = pynop.LNOModel(
    model=model, train_config=train_config, optimizer=optimizer, scheduler_config=scheduler_config
)
checkpoint = torch.load(Path("/media/jlux/SSD2/ReactionDiffusion/runs/Transolver_res/20260123-140431/best_train_loss.ckpt"))
lightning_model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_states'][0])

torch.set_float32_matmul_precision("highest")
# torch.set_float32_matmul_precision("high")

trainer = pl.Trainer(
    max_epochs=max_epochs,
    # limit_train_batches=0.5,
    callbacks=callbacks,
    accelerator="gpu",
    logger=loggers,
    num_sanity_val_steps=0,
    log_every_n_steps=50,
    # gradient_clip_val=1,
    # gradient_clip_algorithm="norm",
)

print("Log dir:", logdir)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Log dir: /media/jlux/SSD2/ReactionDiffusion/runs/Transolver_res/20260123-140431


# Tensorboard

In [None]:
%load_ext tensorboard
# %tensorboard --logdir $baselogdir


# Training the model

In [None]:
trainer.fit(lightning_model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)

You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/jlux/miniforge3/envs/pytorch/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /media/jlux/SSD2/ReactionDiffusion/runs/Transolver_res/20260123-140431 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type       | Params | Mode 
------------------------------------------------------
0 | model          | Transolver | 1.9 M  | train
1 | loss_fn        | nMSELoss   | 0      | train
2 | train_loss_avg | MeanMetric | 0      | train
------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.

Training: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
# torch.save(model.state_dict(), baselogdir / Path("last.pth"))

# END