In [1]:
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
from jaxtyping import Array
import sys
sys.path.append('..')
from networks import FNO2d, FNO2d_2
from networks.fno2d import Hparams, compute_loss
from utils import *

Trainer.compute_loss = staticmethod(compute_loss)

In [2]:
import equinox as eqx
import jax.numpy as jnp
from jax import vmap
import jax_dataloader as jdl
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
from utils import *

import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard
import argparse
from optax.contrib import reduce_on_plateau

running_on = "local"
problem = "kdv"

if running_on == "local":
    data_path = "C:/Users/eirik/OneDrive - NTNU/5. klasse/prosjektoppgave/eirik_prosjektoppgave/data/"
    checkpoint_path = "C:/Users/eirik/orbax/"
    optuna_path = ""
elif running_on == "idun":
    data_path = "/cluster/work/eirikaf/data/"
    checkpoint_path = "/cluster/work/eirikaf/"
    optuna_path = "phlearn-summer24/eirik_prosjektoppgave/"
else:
    raise ValueError("Invalid running_on")

if problem == "advection":
    data = jnp.load(data_path + "advection.npz")
elif problem == "kdv":
    data = jnp.load(data_path + "kdv_newest.npz")
else:
    raise ValueError("Invalid problem")

# SPLIT DATA
train_val_test = train_val_test_split(jnp.array(data['data']), 0.8, 0.1, 0.1)

# SCALE DATA
scaled_data = scale_data(jnp.array(data['x']), jnp.array(data['t']), train_val_test)

a_train_s = scaled_data["a_train_s"]
u_train_s = scaled_data["u_train_s"]
a_val_s = scaled_data["a_val_s"]
u_val_s = scaled_data["u_val_s"]

x_train_s = scaled_data["x_train_s"]
t_train_s = scaled_data["t_train_s"]

In [3]:
def conjugate_grads_transform():
    def init_fn(params):
        return None  # No state needed

    def update_fn(updates, state, params=None):
        # Conjugate the gradients if they are complex
        updates = jax.tree_util.tree_map(
            lambda g: jnp.conj(g) if jnp.iscomplexobj(g) else g, updates
        )
        return updates, state
    return optax.GradientTransformation(init_fn, update_fn)

hparams = Hparams(n_blocks = 3, hidden_dim=100, modes_max = 10)
model_fno2d_2 = FNO2d_2(hparams)
model_fno2d = FNO2d(hparams)

In [4]:
eqx.filter([model_fno2d_2], eqx.is_array)

[FNO2d_2(
   self_adaptive=None,
   is_self_adaptive=None,
   lifting=Conv2d(
     num_spatial_dims=2,
     weight=f64[100,3,1,1],
     bias=f64[100,1,1],
     in_channels=3,
     out_channels=100,
     kernel_size=(1, 1),
     stride=(1, 1),
     padding=((0, 0), (0, 0)),
     dilation=(1, 1),
     groups=1,
     use_bias=True,
     padding_mode='ZEROS'
   ),
   fno_blocks=FNOBlock2d(
     spectral_conv=SpectralConv2d(
       weights1=c128[3,100,100,10,10],
       weights2=c128[3,100,100,10,10],
       in_channels=None,
       out_channels=None,
       modes1=None,
       modes2=None
     ),
     bypass=f64[3,100,100],
     activation=None
   ),
   projection=Conv2d(
     num_spatial_dims=2,
     weight=f64[1,100,1,1],
     bias=f64[1,1,1],
     in_channels=100,
     out_channels=1,
     kernel_size=(1, 1),
     stride=(1, 1),
     padding=((0, 0), (0, 0)),
     dilation=(1, 1),
     groups=1,
     use_bias=True,
     padding_mode='ZEROS'
   )
 )]

In [5]:
# DATALOADERS
train_loader_fno2d_2 = jdl.DataLoader(jdl.ArrayDataset(a_train_s, u_train_s, asnumpy = False), batch_size=16, shuffle=True, backend='jax', drop_last=True)
val_loader_fno2d_2 = jdl.DataLoader(jdl.ArrayDataset(a_val_s, u_val_s, asnumpy = False), batch_size=16, shuffle=True, backend='jax', drop_last=True)

opt_fno2d_2 = optax.chain(
    conjugate_grads_transform(),  # Conjugate gradients
    optax.adam(1e-3),             # Adam optimizer
)
opt_state_fno2d_2 = opt_fno2d_2.init(eqx.filter([model_fno2d_2], eqx.is_array))
trainer_fno_2d_2 = Trainer(model_fno2d_2, opt_fno2d_2, opt_state_fno2d_2, train_loader_fno2d_2, val_loader_fno2d_2, x = x_train_s, t = t_train_s)

In [6]:
from time import perf_counter

In [7]:
start = perf_counter()
trainer_fno_2d_2(5)
end = perf_counter()
print(f"Time FNO2d_2: {end-start}")

Output()

Time FNO2d_2: 1167.2518066000193


In [8]:
# DATALOADERS
train_loader_fno2d = jdl.DataLoader(jdl.ArrayDataset(a_train_s, u_train_s, asnumpy = False), batch_size=16, shuffle=True, backend='jax', drop_last=True)
val_loader_fno2d = jdl.DataLoader(jdl.ArrayDataset(a_val_s, u_val_s, asnumpy = False), batch_size=16, shuffle=True, backend='jax', drop_last=True)

opt_fno2d = optax.adam(1e-3)
opt_state_fno2d = opt_fno2d.init(eqx.filter([model_fno2d], eqx.is_array))
trainer_fno_2d = Trainer(model_fno2d, opt_fno2d, opt_state_fno2d, train_loader_fno2d, val_loader_fno2d, x = x_train_s, t = t_train_s)

In [9]:
start = perf_counter()
trainer_fno_2d(5)
end = perf_counter()
print(f"Time FNO2d: {end-start}")

Output()

Time FNO2d: 1189.897722100024


In [11]:
print(trainer_fno_2d.time_trained)
print(trainer_fno_2d_2.time_trained)

1184.1909269001335
1161.493817999959


In [12]:
plot_loss(trainer_fno_2d)

In [15]:
plot_predictions(u_val_s, a_val_s, x_train_s, t_train_s, trainer_fno_2d.model, trainer_fno_2d_2.model)

In [13]:
plot_loss(trainer_fno_2d_2)

In [20]:
print("Number of parameters in FNO2d: ")
print(param_count(trainer_fno_2d.model))
print("Number of parameters in FNO2d_2: ")
print(param_count(trainer_fno_2d_2.model))

Number of parameters in FNO2d: 
12030501
Number of parameters in FNO2d_2: 
6030501
