In [1]:
from tqdm.auto import tqdm
import pickle
from functools import partial
import time
import warnings
from math import pi, inf
import os
import argparse
import json
import jax.debug as jdb

In [20]:
class args_class:
    def __init__(self, seed, M, pnorm_init, p_freq, output_dir, use_x64):
        self.seed = seed
        self.M = M
        self.pnorm_init = pnorm_init
        self.p_freq = p_freq
        self.output_dir = output_dir
        self.use_x64 = use_x64

In [37]:
# Use argparse inside of ipynb
args = args_class(0, 20, 2.2, 100, 'diagnose_training_files', True)

In [39]:
import jax                                          # noqa: E402
import jax.numpy as jnp                             # noqa: E402
from jax.example_libraries import optimizers             # noqa: E402
from dynamics import prior                          # noqa: E402
from utils import (tree_normsq, rk38_step, epoch,   # noqa: E402
                   odeint_fixed_step, random_ragged_spline, spline,
            params_to_cholesky, params_to_posdef)

import jax.debug as jdebug

def convert_p_qbar(p):
    return jnp.sqrt(1/(1 - 1/p) - 1.1)

def convert_qbar_p(qbar):
    return 1/(1 - 1/(1.1 + qbar**2))

# Initialize PRNG key
key = jax.random.PRNGKey(args.seed)

# Hyperparameters
hparams = {
    'seed':        args.seed,     #
    'use_x64':     args.use_x64,  #
    'num_subtraj': args.M,        # number of trajectories to sub-sample

    # For training the model ensemble
    'ensemble': {
        'num_hlayers':    2,     # number of hidden layers in each model
        'hdim':           32,    # number of hidden units per layer
        'train_frac':     0.75,  # fraction of each trajectory for training
        'batch_frac':     0.25,  # fraction of training data per batch
        'regularizer_l2': 1e-4,  # coefficient for L2-regularization
        'learning_rate':  1e-2,  # step size for gradient optimization
        'num_epochs':     1000,  # number of epochs
    },
    # For meta-training
    'meta': {
        'num_hlayers':       2,          # number of hidden layers
        'hdim':              32,         # number of hidden units per layer
        'train_frac':        0.75,       #
        'learning_rate':     1e-2,       # step size for gradient optimization
        'num_steps':         1500,        # maximum number of gradient steps
        'regularizer_l2':    1e-4,       # coefficient for L2-regularization
        'regularizer_ctrl':  1e-3,       #
        'regularizer_error': 0.,         #
        'T':                 5.,         # time horizon for each reference
        'dt':                1e-2,       # time step for numerical integration
        'num_refs':          10,         # reference trajectories to generate
        'num_knots':         6,          # knot points per reference spline
        'poly_orders':       (9, 9, 6),  # spline orders for each DOF
        'deriv_orders':      (4, 4, 2),  # smoothness objective for each DOF
        'min_step':          (-2., -2., -pi/6),    #
        'max_step':          (2., 2., pi/6),       #
        'min_ref':           (-inf, -inf, -pi/3),  #
        'max_ref':           (inf, inf, pi/3),     #
        'p_freq':            args.p_freq,          # frequency for p-norm update
    },
}

In [35]:
def convert_p_qbar(p):
    return jnp.sqrt(1/(1 - 1/p) - 1.1)

def convert_qbar_p(qbar):
    return 1/(1 - 1/(1.1 + qbar**2))

In [6]:
# DATA PROCESSING ########################################################
# Load raw data and arrange in samples of the form
# `(t, x, u, t_next, x_next)` for each trajectory, where `x := (q,dq)`
with open('training_data.pkl', 'rb') as file:
    raw = pickle.load(file)
num_dof = raw['q'].shape[-1]       # number of degrees of freedom
num_traj = raw['q'].shape[0]       # total number of raw trajectories
num_samples = raw['t'].size - 1    # number of transitions per trajectory
t = jnp.tile(raw['t'][:-1], (num_traj, 1))
t_next = jnp.tile(raw['t'][1:], (num_traj, 1))
x = jnp.concatenate((raw['q'][:, :-1], raw['dq'][:, :-1]), axis=-1)
x_next = jnp.concatenate((raw['q'][:, 1:], raw['dq'][:, 1:]), axis=-1)
u = raw['u'][:, :-1]
data = {'t': t, 'x': x, 'u': u, 't_next': t_next, 'x_next': x_next}

