In [None]:
# !pip installs

In [None]:
import flax.linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import orbax.checkpoint
import tensorflow as tf

from swirl_dynamics.data import tfgrain_transforms as transforms
from swirl_dynamics.lib.networks import encoders
from swirl_dynamics.lib.solvers import ode
from swirl_dynamics.projects.evolve_smoothly import ansatzes
from swirl_dynamics.projects.evolve_smoothly import batch_decode
from swirl_dynamics.projects.evolve_smoothly import data_pipelines
from swirl_dynamics.projects.evolve_smoothly import encode_decode
from swirl_dynamics.projects.evolve_smoothly import latent_dynamics
from swirl_dynamics.templates import callbacks
from swirl_dynamics.templates import train

In [None]:
tf.config.experimental.set_visible_devices([], "GPU")

### Batch Decode Training

This training involves fitting the same ansatz to a large number of snapshots. The resulting error provides evidence whether the selected ansatz is sufficiently expressive for the problem considered.

We first set up the data pipeline using grain. This dataset contains a large number of snapshots with its batch dimension corresponding to grid points. In other words, a random sample represents all snapshots evaluated at these collocation points.

In [None]:
hdf5_file_path = "/swirl_dynamics/hdf5/pde/1d/ks_trajectories.hdf5"  #@param
num_snapshots = 5000  #@param

In [None]:
train_dataloader = data_pipelines.create_batch_decode_pipeline(
    hdf5_file_path = hdf5_file_path,
    snapshot_field = "train/u",
    grid_field = "train/x",
    num_snapshots_to_train = num_snapshots,
    transformations = [
        # this rescales the grid from [0, L) to [-1, 1)
        transforms.LinearRescale(
            feature_name = "x", input_range = (0, 64), output_range = (-1, 1))
    ],
    seed = 42,
    batch_size = 32,
)

Next we instantiate the model, which takes an ansatz model (wrapped to provide easy access to things like parameter shapes and structures) and the number of snapshots in the dataset.

In [None]:
ansatz = ansatzes.NonLinearFourier(
    model=ansatzes.nonlinear_fourier.NonLinearFourier(
        features=(8, 8),
        num_freqs=3,
        act_fn=jnp.sin,
        zero_freq=False,
        dyadic=False
    )
)

In [None]:
model = batch_decode.BatchDecode(
  ansatz=ansatz, num_snapshots = num_snapshots,
)

Thirdly, we instantiate the trainer, which takes the model, a random seed and an optimizer.

In [None]:
trainer = batch_decode.BatchDecodeTrainer(
    model=model, rng=jax.random.PRNGKey(42), optimizer=optax.adam(1e-3)
)

We can now run training and monitor progress. The fact that the training loss is fairly low at the end is a promising sign - the ansatz we adopted has the expressive power to represent a wide range of snapshots.

In [None]:
workdir = "batch_decode/"  #@param
num_train_steps = 20000  #@param

In [None]:
train.run(
  train_dataloader=train_dataloader,
  trainer=trainer,
  workdir=workdir,
  total_train_steps=num_train_steps,
  metric_aggregation_steps=50,
  callbacks=[
    callbacks.TqdmProgressBar(
        total_train_steps=num_train_steps,
        train_monitors=["train_loss", "train_loss_std"]
    ),
  ],
)

### Encode Decode Training

This stage involves using an encoder network to output the weights of an ansatz that parametrizes snapshots. By incorporating the consistency loss, we obtain smooth weight trajectories which prove to be beneficial for training dynamics later on.

In [None]:
transforms = [
    # this rescales the grid from [0, L) to [-1, 1)
    transforms.LinearRescale(
        feature_name="x", input_range=(0, 64), output_range=(-1, 1))
]

train_dataloader = data_pipelines.create_encode_decode_pipeline(
    hdf5_file_path=hdf5_file_path,
    snapshot_field="train/u",
    grid_field="train/x",
    transformations=transforms,
    seed=42,
    batch_size=32,
)

eval_dataloader = data_pipelines.create_encode_decode_pipeline(
    hdf5_file_path=hdf5_file_path,
    snapshot_field="eval/u",
    grid_field="eval/x",
    transformations=transforms,
    seed=42,
    batch_size=32,
)

In [None]:
ansatz = ansatzes.NonLinearFourier(
    model=ansatzes.nonlinear_fourier.NonLinearFourier(
        features=(8, 8),
        num_freqs=3,
        act_fn=jnp.sin,
        zero_freq=False,
        dyadic=False
    )
)
encoder = encoders.EncoderResNet(
    filters=4,
    dim_out=ansatz.num_params,
    num_levels=4,
    num_resnet_blocks=2,
    act_fn=jnp.sin
)
model = encode_decode.EncodeDecode(
    ansatz=ansatz,
    encoder=encoder,
    snapshot_dims=(1, 512, 1),
    consistency_weight=10.,
)

Define an exponentially decay learning rate schedule

In [None]:
lr = optax.warmup_cosine_decay_schedule(
    init_value = 0.0,
    peak_value = 1e-4,
    warmup_steps = 1000,
    decay_steps = 99000,
    end_value = 1e-6,
)

For the optimizer we use adam with norm-based gradient clipping.

In [None]:
optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=1.),
    optax.adam(lr)
)

In [None]:
trainer = encode_decode.EncodeDecodeTrainer(
    model=model, rng=jax.random.PRNGKey(1), optimizer=optimizer
)

