In [None]:
import pickle
from pathlib import Path
from functools import partial
from argparse import Namespace
from tqdm import tqdm
import scipy
import jax
import jax.numpy as jnp
import jax.random as jr
import optax
from flax.training import train_state
from dataset.input.signals import multisine_signal
import dataset.dynamics.boucwen as dyn
from dataset.simulate import simulate_rk4 as simulate
from dataset.simulate import generate_batch
from neuralss import ss_init, ss_apply
from ae import Encoder, Projector
from lr import create_learning_rate_fn

In [None]:
jax.config.update("jax_default_device", jax.devices("gpu")[0])

In [None]:
# Configuration
cfg = {
    # Misc
    "log_wandb": True,

    # Meta dataset
    "K": 2,  # repetitions from the same system, unused
    "nu": 1,
    "ny": 1,
    "seq_len": 1500,
    "skip_sim": 500,
    "fs": 750.0, # sampling time
    "fh": 150, # highest frequency
    "upsamp": 20, # upsampling for integration
    "input_scale": 50,
    "output_scale": 7e-4,

    # Base learner
    "nx": 3,
    "hidden_f": 16,
    "hidden_g": 16,
    
    # Inner Loop Optimization
    "alpha": 0.1,
    "inner_iters": 10,
    
    # Optimization
    "batch_size": 128,  # systems sampled at each meta optimization step
    "iters": 200_000,
    "lr": 2e-4,
    "clip": 1.0,
    "warmup_iters": 0,
    "skip_loss": 500,  # skipped from the loss computation, to avoid a more advanced handling of the initial condition
    "same_sys": 10
}

cfg = Namespace(**cfg)

In [None]:
if cfg.log_wandb:
    import wandb
    wandb.init(
        project="sysid-parametric-meta",
        #name="run1",
        # track hyperparameters and run metadata
        config=vars(cfg)
    )

In [None]:
seed = 12345
key = jr.key(seed)
dec_key, proj_key, data_key, train_key = jr.split(key, 4)

In [None]:
# Meta dataset definition
fs_up = cfg.fs * cfg.upsamp
ts_up = 1.0 / fs_up
N = cfg.seq_len + cfg.skip_sim
N_up = N * cfg.upsamp

input_fn = partial(multisine_signal, seq_len=N_up, fs=fs_up, fh=cfg.fh, scale=cfg.input_scale)
simulate_fn = jax.jit(partial(simulate, f_xu=dyn.f_xu))
generate_batch = partial(
    generate_batch,
    init_fn=dyn.init_fn,  # random initial state
    input_fn=input_fn,  # random input
    params_fn=dyn.params_fn,  # random system parameters
    simulate_fn=simulate_fn,  # simulation function
)


def generate_batches(key, batch_size=cfg.batch_size, K=cfg.K):
    generate_batch_cfg = jax.jit(partial(generate_batch, systems=batch_size, runs=K))
    while True:
        key, subkey = jr.split(key, 2)
        yield generate_batch_cfg(subkey)


def preproc_batch(batch):
    batch_u, batch_x, batch_t, batch_params = batch
    batch_y = batch_x[..., [0]]

    batch_u /= cfg.input_scale
    batch_y /= cfg.output_scale

    if cfg.upsamp > 1:
        batch_u = scipy.signal.decimate(batch_u, q=cfg.upsamp, axis=-2)
        batch_y = scipy.signal.decimate(batch_y, q=cfg.upsamp, axis=-2)

    batch_y1 = batch_y[:, 0, cfg.skip_sim:]
    batch_u1 = batch_u[:, 0, cfg.skip_sim:]

    batch_y2 = batch_y[:, 1, cfg.skip_sim:]
    batch_u2 = batch_u[:, 1, cfg.skip_sim:]

    return batch_y1, batch_u1, batch_y2, batch_u2

In [None]:
# Initialize data loader
train_dl = generate_batches(data_key)
batch = next(iter(train_dl))
batch_y1, batch_u1, batch_y2, batch_u2 = preproc_batch(batch)
batch_y1.shape, batch_u1.shape,batch_y2.shape, batch_u2.shape,

