In [2]:
import sys
import os
sys.path.append("..")

import jax
jax.config.update("jax_enable_x64", True)
import equinox as eqx
import jax.numpy as jnp
from jax import vmap
import jax_dataloader as jdl
from utils import *

import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard
import argparse
from optax.contrib import reduce_on_plateau

problem = "advection"
network = "deeponet"
running_on = "local"
num_epochs = 5
use_hino = True

if running_on == "local":
    data_path = "C:/Users/eirik/OneDrive - NTNU/5. klasse/prosjektoppgave/eirik_prosjektoppgave/data/"
    hparams_path = "C:/Users/eirik/OneDrive - NTNU/5. klasse/prosjektoppgave/eirik_prosjektoppgave/hyperparameters/"
    checkpoint_path = "C:/Users/eirik/orbax/checkpoints/"
elif running_on == "idun":
    data_path = "/cluster/work/eirikaf/data/"
    hparams_path = "/cluster/home/eirikaf/phlearn-summer24/eirik_prosjektoppgave/hyperparameters/"
    checkpoint_path = "/cluster/work/eirikaf/checkpoints/"
else:
    raise ValueError("Invalid running_on")

scaled_data = jnp.load(data_path + problem + "_scaled_data.npz")
a_train_s = jnp.array(scaled_data["a_train_s"])
u_train_s = jnp.array(scaled_data["u_train_s"])
a_val_s = jnp.array(scaled_data["a_val_s"])
u_val_s = jnp.array(scaled_data["u_val_s"])

x_train_s = jnp.array(scaled_data["x_train_s"])
t_train_s = jnp.array(scaled_data["t_train_s"])

# DATALOADERS
train_loader = jdl.DataLoader(jdl.ArrayDataset(a_train_s, u_train_s, asnumpy = False), batch_size=16, shuffle=True, backend='jax', drop_last=True)
val_loader = jdl.DataLoader(jdl.ArrayDataset(a_val_s, u_val_s, asnumpy = False), batch_size=16, shuffle=True, backend='jax', drop_last=True)

# AUTOPARALLELISM
sharding_a, sharding_u, replicated = create_device_mesh()

# IMPORT WANTED NETWORK ARCHITECTURE
if network == "deeponet":
    from networks.deeponet import DeepONet as OperatorNet
    from networks.deeponet import Hparams as OperatorHparams
    if use_hino:
        from networks.hino_DON import *
        from networks.hino_DON import HINO_DON as HamiltonianNet
    else:
        from networks.hno_DON import *
        from networks.hno_DON import HNO_DON as HamiltonianNet
elif network == "modified_deeponet":
    from networks.modified_deeponet import ModifiedDeepONet as OperatorNet
    from networks.modified_deeponet import Hparams as OperatorHparams
    if use_hino:
        from networks.hino_DON import *
        from networks.hino_DON import HINO_DON as HamiltonianNet
    else:
        from networks.hno_DON import *
        from networks.hno_DON import HNO_DON as HamiltonianNet
elif network == "fno1d":
    if use_hino:
        from networks.hino_DON import *
        from networks.hino_DON import HINO_DON as HamiltonianNet
    else:
        from not_in_use.hno_DON import *
elif network == "fno2d":
    if use_hino:
        from networks.hino_DON import *
        from networks.hino_DON import HINO_DON as HamiltonianNet
    else:
        from not_in_use.hno_DON import *
elif network == "fno_timestepping":
    if use_hino:
        from networks.hino_DON import *
        from networks.hino_DON import HINO_DON as HamiltonianNet
    else:
        from not_in_use.hno_DON import *
else:
    raise ValueError("Invalid network")

operator_trainer = Trainer.from_checkpoint(checkpoint_path+f"{network}_{problem}", 
                                           OperatorNet, 
                                           Hparams=OperatorHparams,
                                           replicated=replicated,)
operator_net = operator_trainer.model
operator_net_hparams = operator_trainer.hparams

if use_hino:
    Trainer.compute_loss = staticmethod(compute_loss_hino)
    Trainer.evaluate = eqx.filter_jit(staticmethod(evaluate_hino), donate="all-except-first")
else:
    Trainer.compute_loss = staticmethod(compute_loss_hno)
    Trainer.evaluate = eqx.filter_jit(staticmethod(evaluate_hno), donate="all-except-first")

# IMPORT HYPERPARAMETERS
with open(hparams_path + "energy_net" + '.json', "rb") as f:
    hparams_energy_net_dict = json.load(f)
    energy_net_hparams = EnergyNetHparams(**hparams_energy_net_dict)
    
energy_net = EnergyNet(energy_net_hparams)

model = HamiltonianNet(energy_net=energy_net, operator_net=operator_net)
if replicated:
    model = eqx.filter_shard(model, replicated)

hparams = Hparams(energy_net=energy_net_hparams, operator_net=operator_net_hparams)

# INITIALIZE OPTIMIZERS
PATIENCE = 5 # Number of epochs with no improvement after which learning rate will be reduced
COOLDOWN = 0 # Number of epochs to wait before resuming normal operation after the learning rate reduction
FACTOR = 0.5  # Factor by which to reduce the learning rate:
RTOL = 1e-4  # Relative tolerance for measuring the new optimum:
ACCUMULATION_SIZE = 200 # Number of iterations to accumulate an average value:

if network in ["fno1d", "fno2d", "fno_timestepping"]:
    θ_optimizer = optax.chain(
        conjugate_grads_transform(), # we have to conjugate the gradients for the FNO networks
        optax.adam(operator_net_hparams.learning_rate),
        reduce_on_plateau(
            patience=PATIENCE,
            cooldown=COOLDOWN,
            factor=FACTOR,
            rtol=RTOL,
            accumulation_size=ACCUMULATION_SIZE,
        ),
    )
else:
    θ_optimizer = optax.chain(
        optax.adam(operator_net_hparams.learning_rate),
        reduce_on_plateau(
            patience=PATIENCE,
            cooldown=COOLDOWN,
            factor=FACTOR,
            rtol=RTOL,
            accumulation_size=ACCUMULATION_SIZE,
        ),
    )
    
φ_optimizer = optax.chain(
    optax.adam(energy_net_hparams.learning_rate),
    reduce_on_plateau(
        patience=PATIENCE,
        cooldown=COOLDOWN,
        factor=FACTOR,
        rtol=RTOL,
        accumulation_size=ACCUMULATION_SIZE,
    ),
)

if operator_net.is_self_adaptive: # Self-adaptive weights are enabled 
    λ_optimizer = optax.chain(optax.adam(operator_net_hparams.λ_learning_rate), optax.scale(-1.))
    opt = optax.multi_transform({'θ': θ_optimizer, 'λ': λ_optimizer, 'φ': φ_optimizer}, param_labels=param_labels_hno_self_adaptive)
else:
    opt = optax.multi_transform({'θ': θ_optimizer, 'φ': φ_optimizer}, param_labels=param_labels_hno)
    
opt_state = opt.init(eqx.filter([model], eqx.is_array))

Only one device detected. Disabling array sharding and autoparallelism.
