# Imports

In [1]:
import os
import glob
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import math
from functools import partial
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
from functools import partial
import datetime
import json
from PIL import Image
from attr import dataclass

import torch
from torch import nn, Tensor, optim
import torch.nn.functional as F
from torchvision import datasets
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision.transforms import v2
from torchinfo import summary

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

import pynop
%matplotlib inline


# Model creation

In [2]:
model = pynop.LITNet(
    in_channels=2,
    out_channels=2,
    modes=16,
    hidden_channels=[48, 48, 48, 48, 48],
    block=partial(pynop.ITBlock, compute_ortho_loss=True, sampling=int(64*64)),
    mlp_layers=2,
    mlp_dim=64,
    activation=nn.GELU,
    norm=pynop.LayerNorm2d,
    fixed_pos_encoding=True,
    trainable_pos_encoding=True,
    trainable_pos_encoding_dims=8,
).to("cuda")

In [3]:
summary(model, input_size=(1,2,128,128))

Layer (type:depth-idx)                        Output Shape              Param #
LITNet                                        [1, 2, 128, 128]          2,048
├─CartesianEmbedding: 1-1                     [1, 4, 128, 128]          --
├─ITDecoder: 1-2                              [1, 8, 128, 128]          16,384
│    └─MLPBlock: 2-1                          [16384, 256]              --
│    │    └─Sequential: 3-1                   [16384, 512]              37,888
│    └─Conv2d: 2-2                            [1, 8, 128, 128]          72
│    └─LayerNorm2d: 2-3                       [1, 8, 128, 128]          16
│    └─GELU: 2-4                              [1, 8, 128, 128]          --
├─Conv2d: 1-3                                 [1, 48, 128, 128]         624
├─ModuleList: 1-4                             --                        --
│    └─ITBlock: 2-5                           [1, 48, 128, 128]         589,824
│    │    └─MLPBlock: 3-2                     [16384, 256]              40,960

# Data loading, creation of the train/val sets and dataloaders

In [4]:
# datapath = Path("F:/Projets/2D_diff-react_NA_NA.h5")
datapath = Path("/media/jlux/SSD2/pdebench/2d_reaction_diffusion/133017.hdf5")
train_set = pynop.UnrolledH5Dataset(
    datapath, T_unroll=10, step=10, load_in_ram=True, split_type="train", split_ratio=0.8, seed=42
)
val_set = pynop.UnrolledH5Dataset(
    datapath, T_unroll=10, step=10, load_in_ram=True, split_type="val", split_ratio=0.8, seed=42
)

Scanning train: 100%|██████████| 800/800 [00:04<00:00, 197.18it/s]


Split train: 800 simulations
Total unrolled windows: 8000
Status: All data loaded in RAM.


Scanning val: 100%|██████████| 200/200 [00:00<00:00, 208.01it/s]

Split val: 200 simulations
Total unrolled windows: 2000
Status: All data loaded in RAM.





In [5]:
batch_size = 6
train_dataloader = DataLoader(train_set, shuffle=True, batch_size=batch_size, num_workers=10)
valid_dataloader = DataLoader(
    val_set, shuffle=False, batch_size=batch_size, num_workers=10, persistent_workers=True, pin_memory=True
)

# Preparing the training

In [9]:
# baselogdir = Path("F:/Projets/NLIT")
baselogdir = Path("/media/jlux/SSD2/ReactionDiffusion/NLIT")

now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = Path(baselogdir / f"{now}")
callbacks = []

loggers = [TensorBoardLogger(logdir / Path("tb_logs"), name="NLIT_RD"), CSVLogger(logdir, name="NLIT_RD")]

callbacks.append(
    ModelCheckpoint(
        monitor="val_loss", filename=os.path.join(logdir, "best_val_loss"), mode="min", save_top_k=2, save_last=False
    )
)

callbacks.append(
    ModelCheckpoint(
        monitor="loss",
        filename=os.path.join(logdir, "best_train_loss"),
        mode="min",
        save_top_k=2,
        save_last=False,
    )
)

callbacks.append(LearningRateMonitor(logging_interval="epoch"))
# callbacks.append(EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min"))


# print values in scientific format
class CustomProgressBar(TQDMProgressBar):
    def get_metrics(self, trainer, model):
        items = super().get_metrics(trainer, model)
        # On applique le format scientifique à tout le dictionnaire
        return {k: (f"{v:.3e}" if isinstance(v, float) else v) for k, v in items.items()}


# callbacks.append(TQDMProgressBar(leave=True))
callbacks.append(CustomProgressBar(leave=True))

train_config = pynop.TrainingSchedule(
    start_autoregressive=0,
    final_autoregressive=50,
    min_autoregressive_steps=5,
    max_autoregressive_steps=9,
    detach_grad_steps=4,
    loss_fn=torch.nn.MSELoss(),
    force_orthogonality=True,
    ortho_weight=1,
    noise_level=1e-4,
)

max_epochs = 50
lr = 1e-4

scheduler_config = [
    {
        "scheduler": ReduceLROnPlateau,
        "mode": "min",
        "patience": 10,
        "factor": 0.5,
        "monitor": "MSE",
        "interval": "epoch",
        "frequency": 1,
        "cooldown": 0,
    },
]

# scheduler_config = [
#     {
#         "scheduler": CosineAnnealingLR,
#         "T_max": max_epochs * int(len(train_set) / batch_size),
#         "eta_min":lr*1e-2,
#         "interval": "step",
#         "frequency": 1,
#     },
# ]


optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-4)

# From scratch
lightning_model = pynop.Model(
    model=model, train_config=train_config, optimizer=optimizer, scheduler_config=scheduler_config
)
# checkpoint = torch.load(Path("/media/jlux/SSD2/ReactionDiffusion/NLIT/20260111-145353/best_train_loss.ckpt"))
# lightning_model.load_state_dict(checkpoint['state_dict'])

# torch.set_float32_matmul_precision('highest')
torch.set_float32_matmul_precision("high")

trainer = pl.Trainer(
    max_epochs=max_epochs,
    limit_train_batches=0.34,
    callbacks=callbacks,
    accelerator="gpu",
    logger=loggers,
    num_sanity_val_steps=1,
)
# gradient_clip_val=1,
# gradient_clip_algorithm="norm")

print("Log dir:", logdir)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Log dir: /media/jlux/SSD2/ReactionDiffusion/NLIT/20260113-093234


# Training the model

In [None]:
trainer.fit(lightning_model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type       | Params | Mode 
------------------------------------------------------
0 | model          | LITNet     | 3.2 M  | train
1 | loss_fn        | MSELoss    | 0      | train
2 | train_loss_avg | MeanMetric | 0      | train
------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.940    Total estimated model params size (MB)
90        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]