In [None]:
# Initialize state-space model
params_ss = ss_init(dec_key, nu=cfg.nu, ny=cfg.ny, nx=cfg.nx)
params_ss_flat, unflatten_dec = jax.flatten_util.ravel_pytree(params_ss)
n_params = params_ss_flat.shape[0]
scalers = {"f": {"lin": 1e-2, "nl": 1e-2}, "g": {"lin": 1e0, "nl": 1e0}}

In [None]:
# Mean Squared Error loss function
def mse_loss_fn(p, y, u):
    x0 = jnp.zeros((cfg.nx, ))
    y1_hat = ss_apply(p, scalers, x0, u)
    err = y - y1_hat
    loss = jnp.mean(err[cfg.skip_loss:] ** 2)
    return loss

In [None]:
# Inner update function (GD) on a single instance
def inner_update_fn(p, y, u, alpha=0.1, iters=1):
    grad_fn = jax.grad(mse_loss_fn)
    for _ in range(iters):
        grads = grad_fn(p, y, u)
        inner_sgd_fn = lambda g, p: (p - alpha*g)
        p = jax.tree_util.tree_map(inner_sgd_fn, grads, p)
    return p

# Meta loss (MAML) for one instance
def instance_loss_fn(p1, y1, u1, y2, u2):

    p2 = inner_update_fn(p1, y1, u1, alpha=cfg.alpha, iters=cfg.inner_iters)
    return mse_loss_fn(p2, y2, u2)


instance_loss_fn(params_ss, batch_y1[0], batch_u1[0], batch_y2[0], batch_u2[0])

In [None]:
# batched loss
def loss_fn(*args):
    loss = jax.vmap(instance_loss_fn, in_axes=(None, 0, 0, 0, 0))(*args)
    return jnp.mean(loss)

In [None]:
#lr_scheduler = create_learning_rate_fn(cfg)

opt = optax.chain(
  optax.clip(cfg.clip),
  optax.adam(learning_rate=cfg.lr),
)
state = train_state.TrainState.create(apply_fn=loss_fn, params=params_ss, tx=opt)

@jax.jit
def make_step(state, y1, u1, y2, u2):
        loss, grads = jax.value_and_grad(state.apply_fn)(state.params, y1, u1, y2, u2)
        state = state.apply_gradients(grads=grads)
        return loss, state

In [None]:
LOSS = []
loss = jnp.array(jnp.nan)
#for itr, batch in (pbar := tqdm(enumerate(train_dl), total=cfg.iters)):

for itr in (pbar := tqdm(range(cfg.iters))):

    if itr % cfg.same_sys == 0: # some speed up
        batch = next(iter(train_dl))
        batch_y1, batch_u1, batch_y2, batch_u2 = preproc_batch(batch)

    loss, new_state = make_step(state, batch_y1, batch_u1, batch_y2, batch_u2)
    if not jnp.isnan(loss).any() and loss < 2.0: # fix some instability issues in training
        state = new_state

    LOSS.append(loss.item())
    if itr % 10 == 0:
        pbar.set_postfix_str(
            f"loss:{loss.item():.4f}"
        )

    #if itr % 100 == 0 and cfg.log_wandb:
    if cfg.log_wandb:
        if itr % 1 == 0:
            wandb.log({"loss": loss.item()})

    if itr % 5000 == 0:
        ckpt = {
            "cfg": cfg,
            "params": state.params,
            "scalers": scalers,
            "LOSS": jnp.array(LOSS),
        }
        ckpt_path = Path("tmp") / f"maml_{itr}.p"
        ckpt_path.parent.mkdir(exist_ok=True, parents=True)
        pickle.dump(ckpt, open(ckpt_path, "wb"))


    if itr == cfg.iters:
        break

In [None]:
# Save the final checkpoint
ckpt = {
    "cfg": cfg,
    "params": state.params,
    "scalers": scalers,
    "LOSS": jnp.array(LOSS),
}

ckpt_path = Path("out") / f"maml.p"
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
pickle.dump(ckpt, open(ckpt_path, "wb" ))

In [None]:
if cfg.log_wandb:
    wandb.finish()