Run training

In [None]:
workdir = "/tmp/encode_decode/"  #@param
num_train_steps = 100000  #@param

In [None]:
train.run(
  train_dataloader=train_dataloader,
  trainer=trainer,
  workdir=workdir,
  total_train_steps=num_train_steps,
  metric_aggregation_steps=50,
  eval_dataloader=eval_dataloader,
  eval_every_steps=2000,
  num_batches_per_eval=10,
  callbacks=[
    callbacks.TrainStateCheckpoint(
        base_dir=workdir,  # NOTE: this must be a full path
        options=orbax.checkpoint.CheckpointManagerOptions(
            save_interval_steps=1000,
            max_to_keep=5,
        )
    ),
    callbacks.TqdmProgressBar(
        total_train_steps=num_train_steps,
        train_monitors=("train_loss",),
        eval_monitors=("eval_reconstruction_rel_l2",)),
  ],
)

Check inference

In [None]:
encode_fn = encode_decode.EncodeDecodeTrainer.build_inference_fn(
    trainer.train_state, encoder
)

In [None]:
eval_batch = next(iter(eval_dataloader))
encoding = encode_fn(eval_batch["u"])
reconstruction = jax.vmap(ansatz.batch_evaluate, in_axes=(0, 0))(encoding, eval_batch["x"])
print(reconstruction.shape)

In [None]:
plt.figure(figsize=(4, 3))
plt.plot(
    eval_batch["x"][0, :, 0], reconstruction[0, :, 0], label="reconstruction"
)
plt.plot(
    eval_batch["x"][0, :, 0], eval_batch["u"][0, :, 0], label="true"
)
plt.legend()
plt.show()

### Latent Dynamics Training

After encoder training, we train a latent dynamical model on the resulting latent trajectories (frozen encoder).

In [None]:
transforms = [
    # this rescales the grid from [0, L) to [-1, 1)
    transforms.LinearRescale(
        feature_name="x", input_range=(0, 64), output_range=(-1, 1))
]


train_dataloader = data_pipelines.create_latent_dynamics_pipeline(
    hdf5_file_path=hdf5_file_path,
    snapshot_field="train/u",
    tspan_field="train/t",
    grid_field="train/x",
    num_time_steps=2,
    transformations=transforms,
    seed=42,
    batch_size=32,
)

eval_dataloader = data_pipelines.create_latent_dynamics_pipeline(
    hdf5_file_path=hdf5_file_path,
    snapshot_field="eval/u",
    tspan_field="eval/t",
    grid_field="eval/x",
    num_time_steps=2,
    transformations=transforms,
    seed=42,
    batch_size=32,
)

In [None]:
latent_dynamics_model = latent_dynamics.create_hyperunet_dynamics_model(
    ansatz=ansatz,
    embed_dims=(4, 256, 1024),
    act_fn=nn.swish,
    use_layernorm=True,
)
integrator = ode.RungeKutta4()
model = latent_dynamics.LatentDynamics(
    encoder=encode_fn,
    ansatz=ansatz,
    latent_dynamics_model=latent_dynamics_model,
    integrator=integrator,
    reconstruction_weight=0.0,
    latent_weight=1.0,
    consistency_weight=1.0,
)

In [None]:
optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=1.),
    optax.adam(
        optax.warmup_cosine_decay_schedule(
            init_value = 0.0,
            peak_value = 1e-4,
            warmup_steps = 1000,
            decay_steps = 99000,
            end_value = 1e-6,
        )
    )
)
trainer = latent_dynamics.LatentDynamicsTrainer(
    model=model, rng=jax.random.PRNGKey(1), optimizer=optimizer
)

In [None]:
workdir = "/tmp/latent_dynamics/"  #@param
num_train_steps = 100000  #@param

In [None]:
train.run(
  train_dataloader=train_dataloader,
  trainer=trainer,
  workdir=workdir,
  total_train_steps=num_train_steps,
  metric_aggregation_steps=50,
  eval_dataloader=eval_dataloader,
  eval_every_steps=2500,
  num_batches_per_eval=10,
  callbacks=[
    callbacks.TrainStateCheckpoint(
        base_dir=workdir,  # NOTE: this must be a full path
        options=orbax.checkpoint.CheckpointManagerOptions(
            save_interval_steps=5000,
            max_to_keep=5,
        )
    ),
    callbacks.TqdmProgressBar(
        total_train_steps=num_train_steps,
        train_monitors=("train_loss",),
        eval_monitors=("eval_latent_rel_l2",)),
  ],
)

Here we check inference by predicting 2 steps forward in time using the trained latent dynamical model.

In [None]:
dm = latent_dynamics.LatentDynamicsTrainer.build_inference_fn(
    trainer.train_state,
    encoder=encode_fn,
    ansatz=ansatz,
    latent_dynamics_model=latent_dynamics_model,
    integrator=integrator,
)

In [None]:
eval_batch = next(iter(eval_dataloader))
evolution = dm(
    u0=eval_batch["u"][:, 0], tspan=eval_batch["t"], grid=eval_batch["x"]
)
print(evolution.shape)

In [None]:
plt.figure(figsize=(4, 3))
plt.plot(
    eval_batch["x"][0, :, 0], evolution[0, -1, :, 0], label="predicted"
)
plt.plot(
    eval_batch["x"][0, :, 0], eval_batch["u"][0, -1, :, 0], label="true"
)
plt.legend()
plt.show()