# Shuffle and sub-sample trajectories
if hparams['num_subtraj'] > num_traj:
    warnings.warn('Cannot sub-sample {:d} trajectories! '
                    'Capping at {:d}.'.format(hparams['num_subtraj'],
                                            num_traj))
    hparams['num_subtraj'] = num_traj

key, subkey = jax.random.split(key, 2)
shuffled_idx = jax.random.permutation(subkey, num_traj)
hparams['subtraj_idx'] = shuffled_idx[:hparams['num_subtraj']]
data = jax.tree_util.tree_map(
    lambda a: jnp.take(a, hparams['subtraj_idx'], axis=0),
    data
)

In [7]:
# MODEL ENSEMBLE TRAINING ################################################
# Loss function along a trajectory
def ode(x, t, u, params, prior=prior):
    """TODO: docstring."""
    num_dof = x.size // 2
    q, dq = x[:num_dof], x[num_dof:]
    H, C, g, B = prior(q, dq)

    # Each model in the ensemble is a feed-forward neural network
    # with zero output bias
    f = x
    for W, b in zip(params['W'], params['b']):
        f = jnp.tanh(W@f + b)
    f = params['A'] @ f
    ddq = jax.scipy.linalg.solve(H, B@u + f - C@dq - g, assume_a='pos')
    dx = jnp.concatenate((dq, ddq))
    return dx

def loss(params, regularizer, t, x, u, t_next, x_next, ode=ode):
    """TODO: docstring."""
    num_samples = t.size
    dt = t_next - t
    x_next_est = jax.vmap(rk38_step, (None, 0, 0, 0, 0, None))(
        ode, dt, x, t, u, params
    )
    loss = (jnp.sum((x_next_est - x_next)**2)
            + regularizer*tree_normsq(params)) / num_samples
    return loss

# Parallel updates for each model in the ensemble
@partial(jax.jit, static_argnums=(4, 5))
@partial(jax.vmap, in_axes=(None, 0, None, 0, None, None))
def step(idx, opt_state, regularizer, batch, get_params, update_opt,
            loss=loss):
    """TODO: docstring."""
    params = get_params(opt_state)
    grads = jax.grad(loss, argnums=0)(params, regularizer, **batch)
    opt_state = update_opt(idx, grads, opt_state)
    return opt_state

@jax.jit
@jax.vmap
def update_best_ensemble(old_params, old_loss, new_params, batch):
    """TODO: docstring."""
    new_loss = loss(new_params, 0., **batch)  # do not regularize
    best_params = jax.tree_util.tree_map(
        lambda x, y: jnp.where(new_loss < old_loss, x, y),
        new_params,
        old_params
    )
    best_loss = jnp.where(new_loss < old_loss, new_loss, old_loss)
    return best_params, best_loss, new_loss

# Initialize model parameters
num_models = hparams['num_subtraj']  # one model per trajectory
num_hlayers = hparams['ensemble']['num_hlayers']
hdim = hparams['ensemble']['hdim']
if num_hlayers >= 1:
    shapes = [(hdim, 2*num_dof), ] + (num_hlayers-1)*[(hdim, hdim), ]
else:
    shapes = []
key, *subkeys = jax.random.split(key, 1 + 2*num_hlayers + 1)
keys_W = subkeys[:num_hlayers]
keys_b = subkeys[num_hlayers:-1]
key_A = subkeys[-1]
ensemble = {
    # hidden layer weights
    'W': [0.1*jax.random.normal(keys_W[i], (num_models, *shapes[i]))
            for i in range(num_hlayers)],
    # hidden layer biases
    'b': [0.1*jax.random.normal(keys_b[i], (num_models, shapes[i][0]))
            for i in range(num_hlayers)],
    # last layer weights
    'A': 0.1*jax.random.normal(key_A, (num_models, num_dof, hdim))
}

