# Imports

- https://github.com/camlab-ethz/ConvolutionalNeuralOperator/tree/main/CNO2d_vanilla_torch_version

In [1]:
from __future__ import annotations

import torch
from neuralop.models import FNO

from operator_aliasing.models.cno2d import CNO2d
from operator_aliasing.models.FNOModules import FNO2d

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Multi-resolution inference

In [2]:
x_data = torch.rand((32, 1, 128, 128)).type(torch.float32).to(device)
y_data = torch.ones((32, 1, 128, 128)).type(torch.float32).to(device)

In [3]:
N_layers = 4
N_res = 4
N_res_neck = 4
channel_multiplier = 16

s = 128

model_args = {
    'in_dim': 1,  # Number of input channels.
    'out_dim': 1,  # Number of output channels.
    'size': s,  # Input and Output spatial size (required )
    'n_layers': N_layers,  # Number of (D) or (U) blocks in the network
    'n_res': N_res,  # Number of (R) blocks per level (except the neck)
    'n_res_neck': N_res_neck,  # Number of (R) blocks in the neck
    'channel_multiplier': channel_multiplier,  # How num channels evolve?
    'use_bn': False,
}
cno = CNO2d(**model_args)
cno = cno.to(device)

In [4]:
max_modes = 16
starting_modes = (max_modes, max_modes)
model = FNO(
    max_n_modes=(max_modes, max_modes),
    n_modes=starting_modes,
    hidden_channels=32,
    in_channels=1,
    out_channels=1,
).to(device)

In [5]:
fno_architecture = {
    'width': 64,
    'modes': 16,
    'FourierF': 0,  # Num Fourier Features in the input channels. Default is 0.
    'n_layers': 4,  # Number of Fourier layers
    'padding': 0,
    'include_grid': 0,
    'retrain': 4,  # Random seed
}

fno = FNO2d(fno_architecture, device=device).to(device)

x_data_transformed = torch.rand((32, 64, 64, 1)).type(torch.float32).to(device)

fno(x_data_transformed).shape

torch.Size([32, 64, 64, 1])

In [6]:
model(x_data).shape

torch.Size([32, 1, 128, 128])

In [7]:
4096 / 32

128.0

In [8]:
cno(x_data)

