# Imports

In [2]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import math
from functools import partial
from typing import Optional, Any, Dict, List
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
from functools import partial
import datetime
import json

import torch
from torch import nn, Tensor, optim
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import Dataset
from torchvision.transforms import v2
from torchinfo import summary

from PIL import Image
import torchview

import pynop
%matplotlib inline


In [8]:
a = np.random.random((3, 5, 4))
np.moveaxis(a,-1,0).shape

(4, 3, 5)

# FNO model
Here is an example of a model using FNO and U-FNO blocks as in [1] or [2].  
In these works, the number of channels is quite small (36), as they do not use tucker factorization. Likewise, the number of fourier modes is only 10.  
Therefore, this model is squite small, as it has only 1,175,241 parameters, thanks to the tucker factorization

These values could be increased for example to 64 or 128 channels and 16 or even 32 fourier modes. Specifying a spectral compression ratio of 2 can also reduce drastically the number of parameters, but the compression effect must be tested.

It is possible to add trainable positional encodings to the input data (set positional_encoding=True in the FNO class). You have to define the number of channels of this embedding (trainable_pos_encoding_dims) and its spectral representation (trainable_pos_encoding_modes).  
Constant positional encodings are also possible (fixed_positional_encoding=True). It adds the normalized x and y coordinates to the input (+ 2 channels).

[1] G. Wen, Z. Li, K. Azizzadenesheli, A. Anandkumar, and S. M. Benson. U-FNO—an enhanced
fourier neural operator-based deep-learning model for multiphase flow. Advances in Water
Resources, 163:104180, 2022  
[2] https://doi.org/10.1016/j.ress.2024.110392

In [3]:
FNOmodel = pynop.FNO(
    in_channels=1,
    out_channels=1,
    modes=(10, 10),
    hidden_channels=(36, 36, 36, 36, 36, 36),
    blocks=["FNO","FNO","FNO","UFNO","UFNO","UFNO"],
    spectral_compression_factor=(1, 1, 1),
    trainable_pos_encoding=True,
    fixed_pos_encoding=True,
    trainable_pos_encoding_modes=(16, 16),  # Only useful if pos_encoding == 'trainable'
    trainable_pos_encoding_dims=8,  # Only useful if pos_encoding == 'trainable' dimension of the embedding
)

In [4]:
summary(
    FNOmodel, input_size=(2, 1, 128, 128), depth=10, col_names=["input_size", "output_size", "num_params", "kernel_size"]
)

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

# CoDANO
This is a CoDANO model [3] using using FNO blocks to project the input variables to a latent space to be processed as tokens by a multi-head attention module.
In this implementation, each variable is represneted by a channel in the input tensor (e.g. the velocity field has 2 channels, the pressure field has 1 channel, etc.).
After adding an optional static field and the positional embeddings, the input tensor is lifted to a higher dimensal space using a linear layer (to hidden_variable_codimension channels).
In eahc CoDANO layer, the input tensor is projected to K, V, Q matrix using FNO blocks. Each variable is processed by the same FNO block (shared block) to obtain the corresponding token.
The Q, V, K tokens are then processed by a multi-head attention module, and the output is projected back to the original input space using a linear layer.

Note that if a static field is defined in the class initiation, it must be provided when calling the model. You have to add this in the train loop!  

This architecture can rapidly lead to very large models when the number of heads is increasing, even using a small latent dimension. 

[3] https://papers.nips.cc/paper_files/paper/2024/file/bc75fa9843a7905bbed9d83895a88f7f-Paper-Conference.pdf

In [5]:
CoDANO_model = pynop.CoDANO(
    ["ux", "uy"],
    n_heads=2,  # number of attention heads in each layer
    n_layers=4, # number of CoDANO Layers (multihead attention + projection)
    modes=(24, 24), # the number of modes in the FNO layers
    static_channel_dim=1, # the number of static channels (e.g. the diffusivity field)
    hidden_lifting_channels=64, # number of channels in the intermediate layer sof the lifting module (a series of 1x1 conv)
    hidden_variable_codimension=32, # the latent dimension of each variable (after lifting)
    positional_encoding_dim=8,  # the dimension of the positional encoding (before lifting)
    positional_encoding_modes = (16, 16), # the number of modes in the positional encoding
    per_channel_attention=False,    # if True, the attention is computed per channel (it decrease the number of parameters)
    spectral_compression_factor=(2, 2, 2), # the compression factor for the spectral compression (in the FNO layers).
    activation=partial(nn.GELU, approximate="tanh"), # the activation function (must be class)
    norm=partial(nn.InstanceNorm2d, affine=True), # the normalization layer (must be class)
    ndim=2, # the number of dimensions (only 2D)
)
# model(torch.rand((2, 2, 128, 128))).shape

