# Model Testing Notebook

This notebook tests every model in the `model` folder, including initialization, forward and backward passes, grid generation, and parameter counting.

In [1]:
import torch
import numpy as np
import sys
sys.path.append('./model')
from model.custom_pfnn_kan import KoopmanAE_2d_kan
from model.pfnn import KoopmanAE_2d_trans
from model.pfnn_consist_2d import KoopmanAE_2d
from model.restormer_arch import Restormer
from model.koopman_base import dynamics, dynamics_back
# utilities.py may contain helper functions, not a model class


## Load and Initialize Models

Instantiate each model with example parameters. If pre-trained weights are available, they can be loaded here.

In [2]:
# KoopmanAE_2d_kan example
koopman_model = KoopmanAE_2d_kan(in_channel=1, out_channel=1, dim=4)

# KoopmanAE_2d_trans example (was PFNN)
pfnn_model = KoopmanAE_2d_trans(in_channel=1, out_channel=1, dim=4)

# KoopmanAE_2d example (was PFNNConsist2D)
pfnn_consist_model = KoopmanAE_2d(in_channel=1, out_channel=1, steps=1, steps_back=1)

# Restormer example (parameters may need adjustment based on actual class definition)
restormer_model = Restormer(inp_channels=1, out_channels=1, dim=4)


## Test Forward Dynamics

Pass a sample input tensor through each model in 'forward' mode and print the output shapes.

In [4]:
# Create a dummy input tensor (batch_size=2, channels=1, height=8, width=8)
dummy_input = torch.randn(2, 1, 64, 64)

# KoopmanAE_2d_kan forward
koopman_out, koopman_out_id = koopman_model(dummy_input, mode='forward')
print('KoopmanAE_2d_kan forward output:', [o.shape for o in koopman_out], 'Identity output:', [o.shape for o in koopman_out_id])

# PFNN forward
pfnn_out, pfnn_out_id = pfnn_model(dummy_input)
print('PFNN forward output:', [o.shape for o in pfnn_out], 'Identity output:', [o.shape for o in pfnn_out_id])

# PFNNConsist2D forward
pfnn_consist_out, pfnn_consist_out_id = pfnn_consist_model(dummy_input)
print('PFNNConsist2D forward output:', [o.shape for o in pfnn_consist_out], 'Identity output:', [o.shape for o in pfnn_consist_out_id])

# Restormer forward
restormer_out = restormer_model(dummy_input)
print('Restormer forward output:', restormer_out.shape)


KoopmanAE_2d_kan forward output: [torch.Size([2, 1, 64, 64])] Identity output: [torch.Size([2, 1, 64, 64])]
PFNN forward output: [torch.Size([2, 1, 64, 64])] Identity output: [torch.Size([2, 1, 64, 64])]
PFNNConsist2D forward output: [torch.Size([2, 1, 64, 64])] Identity output: [torch.Size([2, 1, 64, 64])]
Restormer forward output: torch.Size([2, 1, 64, 64])
PFNN forward output: [torch.Size([2, 1, 64, 64])] Identity output: [torch.Size([2, 1, 64, 64])]
PFNNConsist2D forward output: [torch.Size([2, 1, 64, 64])] Identity output: [torch.Size([2, 1, 64, 64])]
Restormer forward output: torch.Size([2, 1, 64, 64])


## Test Backward Dynamics

Pass a sample input tensor through KoopmanAE_2d_kan in 'backward' mode and print the output shapes.

In [5]:
koopman_back_out, koopman_back_out_id = koopman_model(dummy_input, mode='backward')
print('KoopmanAE_2d_kan backward output:', [o.shape for o in koopman_back_out], 'Identity output:', [o.shape for o in koopman_back_out_id])

KoopmanAE_2d_kan backward output: [torch.Size([2, 1, 64, 64])] Identity output: [torch.Size([2, 1, 64, 64])]


## Test Grid Generation

Call the get_grid method with sample parameters and print the grid shape.

In [6]:
grid = koopman_model.get_grid(S=8, batchsize=2, device=dummy_input.device)
print('Generated grid shape:', grid.shape)

Generated grid shape: torch.Size([2, 2, 8, 8])


## Count Model Parameters

Use the count_params method (if available) or sum up parameters for each model and print the total.

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('KoopmanAE_2d_kan parameters:', koopman_model.count_params() if hasattr(koopman_model, 'count_params') else count_parameters(koopman_model))
print('KoopmanAE_2d_trans parameters:', count_parameters(pfnn_model))
print('KoopmanAE_2d parameters:', count_parameters(pfnn_consist_model))
print('Restormer parameters:', count_parameters(restormer_model))

KoopmanAE_2d_kan parameters: 60873
KoopmanAE_2d_trans parameters: 180912
KoopmanAE_2d parameters: 740641
Restormer parameters: 226000