# Shuffle samples in time along each trajectory, then split each
# trajectory into training and validation sets (i.e., for each model)
key, *subkeys = jax.random.split(key, 1 + num_models)
subkeys = jnp.asarray(subkeys)
shuffled_data = jax.tree_util.tree_map(
    lambda a: jax.vmap(jax.random.permutation)(subkeys, a),
    data
)
num_train_samples = int(hparams['ensemble']['train_frac'] * num_samples)
ensemble_train_data = jax.tree_util.tree_map(
    lambda a: a[:, :num_train_samples],
    shuffled_data
)
ensemble_valid_data = jax.tree_util.tree_map(
    lambda a: a[:, num_train_samples:],
    shuffled_data
)

# Initialize gradient-based optimizer (ADAM)
learning_rate = hparams['ensemble']['learning_rate']
batch_size = int(hparams['ensemble']['batch_frac'] * num_train_samples)
num_batches = num_train_samples // batch_size
init_opt, update_opt, get_params = optimizers.adam(learning_rate)
opt_states = jax.vmap(init_opt)(ensemble)
get_ensemble = jax.jit(jax.vmap(get_params))
step_idx = 0
best_idx = jnp.zeros(num_models)

# Pre-compile before training
print('ENSEMBLE TRAINING: Pre-compiling ... ', end='', flush=True)
start = time.time()
batch = next(epoch(key, ensemble_train_data, batch_size,
                    batch_axis=1, ragged=False))
_ = step(step_idx, opt_states, hparams['ensemble']['regularizer_l2'],
            batch, get_params, update_opt)
inf_losses = jnp.broadcast_to(jnp.inf, (num_models,))
best_ensemble, best_losses, _ = update_best_ensemble(ensemble,
                                                        inf_losses,
                                                        ensemble,
                                                        ensemble_valid_data)
_ = get_ensemble(opt_states)
end = time.time()
print('done ({:.2f} s)!'.format(end - start))

# Do gradient descent
for _ in tqdm(range(hparams['ensemble']['num_epochs'])):
    key, subkey = jax.random.split(key, 2)
    for batch in epoch(subkey, ensemble_train_data, batch_size,
                        batch_axis=1, ragged=False):
        opt_states = step(step_idx, opt_states,
                            hparams['ensemble']['regularizer_l2'],
                            batch, get_params, update_opt)
        new_ensemble = get_ensemble(opt_states)
        old_losses = best_losses
        best_ensemble, best_losses, valid_losses = update_best_ensemble(
            best_ensemble, best_losses, new_ensemble, batch
        )
        step_idx += 1
        best_idx = jnp.where(old_losses == best_losses,
                                best_idx, step_idx)

ENSEMBLE TRAINING: Pre-compiling ... done (1.27 s)!


  0%|          | 0/1000 [00:00<?, ?it/s]

In [40]:
from jax import config
config.update("jax_debug_nans", True)

