We install the library (including the templates) along other dependencies necessaries that are project specific.

In [None]:
!pip install git+https://github.com/google-research/swirl-dynamics.git@main
!pip install ml_collections
!pip install ott-jax

We import all the necessary libraries for training.

In [None]:
import json
from os import path as osp

import jax
import ml_collections
import optax
from orbax import checkpoint
import pandas as pd
import tensorflow as tf
from IPython.display import display

We import libraries from `swirl_dynamics`.

In [None]:
from swirl_dynamics.lib.solvers import ode
from swirl_dynamics.projects.ergodic import choices
from swirl_dynamics.projects.ergodic import ks_1d
from swirl_dynamics.projects.ergodic import lorenz63
from swirl_dynamics.projects.ergodic import ns_2d
from swirl_dynamics.projects.ergodic import stable_ar
from swirl_dynamics.projects.ergodic import utils
from swirl_dynamics.templates import callbacks
from swirl_dynamics.templates import train

# Loading the configuration files.
from swirl_dynamics.projects.ergodic.configs import ks_1d as ks_1d_config
from swirl_dynamics.projects.ergodic.configs import lorenz63 as lorenz63_config
from swirl_dynamics.projects.ergodic.configs import ns_2d as ns_2d_config

In [None]:
tf.config.experimental.set_visible_devices([], "GPU")

In [None]:
def get_config(
    experiment: str,
    batch_size: int,
    normalize: bool,
    add_noise: bool,
    use_curriculum: bool,
    use_pushfwd: bool,
    measure_dist_type: str,
    measure_dist_lambda: float,
    measure_dist_k_lambda: float,
) -> ml_collections.ConfigDict:
  """Helper to retrieve config and override defaults."""
  experiment = choices.Experiment(experiment)
  if experiment == choices.Experiment.L63:
    config = lorenz63_config.get_config()
  elif experiment == choices.Experiment.KS_1D:
    config = ks_1d_config.get_config()
  elif experiment == choices.Experiment.NS_2D:
    config = ns_2d_config.get_config()
  else:
    raise NotImplementedError(f"Unknown experiment: {experiment}")
  config.batch_size = batch_size
  config.normalize = normalize
  config.add_noise = add_noise
  if add_noise:
    config.noise_level = 1e-3
  config.use_curriculum = use_curriculum
  config.use_pushfwd = use_pushfwd
  config.measure_dist_type = measure_dist_type
  config.measure_dist_lambda = measure_dist_lambda
  config.measure_dist_k_lambda = measure_dist_k_lambda
  if use_curriculum:
    config.train_steps_per_cycle = 50_000
    config.time_steps_increase_per_cycle = 1
  else:
    config.train_steps_per_cycle = 0
    config.time_steps_increase_per_cycle = 0
  config.metric_aggregation_steps = 1  # Log to tqdm bar more frequently
  return config

## Define experiment configuration.

In [None]:
experiment = "ns_2d"  #@param choices=['lorenz63', 'ks_1d', 'ns_2d']
batch_size = 512  #@param {type:"integer"}
measure_dist_type = "MMD" #@param choices=['MMD', 'SD']
normalize = False  #@param {type:"boolean"}
add_noise = False  #@param {type:"boolean"}
use_curriculum = True  #@param {type:"boolean"}
use_pushfwd = True  #@param {type:"boolean"}
measure_dist_lambda = 0.0  #@param {type:"number"}
measure_dist_k_lambda = 0.0  #@param {type:"number"}
display_config = True  #@param {type:"boolean"}
config = get_config(
    experiment,
    batch_size,
    normalize,
    add_noise,
    use_curriculum,
    use_pushfwd,
    measure_dist_type,
    measure_dist_lambda,
    measure_dist_k_lambda,
)

if display_config:
  config_df = pd.DataFrame.from_dict(
      config,
      orient='index',
      columns=['Config values']
  )
  display(config_df)

## Run experiment

In [None]:
"<TODO: INSERT WORKDIR HERE>"
workdir = "" #@param

In [None]:
#@title Launch experiment
print(f'Saving files and checkpoints to {workdir}.')
# Saves config to the workdir, so that it can be loaded later.
if not tf.io.gfile.exists(workdir):
  tf.io.gfile.makedirs(workdir)
with tf.io.gfile.GFile(name=osp.join(workdir, "config.json"), mode="w") as f:
  config_json = config.to_json_best_effort()
  if isinstance(config_json, str):
    conf_json = json.loads(config_json)
  json.dump(config_json, f)