tensor([[[[-0.0449, -0.0808, -0.0779,  ..., -0.0749, -0.0727, -0.0596],
          [-0.0289, -0.0633, -0.0559,  ..., -0.0554, -0.0579, -0.0491],
          [-0.0268, -0.0659, -0.0611,  ..., -0.0675, -0.0663, -0.0533],
          ...,
          [-0.0233, -0.0571, -0.0546,  ..., -0.0547, -0.0623, -0.0539],
          [-0.0226, -0.0519, -0.0505,  ..., -0.0555, -0.0605, -0.0544],
          [-0.0169, -0.0266, -0.0284,  ..., -0.0346, -0.0408, -0.0457]]],


        [[[-0.0435, -0.0793, -0.0748,  ..., -0.0761, -0.0730, -0.0590],
          [-0.0275, -0.0635, -0.0554,  ..., -0.0589, -0.0594, -0.0478],
          [-0.0274, -0.0701, -0.0648,  ..., -0.0660, -0.0660, -0.0521],
          ...,
          [-0.0226, -0.0578, -0.0576,  ..., -0.0532, -0.0627, -0.0549],
          [-0.0211, -0.0520, -0.0554,  ..., -0.0514, -0.0591, -0.0540],
          [-0.0164, -0.0273, -0.0308,  ..., -0.0294, -0.0394, -0.0462]]],


        [[[-0.0441, -0.0808, -0.0786,  ..., -0.0751, -0.0736, -0.0593],
          [-0.0290, -0.065

In [9]:
x_data_small = torch.rand((32, 1, 64, 64)).type(torch.float32).to(device)
y_data_small = torch.ones((32, 1, 64, 64)).type(torch.float32).to(device)
cno(x_data_small).shape

# model(x_data_small).shape

torch.Size([32, 1, 128, 128])

In [10]:
x_data_large = torch.rand((32, 1, 256, 256)).type(torch.float32).to(device)
y_data_large = torch.ones((32, 1, 256, 256)).type(torch.float32).to(device)
cno(x_data_large).shape

torch.Size([32, 1, 128, 128])

In [17]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

# --------------------------------------
# REPLACE THIS PART BY YOUR DATALOADER
# --------------------------------------

n_train = 100  # number of training samples

x_data = (
    torch.rand((256, 1, 128, 128))
    .type(torch.float32)
    .reshape(256, 128, 128, 1)
)
y_data = (
    torch.ones((256, 1, 128, 128))
    .type(torch.float32)
    .reshape(256, 128, 128, 1)
)

input_function_train = x_data[:n_train, :]
output_function_train = y_data[:n_train, :]
input_function_test = x_data[n_train:, :]
output_function_test = y_data[n_train:, :]

batch_size = 10

training_set = DataLoader(
    TensorDataset(input_function_train, output_function_train),
    batch_size=batch_size,
    shuffle=True,
)
testing_set = DataLoader(
    TensorDataset(input_function_test, output_function_test),
    batch_size=batch_size,
    shuffle=False,
)


# ---------------------
# Define the hyperparameters and the model:
# ---------------------

learning_rate = 0.001
epochs = 50
step_size = 15
gamma = 0.5

N_layers = 4
N_res = 4
N_res_neck = 4
channel_multiplier = 16

s = 128

cno = CNO2d(**model_args)

# -----------
# TRAIN:
# -----------

cno = cno.to(device)

# cno = model.to(device)

cno = fno.to(device)

optimizer = AdamW(cno.parameters(), lr=learning_rate, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=step_size, gamma=gamma
)

loss = nn.L1Loss()
freq_print = 1


for epoch in range(epochs):
    train_mse = 0.0
    for _step, (input_data, output_data) in enumerate(training_set):
        input_batch = input_data.to(device)
        output_batch = output_data.to(device)
        print(input_batch.shape)
        optimizer.zero_grad()
        output_pred_batch = cno(input_batch)
        loss_f = loss(output_pred_batch, output_batch)
        loss_f.backward()
        optimizer.step()
        train_mse += loss_f.item()
    train_mse /= len(training_set)

    scheduler.step()

    with torch.no_grad():
        cno.eval()
        test_relative_l2 = 0.0
        for _step, (input_data, output_data) in enumerate(testing_set):
            input_batch = input_data.to(device)
            output_batch = output_data.to(device)
            output_pred_batch = cno(input_batch)
            loss_f = (
                torch.mean(abs(output_pred_batch - output_batch))
                / torch.mean(abs(output_batch))
            ) ** 0.5 * 100
            test_relative_l2 += loss_f.item()
        test_relative_l2 /= len(testing_set)

    if epoch % freq_print == 0:
        print(
            '######### Epoch:',
            epoch,
            ' ######### Train Loss:',
            train_mse,
            ' ######### Relative L1 Test Norm:',
            test_relative_l2,
        )

torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
######### Epoch: 0  ######### Train Loss: 0.054802651389036325  ######### Relative L1 Test Norm: 15.84524949391683
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
######### Epoch: 1  ######### Train Loss: 0.017819717072416096  ######### Relative L1 Test Norm: 14.050152778625488
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128, 256, 1])
torch.Size([10, 128

In [18]:
def get_model_preds(
    test_loader: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    device: torch.device,
    # data_transform: DataProcessor,
) -> torch.Tensor:
    """Return model predictions."""
    model_preds = []
    for _idx, sample in enumerate(test_loader):  # resolution 128
        model_input = sample['x'].to(device)
        with torch.no_grad():
            out = model(model_input)
            model_preds.append(out)
    return torch.cat(model_preds)

In [None]:
for _idx, sample in enumerate(test_loader):  # resolution 128
    model_input = sample['x'].to(device)
    with torch.no_grad():
        out = model(model_input)
        model_preds.append(out)
return torch.cat(model_preds)