# META-TRAINING ##########################################################
def ode(z, t, meta_params, pnorm_param, params, reference, prior=prior):
    """TODO: docstring."""
    x, pA, c = z
    num_dof = x.size // 2
    q, dq = x[:num_dof], x[num_dof:]
    r = reference(t)
    dr = jax.jacfwd(reference)(t)
    ddr = jax.jacfwd(jax.jacfwd(reference))(t)

    # Regressor features
    y = x
    for W, b in zip(meta_params['W'], meta_params['b']):
        y = jnp.tanh(W@y + b)

    # Parameterized control and adaptation gains
    gains = jax.tree_util.tree_map(
        lambda x: params_to_posdef(x),
        meta_params['gains']
    )
    Λ, K, P = gains['Λ'], gains['K'], gains['P']

    qn = 1.1 + pnorm_param['pnorm']**2

    A = (jnp.maximum(jnp.abs(pA), 1e-6 * jnp.ones_like(pA))**(qn-1) * jnp.sign(pA) * (jnp.ones_like(pA) - jnp.isclose(pA, 0, atol=1e-6))) @ P
    # A = (jnp.maximum(jnp.abs(pA), 1e-6 * jnp.ones_like(pA))**(qn-1) * jnp.sign(pA) * jnp.isclose(pA, 0, atol=1e-6)) @ P
    #A = pA

    # Auxiliary signals
    e, de = q - r, dq - dr
    v, dv = dr - Λ@e, ddr - Λ@de
    s = de + Λ@e

    # Controller and adaptation law
    H, C, g, B = prior(q, dq)
    f_hat = A@y
    τ = H@dv + C@v + g - f_hat - K@s
    u = jnp.linalg.solve(B, τ)
    dpA = jnp.outer(s, y) @ P

    # Apply control to "true" dynamics
    f = x
    for W, b in zip(params['W'], params['b']):
        f = jnp.tanh(W@f + b)
    f = params['A'] @ f
    ddq = jax.scipy.linalg.solve(H, τ + f - C@dq - g, assume_a='pos')
    dx = jnp.concatenate((dq, ddq))

    # Estimation loss
    # chol_P = params_to_cholesky(meta_params['gains']['P'])
    # f_error = f_hat - f
    # loss_est = f_error@jax.scipy.linalg.cho_solve((chol_P, True),
    #                                               f_error)

    # Integrated cost terms
    dc = jnp.array([
        e@e + de@de,                # tracking loss
        u@u,                        # control loss
        (f_hat - f)@(f_hat - f),    # estimation loss
    ])

    # Assemble derivatives
    dz = (dx, dpA, dc)
    return dz

# Simulate adaptive control loop on each model in the ensemble
def ensemble_sim(meta_params, pnorm_param, ensemble_params, reference, T, dt, ode=ode):
    """TODO: docstring."""
    # Initial conditions
    r0 = reference(0.)
    dr0 = jax.jacfwd(reference)(0.)
    num_dof = r0.size
    num_features = meta_params['W'][-1].shape[0]
    x0 = jnp.concatenate((r0, dr0))
    A0 = jnp.zeros((num_dof, num_features))
    c0 = jnp.zeros(3)
    z0 = (x0, A0, c0)

    # Integrate the adaptive control loop using the meta-model
    # and EACH model in the ensemble along the same reference
    in_axes = (None, None, None, None, None, None, None, 0)
    ode = partial(ode, reference=reference)
    z, t = jax.vmap(odeint_fixed_step, in_axes)(ode, z0, 0., T, dt,
                                                meta_params, pnorm_param,
                                                ensemble_params)
    x, A, c = z
    return t, x, A, c

# Initialize meta-model parameters
num_hlayers = hparams['meta']['num_hlayers']
hdim = hparams['meta']['hdim']
if num_hlayers >= 1:
    shapes = [(hdim, 2*num_dof), ] + (num_hlayers-1)*[(hdim, hdim), ]
else:
    shapes = []