# Sets up experiment-specific parameters.
experiment = choices.Experiment(config.experiment)
if experiment == choices.Experiment.L63:
    fig_callback_cls = lorenz63.Lorenz63PlotFigures
    state_dims = (3 // config.spatial_downsample_factor,)
    optimizer = optax.adam(learning_rate=config.lr)
elif experiment == choices.Experiment.KS_1D:
  fig_callback_cls = ks_1d.KS1DPlotFigures
  state_dims = (
      512 // config.spatial_downsample_factor,
      config.num_lookback_steps
  )
  optimizer = optax.adam(learning_rate=config.lr)
elif experiment == choices.Experiment.NS_2D:
  fig_callback_cls = ns_2d.NS2dPlotFigures
  state_dims = (
      64 // config.spatial_downsample_factor,
      64 //config.spatial_downsample_factor,
      config.num_lookback_steps
  )
  optimizer = optax.adam(
      learning_rate=optax.exponential_decay(
          init_value=config.lr,
          transition_steps=72_000,
          decay_rate=0.5,
          staircase=True,
      )
  )
else:
  raise NotImplementedError(f"Unknown experiment: {config.experiment}")

# Instantiating the dataloaders.
train_loader, normalize_stats = utils.create_loader_from_hdf5(
    num_time_steps=config.num_time_steps,
    time_stride=config.time_stride,
    batch_size=config.batch_size,
    seed=config.seed,
    dataset_path=config.dataset_path,
    split="train",
    normalize=config.normalize,
    normalize_stats=None,
    spatial_downsample_factor=config.spatial_downsample_factor,
)
eval_loader, _ = utils.create_loader_from_hdf5(
    num_time_steps=-1,
    time_stride=config.time_stride,
    batch_size=-1,
    seed=config.seed,
    dataset_path=config.dataset_path,
    split="eval",
    normalize=config.normalize,
    normalize_stats=normalize_stats,
    spatial_downsample_factor=config.spatial_downsample_factor,
)

# Model
measure_dist_fn = choices.MeasureDistance(config.measure_dist_type).dispatch()
model_config = stable_ar.StableARModelConfig(
    state_dimension=state_dims,
    dynamics_model=choices.Model(config.model).dispatch(config),
    integrator=choices.Integrator(config.integrator),
    measure_dist=measure_dist_fn,
    use_pushfwd=config.use_pushfwd,
    add_noise=config.add_noise,
    noise_level=config.noise_level,
    measure_dist_lambda=config.measure_dist_lambda,
    measure_dist_k_lambda=config.measure_dist_k_lambda,
    num_lookback_steps=config.num_lookback_steps,
    normalize_stats=normalize_stats,
)
model = stable_ar.StableARModel(conf=model_config)

# Trainer
trainer_config = stable_ar.StableARTrainerConfig(
    rollout_weighting=choices.RolloutWeighting(
        config.rollout_weighting
    ).dispatch(config),
    num_rollout_steps=config.num_rollout_steps,
    num_lookback_steps=config.num_lookback_steps,
    add_noise=config.add_noise,
    use_curriculum=config.use_curriculum,
    train_steps_per_cycle=config.train_steps_per_cycle,
    time_steps_increase_per_cycle=config.time_steps_increase_per_cycle,
)
trainer = stable_ar.StableARTrainer(
    model=model,
    conf=trainer_config,
    rng=jax.random.PRNGKey(config.seed),
    optimizer=optimizer,
)

# Setup checkpointing
ckpt_options = checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.save_interval_steps,
    max_to_keep=config.max_checkpoints_to_keep,
)
# Run train
train.run(
    train_dataloader=train_loader,
    eval_dataloader=eval_loader,
    eval_every_steps=config.save_interval_steps,
    num_batches_per_eval=1,
    trainer=trainer,
    workdir=workdir,
    total_train_steps=config.train_steps,
    metric_aggregation_steps=config.metric_aggregation_steps,
    callbacks=[
        callbacks.TrainStateCheckpoint(
            base_dir=workdir,
            options=ckpt_options,
        ),
        callbacks.ProgressReport(
            num_train_steps=config.train_steps,
        ),
        callbacks.TqdmProgressBar(
            total_train_steps=config.train_steps,
            train_monitors=["rollout", "loss", "measure_dist", "measure_dist_k", "max_rollout_decay"],
            eval_monitors=["sd"],
        ),
        fig_callback_cls()
    ],
)