In [6]:
summary(
    CoDANO_model,
    input_size=[(2, 2, 128, 128), (2, 1, 128, 128)], # Here we must pass the static field as an input to the model forward method -> [input_tensor, static field]
    depth=10,
    col_names=["input_size", "output_size", "num_params", "kernel_size"]
)

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

In [None]:
# serialize config dict
def serialize(dictionnary):
    serialized = {}
    for k, v in dictionnary.items():
        try:
            serialized[k] = v.__name__
        except AttributeError:
            serialized[k] = v
    return serialized

# Encoder-decoder
An exampler of the building of an encoder-decoder model using residual blocks.

In [None]:
base_block = partial(pynop.core.ResBlock, norm=[None, pynop.LayerNorm2d], activation=[nn.SiLU, None], use_bias=[True, False])
upblock = partial(pynop.InterpolateConvUpSampleLayer, kernel_size=1, factor=2)

In [None]:
enc_config = {
    "in_channels": 3,
    "latent_dim": 32,
    "block": base_block,
    "downblock": None,  # if downblock is None, the downsampling is done in the first block of each stage using a stride 2
    "depths": [2, 2, 2],
    "dims": [64, 64, 64],
    "stem_kernel_size": 3,
    "stem_norm": pynop.RMSNorm2d,
    "stem_activation": nn.SiLU,
}


dec_config = {
    "out_channels": 3,
    "latent_dim": 32,
    "block": base_block,
    "upblock": upblock,
    "depths": [2, 2, 2],
    "dims": [64, 64, 64],
}


encoder = pynop.models.Encoder(**enc_config)

decoder = pynop.models.Decoder(**dec_config)

In [None]:
# cfg = dc_ae_f32c32(name="dc-ae-f32c32-in-1.0", pretrained_path=None)
generator = pynop.models.AutoEncoder(encoder, decoder)

In [30]:
summary(
    generator, input_size=(2, 3, 96, 96), depth=8, col_names=["input_size", "output_size", "num_params", "kernel_size"]
)

Layer (type:depth-idx)                                  Input Shape               Output Shape              Param #                   Kernel Shape
AutoEncoder                                             [2, 3, 96, 96]            [2, 3, 96, 96]            --                        --
├─Encoder: 1-1                                          [2, 3, 96, 96]            [2, 32, 12, 12]           --                        --
│    └─ConvLayer: 2-1                                   [2, 3, 96, 96]            [2, 64, 96, 96]           --                        --
│    │    └─ModuleList: 3-1                             --                        --                        64                        --
│    │    │    └─ZeroPad2d: 4-1                         [2, 3, 96, 96]            [2, 3, 98, 98]            --                        --
│    │    │    └─Conv2d: 4-2                            [2, 3, 98, 98]            [2, 64, 96, 96]           1,728                     [3, 3]
│    │    └─RMSNorm2d: 3-2 

# UNET

In [2]:
from pynop.core.norm import LayerNorm2d
from pynop.models import unet

block = partial(pynop.core.ResBlock, norm=[None, pynop.LayerNorm2d], activation=[nn.GELU, None], use_bias=[True, False])
block = partial(pynop.ops.ConvLayer,norm=LayerNorm2d, activation=nn.GELU, stride=1)

unetmodel = unet(
    in_channels=1,
    out_channels=1,
    block=block,
    filters=(32, 64, 128, 256),
    repeats=(2, 2, 2, 2),
    stem_stride=2,
    stem_kernel_size=3,
    stem_norm=pynop.LayerNorm2d,
)

In [3]:
summary(
    unetmodel, input_size=(2, 1, 128, 128), depth=8, col_names=["input_size", "output_size", "num_params", "kernel_size"]
)



Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape
unet                                          [2, 1, 128, 128]          [2, 1, 128, 128]          --                        --
├─ConvLayer: 1-1                              [2, 1, 128, 128]          [2, 32, 64, 64]           --                        --
│    └─ModuleList: 2-1                        --                        --                        64                        --
│    │    └─ZeroPad2d: 3-1                    [2, 1, 128, 128]          [2, 1, 130, 130]          --                        --
│    │    └─Conv2d: 3-2                       [2, 1, 130, 130]          [2, 32, 64, 64]           288                       [3, 3]
│    └─LayerNorm2d: 2-2                       [2, 32, 64, 64]           [2, 32, 64, 64]           64                        --
│    └─GELU: 2-3                              [2, 32, 64, 64]           [2, 32, 64, 64]          

In [9]:
block(1,32,stride=2)(torch.rand(2,1,128,128)).shape

torch.Size([2, 32, 64, 64])