key, *subkeys = jax.random.split(key, 1 + 2*num_hlayers + 3)
subkeys_W = subkeys[:num_hlayers]
subkeys_b = subkeys[num_hlayers:-3]
subkeys_gains = subkeys[-3:]
meta_params = {
    # hidden layer weights
    'W': [0.1*jax.random.normal(subkeys_W[i], shapes[i])
            for i in range(num_hlayers)],
    # hidden layer biases
    'b': [0.1*jax.random.normal(subkeys_b[i], (shapes[i][0],))
            for i in range(num_hlayers)],
    'gains': {  # vectorized control and adaptation gains
        'Λ': 0.1*jax.random.normal(subkeys_gains[0],
                                    ((num_dof*(num_dof + 1)) // 2,)),
        'K': 0.1*jax.random.normal(subkeys_gains[1],
                                    ((num_dof*(num_dof + 1)) // 2,)),
        'P': 0.1*jax.random.normal(subkeys_gains[2],
                                    ((hdim*(hdim + 1)) // 2,)),
    },
}
# In the bash script, we always specify p-norm desried initial values
# Note that the program always uses the q_bar parameter as the p-norm parameterization
# However, the printing function should log the final results in p-norm
pnorm_param = {'pnorm': convert_p_qbar(args.pnorm_init)}
print("Initialize pnorm as {:.2f}".format(convert_qbar_p(pnorm_param['pnorm'])))

# Initialize spline coefficients for each reference trajectory
num_refs = hparams['meta']['num_refs']
key, *subkeys = jax.random.split(key, 1 + num_refs)
subkeys = jnp.vstack(subkeys)
in_axes = (0, None, None, None, None, None, None, None, None)
min_ref = jnp.asarray(hparams['meta']['min_ref'])
max_ref = jnp.asarray(hparams['meta']['max_ref'])
t_knots, knots, coefs = jax.vmap(random_ragged_spline, in_axes)(
    subkeys,
    hparams['meta']['T'],
    hparams['meta']['num_knots'],
    hparams['meta']['poly_orders'],
    hparams['meta']['deriv_orders'],
    jnp.asarray(hparams['meta']['min_step']),
    jnp.asarray(hparams['meta']['max_step']),
    0.7*min_ref,
    0.7*max_ref,
)
# x_coefs, y_coefs, θ_coefs = coefs
# x_knots, y_knots, θ_knots = knots
r_knots = jnp.dstack(knots)

# Simulate the adaptive control loop for each model in the ensemble and
# each reference trajectory (i.e., spline coefficients)
@partial(jax.vmap, in_axes=(None, None, None, 0, 0, None, None))
def simulate(meta_params, pnorm_param, ensemble_params, t_knots, coefs, T, dt, min_ref=min_ref, max_ref=max_ref):
    """TODO: docstring."""
    # Define a reference trajectory in terms of spline coefficients
    def reference(t):
        r = jnp.array([spline(t, t_knots, c) for c in coefs])
        r = jnp.clip(r, min_ref, max_ref)
        return r
    t, x, A, c = ensemble_sim(meta_params, pnorm_param, ensemble_params,
                                reference, T, dt)
    return t, x, A, c

@partial(jax.jit, static_argnums=(5, 6))
def loss(meta_params, pnorm_param, ensemble_params, t_knots, coefs, T, dt,
            regularizer_l2, regularizer_ctrl, regularizer_error):
    """TODO: docstring."""
    # Simulate on each model for each reference trajectory
    t, x, A, c = simulate(meta_params, pnorm_param, ensemble_params, t_knots,
                            coefs, T, dt)

    # Sum final costs over reference trajectories and ensemble models
    # Note `c` has shape (`num_refs`, `num_models`, `T // dt`, 3)
    c_final = jnp.sum(c[:, :, -1, :], axis=(0, 1))

    # Form a composite loss by weighting the different cost integrals,
    # and normalizing by the number of models, number of reference
    # trajectories, and time horizon
    num_refs = c.shape[0]
    num_models = c.shape[1]
    normalizer = T * num_refs * num_models
    tracking_loss, control_loss, estimation_loss = c_final
    reg_gain = jnp.linalg.norm(meta_params['gains']['P'])**2
    l2_penalty = tree_normsq((meta_params['W'], meta_params['b']))
    loss = (tracking_loss
            + regularizer_ctrl*control_loss
            + regularizer_error*estimation_loss
            + regularizer_l2*l2_penalty
            + 5 * reg_gain) / normalizer
    aux = {
        # for each model in ensemble
        'tracking_loss':   jnp.sum(c[:, :, -1, 0], axis=0) / num_refs,
        'control_loss':    jnp.sum(c[:, :, -1, 1], axis=0) / num_refs,
        'estimation_loss': jnp.sum(c[:, :, -1, 2], axis=0) / num_refs,
        'l2_penalty':      l2_penalty,
        'eigs_Λ':
            jnp.diag(params_to_cholesky(meta_params['gains']['Λ']))**2,
        'eigs_K':
            jnp.diag(params_to_cholesky(meta_params['gains']['K']))**2,
        'eigs_P':
            jnp.diag(params_to_cholesky(meta_params['gains']['P']))**2,
        'pnorm': pnorm_param['pnorm']
    }
    return loss, aux

# Shuffle and split ensemble into training and validation sets
train_frac = hparams['meta']['train_frac']
num_train_models = int(train_frac * num_models)
key, subkey = jax.random.split(key, 2)
model_idx = jax.random.permutation(subkey, num_models)
train_model_idx = model_idx[:num_train_models]
valid_model_idx = model_idx[num_train_models:]
train_ensemble = jax.tree_util.tree_map(lambda x: x[train_model_idx],
                                        best_ensemble)
valid_ensemble = jax.tree_util.tree_map(lambda x: x[valid_model_idx],
                                        best_ensemble)

# Split reference trajectories into training and validation sets
num_train_refs = int(train_frac * num_refs)
train_t_knots = jax.tree_util.tree_map(lambda a: a[:num_train_refs],
                                        t_knots)
train_coefs = jax.tree_util.tree_map(lambda a: a[:num_train_refs], coefs)
valid_t_knots = jax.tree_util.tree_map(lambda a: a[num_train_refs:],
                                        t_knots)
valid_coefs = jax.tree_util.tree_map(lambda a: a[num_train_refs:], coefs)

# Initialize gradient-based optimizer (ADAM)
learning_rate = hparams['meta']['learning_rate']
init_opt, update_opt, get_params = optimizers.adam(learning_rate)
# Update meta_params and pnorm_param separately
opt_meta = init_opt(meta_params)
opt_pnorm = init_opt(pnorm_param)
step_meta_idx = 0
step_pnorm_idx = 0
best_idx_meta = 0
best_idx_pnorm = 0
best_loss = jnp.inf
best_meta_params = meta_params
best_pnorm_param = pnorm_param

@partial(jax.jit, static_argnums=(6, 7))
def step_meta(idx, opt_state, pnorm_param, ensemble_params, t_knots, coefs, T, dt, regularizer_l2, regularizer_ctrl, regularizer_error):
    """This function only updates the meta_params in an iteration"""
    meta_params = get_params(opt_state)
    grads, aux = jax.grad(loss, argnums=0, has_aux=True)(
        meta_params, pnorm_param, ensemble_params, t_knots, coefs, T, dt,
        regularizer_l2, regularizer_ctrl, regularizer_error
    )
    #isnan_Wgrad = jnp.any(jnp.isnan(grads['W'][0]))
    #isnan_bgrad = jnp.any(jnp.isnan(grads['b'][0]))
    # jdb.print('{isnan_Wgrad}', isnan_Wgrad=isnan_Wgrad)
    # jdb.print('{isnan_bgrad}', isnan_bgrad=isnan_bgrad)
    # if isnan_Wgrad:
    #     jdb.print("W[0] gradient is nan!")
    #     raise ValueError("W[0] gradient is nan!")
    # elif isnan_bgrad:
    #     jdb.print("b[0] gradient is nan!")
    #     raise ValueError("b[0] gradient is nan!")
    opt_state = update_opt(idx, grads, opt_state)
    return opt_state, aux, grads

# @partial(jax.jit, static_argnums=(6, 7))
def step_pnorm(idx, meta_params, opt_state, ensemble_params, t_knots, coefs, T, dt, regularizer_l2, regularizer_ctrl, regularizer_error):
    """This function only updates the meta_params in an iteration"""
    pnorm_param = get_params(opt_state)
    grads, aux = jax.grad(loss, argnums=1, has_aux=True)(
        meta_params, pnorm_param, ensemble_params, t_knots, coefs, T, dt,
        regularizer_l2, regularizer_ctrl, regularizer_error
    )
    opt_state = update_opt(idx, grads, opt_state)
    return opt_state, aux, grads

# Pre-compile before training
print('META-TRAINING: Pre-compiling ... ', end='', flush=True)
dt = hparams['meta']['dt']
T = hparams['meta']['T']
regularizer_l2 = hparams['meta']['regularizer_l2']
regularizer_ctrl = hparams['meta']['regularizer_ctrl']
regularizer_error = hparams['meta']['regularizer_error']
start = time.time()
_ = step_meta(0, opt_meta, pnorm_param, train_ensemble, train_t_knots, train_coefs, T, dt, regularizer_l2, regularizer_ctrl, regularizer_error)
_ = step_pnorm(0, meta_params, opt_pnorm, train_ensemble, train_t_knots, train_coefs, T, dt, regularizer_l2, regularizer_ctrl, regularizer_error)
_ = loss(meta_params, pnorm_param, valid_ensemble, valid_t_knots, valid_coefs, T, dt,
            0., 0., 0.)
end = time.time()
print('done ({:.2f} s)! Now training ...'.format(
        end - start))
start = time.time()

output_name = "seed={:d}_M={:d}_pinit={:f}_pfreq={:f}_reg_gain".format(hparams['seed'], num_models, args.pnorm_init, hparams['meta']['p_freq'])
print(output_name)

# Do gradient descent
for i in tqdm(range(hparams['meta']['num_steps'])):
    opt_meta, train_aux_meta, grads_meta = step_meta(
        step_meta_idx, opt_meta, pnorm_param, train_ensemble, train_t_knots, train_coefs,
        T, dt, regularizer_l2, regularizer_ctrl, regularizer_error
    )
    # print(train_aux_meta)
    new_meta_params = get_params(opt_meta)

    # Update p-norm parameter
    # The i+1 is to make sure not to update p-norm at step 0
    if (i+1) % hparams['meta']['p_freq'] == 0:
        opt_pnorm, train_aux_pnorm, grads_pnorm = step_pnorm(
            step_pnorm_idx, new_meta_params, opt_pnorm, train_ensemble, train_t_knots, train_coefs,
            T, dt, regularizer_l2, regularizer_ctrl, regularizer_error
        )
        step_pnorm_idx += 1
        # jdebug.print('{grad_pnorm}', grad_pnorm=grads_pnorm)
        # jdebug.print('{meta_grad}', meta_grad=grads_meta)
        print("Update p-norm to {:.2f} at step {:d}".format(convert_qbar_p(get_params(opt_pnorm)['pnorm']), step_meta_idx))
    new_pnorm_param = get_params(opt_pnorm)

    valid_loss, valid_aux = loss(
        new_meta_params, new_pnorm_param, valid_ensemble, valid_t_knots, valid_coefs,
        T, dt, 0., 0., 0.
    )
    #jdb.print(f'{valid_loss}')

    # Only update best_meta_params when the loss is decreasing
    if valid_loss < best_loss:
        best_meta_params = new_meta_params
        best_pnorm_param = new_pnorm_param
        best_loss = valid_loss
        best_idx_meta = step_meta_idx
        best_idx_pnorm = step_pnorm_idx
    step_meta_idx += 1
    # print(step_meta_idx)

# Save hyperparameters, ensemble, model, and controller

results = {
    'best_step_meta': best_idx_meta,
    'best_step_pnorm': best_idx_pnorm,
    'hparams': hparams,
    'ensemble': best_ensemble,
    'model': {
        'W': best_meta_params['W'],
        'b': best_meta_params['b'],
    },
    'controller': best_meta_params['gains'],
    'pnorm': convert_qbar_p(best_pnorm_param['pnorm']),
}
output_path = os.path.join(args.output_dir, output_name + '.pkl')
with open(output_path, 'wb') as file:
    pickle.dump(results, file)

end = time.time()
print("Meta-training completes with p-norm chosen as {:.2f}".format(results['pnorm']))
print('done ({:.2f} s)! Best step index for meta params: {}'.format(end - start, best_idx_meta))

Initialize pnorm as 2.20
META-TRAINING: Pre-compiling ... done (12.28 s)! Now training ...
seed=0_M=10_pinit=2.200000_pfreq=100.000000_reg_gain


  0%|          | 0/1500 [00:00<?, ?it/s]

Update p-norm to 2.23 at step 99
Update p-norm to 2.25 at step 199
Update p-norm to 2.28 at step 299
Update p-norm to 2.30 at step 399
Update p-norm to 2.32 at step 499
Update p-norm to 2.34 at step 599
Update p-norm to 2.36 at step 699
Update p-norm to 2.38 at step 799
Update p-norm to 2.40 at step 899
Update p-norm to 2.41 at step 999
Update p-norm to 2.42 at step 1099
Update p-norm to 2.43 at step 1199
Update p-norm to 2.44 at step 1299
Update p-norm to 2.45 at step 1399
Update p-norm to 2.46 at step 1499
Meta-training completes with p-norm chosen as 2.46
done (1135.59 s)! Best step index for meta params: 1499
