In [352]:
import argparse
import os
import pickle
import time
from itertools import product, islice
import math

import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
import jax.debug as jdb

import numpy as np

from tqdm.auto import tqdm

from dynamics import prior, disturbance, plant
from utils import params_to_posdef
from utils import random_ragged_spline, spline
from utils import (tree_normsq, rk38_step, epoch,   # noqa: E402
                   odeint_fixed_step, odeint_ckpt, random_ragged_spline, spline,
            params_to_cholesky, params_to_posdef, vee, hat,
            quaternion_to_rotation_matrix, flat_rotation_matrix_to_quaternion, 
            quaternion_multiply)

from functools import partial
import matplotlib.pyplot as plt
import csv


plt.rcParams.update({'font.size': 24})

In [353]:
def convert_p_qbar(p):
    return np.sqrt(1/(1 - 1/p) - 1.1)

def convert_qbar_p(qbar):
    return 1/(1 - 1/(1.1 + qbar**2))

In [354]:
# Support functions for generating loop reference trajectory
# def reference(t):
#     T = 5.            # loop period
#     d = 2.             # displacement along `x` from `t=0` to `t=T`
#     w = 2.             # loop width
#     h = 3.             # loop height
#     ϕ_max = jnp.pi/3   # maximum roll angle (achieved at top of loop)

#     x = (w/2)*jnp.sin(2*jnp.pi * t/T) + d*(t/T)
#     y = (h/2)*(1 - jnp.cos(2*jnp.pi * t/T))
#     ϕ = 4*ϕ_max*(t/T)*(1-t/T)
#     r = jnp.array([x, y, ϕ])
#     return r

def odeint_fixed_step(func, x0, t0, t1, step_size, *args):
    """TODO: docstring."""
    # Use `numpy` for purely static operations on static arguments
    # (see: https://github.com/google/jax/issues/5208)
    # jdebug.print('{t1_type}', t1_type=type(t1))
    num_steps = 10

    ts = jnp.linspace(t0, t1, num_steps + 1)
    xs = odeint_ckpt(func, x0, ts, *args)
    return xs, ts

In [355]:
xmin_ = -4.25
xmax_ = 4.5
ymin_ = -3.5
ymax_ = 4.25
zmin_ = 0.0
zmax_ = 2.0

J = jnp.diag(jnp.array([0.03, 0.03, 0.09]))

def plant_attitude(R_flatten, Omega, u_d, params):
    k_R = params['k_R']
    k_Omega = params['k_Omega']
    R = R_flatten.reshape((3,3))

    # f_d = jnp.linalg.norm(u)
    b_3d = u_d / jnp.linalg.norm(u_d)
    b_1d = jnp.array([1, 0, 0])
    cross = jnp.cross(b_3d, b_1d)
    b_2d = cross / jnp.linalg.norm(cross)

    R_d = jnp.column_stack((jnp.cross(b_2d, b_3d), b_2d, b_3d))

    Omega_d = jnp.array([0, 0, 0])
    dOmega_d = jnp.array([0, 0, 0])

    e_R = 0.5 * vee(R_d.T@R - R.T@R_d)
    e_Omega = Omega - R.T@R_d@Omega_d

    M = - k_R*e_R \
        - k_Omega*e_Omega \
        + jnp.cross(Omega, J@Omega) \
        - J@(hat(Omega)@R.T@R_d@Omega_d - R.T@R_d@dOmega_d)

    dOmega = jax.scipy.linalg.solve(J, M - jnp.cross(Omega, J@Omega), assume_a='pos')
    dR = R@hat(Omega)
    dR_flatten = dR.flatten()

    return (dR_flatten, dOmega)

In [356]:
@partial(jax.vmap, in_axes=(None, None, 0, 0))
def test_simulate(ts, params, t_knots, coefs, NGD_flag=True, controller_type='adaptive',
                plant=plant, prior=prior, ):
    """TODO: docstring."""
    # Required derivatives of the reference trajectory
    def reference(t):
        x_coefs, y_coefs, z_coefs = coefs
        x = spline(t, t_knots, x_coefs)
        y = spline(t, t_knots, y_coefs)
        z = spline(t, t_knots, z_coefs) + 1.
        x = jnp.clip(x, xmin_, xmax_)
        y = jnp.clip(y, ymin_, ymax_)
        z = jnp.clip(z, zmin_, zmax_)
        r = jnp.array([x, y, z])
        return r
    
    # f_ext = jnp.array([-q[1]**2, -q[2]**2, -q[0]**2])
    # return f_ext

    # f_ext = 3*jnp.array([q[0]**2, q[1]**2, q[2]**2])
    # return f_ext

    # return jnp.array([0, 0, 0])
    
    def ref_derivatives(t):
        ref_vel = jax.jacfwd(reference)
        ref_acc = jax.jacfwd(ref_vel)
        r = reference(t)
        dr = ref_vel(t)
        ddr = ref_acc(t)
        return r, dr, ddr
    
    def ensemble_disturbance(q, dq, R_flatten, Omega, W, b, A):
        test_i = 0
        f_ext = jnp.concatenate((q, dq, R_flatten, Omega), axis=0)
        for W, b in zip(W, b):
            f_ext = jnp.tanh(W[test_i]@f_ext + b[test_i])
        f_ext = A[test_i] @ f_ext
        return 5*f_ext

    # Adaptation law
    def adaptation_law(q, dq, R_flatten, Omega, r, dr, params=params, NGD_flag=NGD_flag):
        # Regressor features
        y = jnp.concatenate((q, dq, R_flatten, Omega))
        for W, b in zip(params['W'], params['b']):
            y = jnp.tanh(W@y + b)

        # y = jnp.array([q[0]**2, 0, 0])
        # y = jnp.zeros(32)
        # y = y.at[0].set(q[0]**2)
        # y = y.at[1].set(q[1]**2)
        # y = y.at[2].set(q[2]**2)

        # Auxiliary signals
        Λ, P = params['Λ'], params['P']
        e, de = q - r, dq - dr
        s = de + Λ@e

        if NGD_flag:
            # P = 1e-3*jnp.eye(32)
            dA = jnp.outer(s, y) @ P
        else:
            # P = 0.5*jnp.eye(3)
            dA = P @ jnp.outer(s, y)
        return dA, y

    # Controller
    def controller(q, dq, r, dr, ddr, f_hat, params=params, type='adaptive'):
        if type == 'adaptive':
            # Auxiliary signals
            Λ, K = params['Λ'], params['K']
            # Λ = jnp.diag(np.array([0.7104868, 1.1266986, 0.7063684]))
            # K = jnp.diag(np.array([0.6291897, 0.6679145, 0.5176407]))

            e, de = q - r, dq - dr
            # jdb.print('e: {}', e)
            # jdb.print('de: {}', de)
            s = de + Λ@e
            # jdb.print('{}', s)
            v, dv = dr - Λ@e, ddr - Λ@de
            # jdb.print('{}', v)

            # Control input and adaptation law
            H, C, g, B = prior(q, dq)
            τ = H@dv + C@v + g - f_hat - K@s
            u_d = jnp.linalg.solve(B, τ)
            # jdb.print('{}', u_d)
            return u_d, τ 
        # elif type == 'pid':
        #     e, edot = r - q, dr - dq
        #     maxPosErr = jnp.array([1.0, 1.0, 1.5])
        #     maxVelErr = jnp.array([3.0, 3.0, 3.0])
        #     Kp = jnp.array([2.0, 2.0, 2.0])
        #     # Ki = jnp.array([0.0, 0.0, 0.0])
        #     Kd = jnp.array([3.0, 3.0, 3.0])
        #     mass = 1.3

        #     e = jnp.minimum(jnp.maximum(e, -maxPosErr), maxPosErr)
        #     edot = jnp.minimum(jnp.maximum(edot, -maxVelErr), maxVelErr)
        #     a_fb = jnp.multiply(Kp, e) + jnp.multiply(Kd, edot)
        #     f_d = mass * (ddr + a_fb - jnp.array([0, 0, -9.80665]))

        #     xi = f_d / mass
        #     abc = xi / jnp.linalg.norm(xi)
        #     a, b, c = abc
        #     psi = 0
        #     invsqrt21pc = 1 / jnp.sqrt(2 * (1 + c))
        #     quaternion0 = jnp.array([invsqrt21pc*(1+c), invsqrt21pc*(-b), invsqrt21pc*a, 0.0])
        #     quaternion1 = jnp.array([jnp.cos(psi/2.), 0.0, 0.0, jnp.sin(psi/2.)])
        #     q_ref = quaternion_multiply(quaternion0, quaternion1)
        #     q_ref = q_ref / jnp.linalg.norm(q_ref)
        #     R_d = quaternion_to_rotation_matrix(q_ref)

        #     e_3 = jnp.array([0, 0, 1])
        #     u_d = f_d*R_d@e_3
        #     τ = u_d
        #     return u_d, τ 

    
    # Closed-loop ODE for `x = (q, dq)`, with a zero-order hold on
    # the controller
    def ode(x, t, u, u_d):
        q, dq, R_flatten, Omega = x
        f_ext = ensemble_disturbance(q, dq, R_flatten, Omega, params['ensemble']['W'], params['ensemble']['b'], params['ensemble']['A'])
        ddq = plant(q, dq, u, f_ext)
        dR_flatten, dOmega, = plant_attitude(R_flatten, Omega, u_d, params)
        dx = (dq, ddq, dR_flatten, dOmega)
        return dx

    # Simulation loop
    def loop(carry, input_slice, params=params):
        t_prev, q_prev, dq_prev, R_flatten_prev, Omega_prev, u_prev, u_d_prev, A_prev, dA_prev, pA_prev = carry
        # jdb.print("t_prev: {}", t_prev)
        # jdb.print("q_prev: {}", q_prev)
        # jdb.print("dq_prev: {}", dq_prev)
        # jdb.print("R_flatten_prev: {}", R_flatten_prev)
        # jdb.print("Omega_prev: {}", Omega_prev)
        # jdb.print("u_prev: {}", u_prev)
        t = input_slice
        zs, ts = odeint_fixed_step(ode, (q_prev, dq_prev, R_flatten_prev, Omega_prev), t_prev, t, 2e-3,
                            u_prev, u_d_prev)
        qs, dqs, R_flattens, Omegas = zs
        q = qs[-1]
        dq = dqs[-1]
        R_flatten = R_flattens[-1]
        Omega = Omegas[-1]

        r, dr, ddr = ref_derivatives(t)

        if NGD_flag:
            qn = 1.1 + params['pnorm']**2

            # Integrate adaptation law via trapezoidal rule
            dA, y = adaptation_law(q, dq, R_flatten, Omega, r, dr)
            # jdb.print('t: {}', t)
            # jdb.print("dA: {}", jnp.linalg.norm(dA))
            pA = pA_prev + (t - t_prev)*(dA_prev + dA)/2
            # jdb.print('pA_prev: {}', jnp.linalg.norm(pA_prev))
            # jdb.print('pA_new: {}', jnp.linalg.norm((t - t_prev)*(dA_prev + dA)/2))
            # jdb.print('pA: {}', jnp.linalg.norm(pA))
            P = params['P']
            A = (jnp.maximum(jnp.abs(pA), 1e-6 * jnp.ones_like(pA))**(qn-1) * jnp.sign(pA) * (jnp.ones_like(pA) - jnp.isclose(pA, 0, atol=1e-6)) ) @ P
            # jdb.print('pre_A: {}', jnp.linalg.norm((jnp.maximum(jnp.abs(pA), 1e-6 * jnp.ones_like(pA))**(qn-1) * jnp.sign(pA) * (jnp.ones_like(pA) - jnp.isclose(pA, 0, atol=1e-6)) )))
            # jdb.print('A: {}', jnp.linalg.norm(A))
            # A = (jnp.maximum(jnp.abs(pA), 1e-6 * jnp.ones_like(pA))**(qn-1) * jnp.sign(pA) * (jnp.ones_like(pA) - jnp.isclose(pA, 0, atol=1e-6)) ) @ params['P']
            # A = jnp.abs(pA)**(qn-1) * jnp.sign(pA) @ params['P']
        else:
            # Integrate adaptation law via trapezoidal rule
            dA, y = adaptation_law(q, dq, R_flatten, Omega, r, dr)
            A = A_prev + (t - t_prev)*(dA_prev + dA)/2
            pA = pA0

        # Compute force estimate and control input
        f_hat = A @ y
        # f_hat = jnp.array([1, 1, 1])
        u_d, τ = controller(q, dq, r, dr, ddr, f_hat, type=controller_type)
        
        # Decompose u_d into f_d and R_d to find u
        b_3d = u_d / jnp.linalg.norm(u_d)
        b_1d = jnp.array([1, 0, 0])
        cross = jnp.cross(b_3d, b_1d)
        b_2d = cross / jnp.linalg.norm(cross)
        R_d = jnp.column_stack((jnp.cross(b_2d, b_3d), b_2d, b_3d))
        R_d_flatten = R_d.flatten()

        f_d = jnp.linalg.norm(u_d)
        R = R_flatten.reshape((3,3))
        e_3 = jnp.array([0, 0, 1])
        u = f_d*R@e_3

        f_ext = ensemble_disturbance(q, dq, R_flatten, Omega, params['ensemble']['W'], params['ensemble']['b'], params['ensemble']['A'])

        carry = (t, q, dq, R_flatten, Omega, u, u_d, A, dA, pA)
        flat_A = A.flatten()
        output_slice = (q, dq, R_flatten, Omega, u, u_d, τ, r, dr, R_d_flatten, f_hat, f_ext, y, flat_A)
        return carry, output_slice

    # Initial conditions
    t0 = ts[0]
    r0, dr0, ddr0 = ref_derivatives(t0)
    q0, dq0 = r0, dr0
    # q0, dq0 = jnp.array([1., 1., 1.]), jnp.array([1., 1., 1.])
    # R_flatten0 = jnp.zeros(9)
    # R0 = jnp.array(
    #     [[ 1.,  0.,  0.],
    #     [ 0., -1.,  0.],
    #     [ 0.,  0., -1.]]
    # )
    R0 = jnp.array(
        [[ 1.,  0.,  0.],
        [ 0., 1.,  0.],
        [ 0.,  0., 1.]]
    )
    # R0 = jnp.array(
    #     [[-0.09478876, -0.92590127,  0.36568008],
    #      [-0.81175405,  0.28452523,  0.51000074],
    #      [-0.57625555, -0.24849995, -0.7785739 ]]
    # )
    # R0 = jnp.array(
    #     [[ 0.80332876, -0.56150749,  0.19842441],
    #      [-0.54055502, -0.82732758, -0.15273949],
    #      [ 0.24992635,  0.01544071, -0.96814173]]
    # )
    R_flatten0 = R0.flatten()
    # R_d0 = jnp.array(
    #     [[ 1.,  0.,  0.],
    #     [ 0., -1.,  0.],
    #     [ 0.,  0., -1.]]
    # )
    R_d0 = jnp.array(
        [[ 1.,  0.,  0.],
        [ 0., 1.,  0.],
        [ 0.,  0., 1.]]
    )
    R_d_flatten0 = R_d0.flatten()
    Omega0 = jnp.zeros(3)
    dA0, y0 = adaptation_law(q0, dq0, R_flatten0, Omega0, r0, dr0)
    A0 = jnp.zeros((q0.size, y0.size))
    # pA0 = jnp.ones((q0.size, y0.size))
    pA0 = jnp.zeros((q0.size, y0.size))
    f0 = A0 @ y0
    u_d0, τ0 = controller(q0, dq0, r0, dr0, ddr0, f0, type=controller_type)
    u0 = u_d0
    f_d = jnp.linalg.norm(u_d0)
    R0 = R_flatten0.reshape((3,3))
    e_3 = jnp.array([0, 0, 1])
    u0 = f_d*R0@e_3
    f_ext0 = ensemble_disturbance(q0, dq0, R_flatten0, Omega0, params['ensemble']['W'], params['ensemble']['b'], params['ensemble']['A'])

    flat_A0 = A0.flatten()

    # Run simulation loop
    carry = (t0, q0, dq0, R_flatten0, Omega0, u0, u_d0, A0, dA0, pA0)
    carry, output = jax.lax.scan(loop, carry, ts[1:])
    q, dq, R_flatten, Omega, u, u_d, τ, r, dr, R_d_flatten, f_hat, f_ext, y, flat_A = output

    # Prepend initial conditions
    q = jnp.vstack((q0, q))
    dq = jnp.vstack((dq0, dq))
    R_flatten = jnp.vstack((R_flatten0, R_flatten))
    R_d_flatten = jnp.vstack((R_d_flatten0, R_d_flatten))
    Omega = jnp.vstack((Omega0, Omega))
    u = jnp.vstack((u0, u))
    u_d = jnp.vstack((u_d0, u_d))
    τ = jnp.vstack((τ0, τ))
    r = jnp.vstack((r0, r))
    dr = jnp.vstack((dr0, dr))
    f_hat = jnp.vstack((f0, f_hat))
    f_ext = jnp.vstack((f_ext0, f_ext))
    flat_A = jnp.vstack((flat_A0, flat_A))
    y = jnp.vstack((y0, y))

    sim = {"q": q, "dq": dq, "R_flatten": R_flatten, "Omega": Omega, "u": u, "u_d": u_d, "τ": τ, "r": r, "dr": dr, "R_d_flatten": R_d_flatten,"f_hat": f_hat, "f_ext": f_ext, "y": y, "A": flat_A}

    return sim

In [357]:
def eval_single_model(model_dir, filename, T, dt, num_traj, pnorm_flag=True, visual_verbose=0, save_dir=None, fixed_P=None, controller_type='adaptive'):
    model_pkl_loc = os.path.join(model_dir, filename)
    with open(model_pkl_loc, 'rb') as f:
        train_results = pickle.load(f)

    test_results = {}
    test_params = {}  

    parts = filename.replace('.pkl', '').split('_')
        
    # Dictionary to hold the attributes for this file
    test_results['train_params'] = {}
    # Loop through each part of the filename
    for part in parts:
        # Split each part by '=' to separate the key and value
        key, value = part.split('=')
        # Convert value to float if it looks like a number, else keep as string
        try:
            test_results['train_params'][key] = float(value)
        except ValueError:
            test_results['train_params'][key] = value

    # Post-process training loss information
    # train_aux = train_results['train_lossaux_history']
    # train_loss_history = jnp.zeros(train_results['train_params']['E'])
    # for i in range(test_results['train_params']['E']):
    #     train_loss_history[i] = train_aux[i]['tracking_loss'] + 1e-3 * train_aux[i]['control_loss'] + 1e-4 * train_aux[i]['l2_penalty'] + test_results['train_params']['regP'] * train_aux[i]['reg_P_penalty']

    if not pnorm_flag:
        test_results['train_info'] = {
            # 'best_step_meta': train_results['best_step_meta'],
            'ensemble': train_results['ensemble'],
            # 'valid_loss_history': train_results['valid_loss_history'],
            # 'train_loss_history': train_results['train_loss_history'],
            # 'pnorm_history': train_results['pnorm_history']
        }
    else:
        test_results['train_info'] = {
            'best_step_meta': train_results['best_step_meta'],
            'ensemble': train_results['ensemble'],
            'valid_loss_history': train_results['valid_loss_history'],
            # 'train_loss_history': train_results['train_loss_history'],
            'pnorm_history': train_results['pnorm_history'],
            'ensemble_loss': train_results['ensemble_loss']
        }
    # test_results['train_info']['pnorm_history'].prepend(test_results['train_params']['pinit'])

    if pnorm_flag:
        test_results['final_p'] = train_results['pnorm']
        # Note that the pnorm stored in pickle is the actual p
        # To run evaluation script, we convert params['pnorm'] to qbar
        test_params['pnorm'] = convert_p_qbar(train_results['pnorm'])
    else:
        test_results['final_p'] = 2.0
    

    # Store the model parameters
    test_params['W'] = train_results['model']['W']
    test_params['b'] = train_results['model']['b']
    test_params['Λ'] = params_to_posdef(train_results['controller']['Λ'])
    test_params['K'] = params_to_posdef(train_results['controller']['K'])
    if 'k_R' in train_results['controller'].keys():
        test_params['k_R'] = train_results['controller']['k_R']
        test_params['k_Omega'] = train_results['controller']['k_Omega']
    else:
        test_params['k_R'] = jnp.array([1.4, 1.4, 1.26])
        test_params['k_Omega'] = jnp.array([0.330, 0.330, 0.300])

    if 'P' in train_results['controller'].keys():
        test_params['P'] = params_to_posdef(train_results['controller']['P'])
    else:
        # P_size = params_to_posdef(train_results['controller']['P']).shape[0]
        # test_params['P'] = jnp.eye(P_size) * fixed_P
        # test_params['P'] = 1e-1*jnp.eye(32)
        test_params['P'] = jnp.eye(32)

    test_params['ensemble'] = {
        'W': train_results['ensemble']['W'],
        'b': train_results['ensemble']['b'],
        'A': train_results['ensemble']['A']
    }

    seed = 0
    key = jax.random.PRNGKey(seed)

    T = 30
    num_knots = 6
    poly_orders = (9, 9, 9)
    deriv_orders = (4, 4, 4)
    min_step = jnp.array([-2, -2, -0.25])
    max_step = jnp.array([2, 2, 0.25])
    min_knot = jnp.array([xmin_, ymin_, zmin_-1])
    max_knot = jnp.array([xmax_, ymax_, zmax_-1])

    key, *subkeys = jax.random.split(key, 1 + num_traj)
    subkeys = jnp.vstack(subkeys)
    in_axes = (0, None, None, None, None, None, None, None, None)
    t_knots, knots, coefs = jax.vmap(random_ragged_spline, in_axes)(
        subkeys, T, num_knots, poly_orders, deriv_orders,
        min_step, max_step, min_knot, max_knot
    )
    r_knots = jnp.dstack(knots)

    # Test on new trajectories
    ts = jnp.arange(0, T, dt)
    sim = test_simulate(ts, test_params, t_knots, coefs)

    sim_e = sim['q'] - sim['r']
    tracking_error = jnp.mean(jnp.linalg.norm(sim_e, axis=2))
    sim_ftilde = sim['f_hat'] - sim['f_ext']
    estimation_error = jnp.mean(jnp.linalg.norm(sim_ftilde, axis=2))

    test_results['tracking_err'] = tracking_error
    test_results['estimation_err'] = estimation_error

    sim['ts'] = ts
    test = {
            'sim': sim,
            'params': test_params,
            'results': test_results,
            }
    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        with open(os.path.join(save_dir, 'test_' + filename), 'wb') as f:
            pickle.dump(test, f)

    return sim, test_params, test_results

In [358]:
T_ = 30
dt_ = 0.02
num_traj = 10
pnorm_flag = True
visual_verbose = 6
controller_type = 'adaptive'
train_results_dir = 'train_results'

model_results = {}

# populate model_results
for subdir, dirs, files in os.walk(train_results_dir):
    for filename in files:
        if filename.endswith('.pkl'):
            _, trial_name = subdir.split('/')
            model_dir = os.path.join(train_results_dir, trial_name)

            sim, test_params, test_results = eval_single_model(model_dir, filename, T_, dt_, num_traj, pnorm_flag=pnorm_flag, visual_verbose=visual_verbose, controller_type=controller_type)

            Λ = test_params['Λ']
            K = test_params['K']
            P = test_params['P']
            k_R = test_params['k_R']
            k_Omega = test_params['k_Omega']
            eig_Ps = jnp.linalg.eigh(test_params['P'])[0]
            tracking_error = test_results['tracking_err']
            best_epoch = test_results['train_info']['best_step_meta']

            print(trial_name)
            model_results[trial_name] = (tracking_error, best_epoch, k_R, k_Omega)

reg_P_2e-3_constant_Kr
reg_P_2e-3_reg_k_R_2e-3_k_R_z_1
reg_P_1_kROmega
z_up_reg_P_1
reg_P_1e-1_reg_k_R_0_k_R_scale_1.1_k_R_z_1.4
reg_P_5e-1_reg_k_R_0_k_R_scale_1.5
reg_P_2e-3_reg_k_R_0_k_R_scale_1.35
reg_P_1_constant_Kr_vector
reg_P_2e-3_reg_k_R_2e-3_k_R_z_1.26_depth_3
reg_P_2e-3_reg_k_R_2e-3_k_R_z_0.3
reg_P_10_constant_Kr
reg_P_2e-3_reg_k_R_1e-6
reg_P_1_reg_Kr_1e-3
reg_P_2e-3_reg_k_R_0_k_R_scale_2
reg_P_5000_constant_Kr_eigs
reg_P_5e-1_reg_k_R_0_k_R_scale_2.5
reg_P_1e-3_constant_Kr
reg_P_1_constant_Kr
reg_P_10000_constant_Kr
reg_P_2e-3_reg_k_R_2e-3_k_R_z_1.0
z_up_reg_P_10
reg_P_1e-2_reg_k_R_2e-3_k_R_scale_3
reg_P_5e-5_constant_Kr
reg_P_1e-1_reg_k_R_1e-6
reg_P_1_reg_k_R_2e-3_k_R_scale_3
reg_P_1_reg_k_R_1e-1
reg_P_1e-2_reg_k_R_0_k_R_scale_1.1_k_R_z_1.4
reg_P_2e-3_reg_k_R_0_k_R_scale_1.1
hardware_16
reg_P_1e-2_reg_k_R_0
reg_P_100_constant_Kr
reg_P_2e-3_Kr_reg_k_R_2e-3
reg_P_2e-3_reg_k_R_2e-3_k_R_z_0.1
z_up_reg_P_5e-1_kRz
reg_P_1e-1_constant_Kr
reg_P_1e-2_reg_k_R_0_k_R_scale_1_k_R_z_1.4
r

In [359]:
# update model_results
for subdir, dirs, files in os.walk(train_results_dir):
    for filename in files:
        if filename.endswith('.pkl'):
            _, trial_name = subdir.split('/')
            model_dir = os.path.join(train_results_dir, trial_name)

            if trial_name not in model_results.keys():
                sim, test_params, test_results = eval_single_model(model_dir, filename, T_, dt_, num_traj, pnorm_flag=pnorm_flag, visual_verbose=visual_verbose, controller_type=controller_type)

                Λ = test_params['Λ']
                K = test_params['K']
                P = test_params['P']
                k_R = test_params['k_R']
                k_Omega = test_params['k_Omega']
                eig_Ps = jnp.linalg.eigh(test_params['P'])[0]
                tracking_error = test_results['tracking_err']
                best_epoch = test_results['train_info']['best_step_meta']

                print(trial_name)
                model_results[trial_name] = (tracking_error, best_epoch, k_R, k_Omega)

In [365]:
def sort_key(item):
    value = item[1][0]
    # Check if the value is nan
    if np.isnan(value):
        return float('inf')  # Treat nan as the smallest possible value
    return float(value)

sorted_model_results = dict(sorted(model_results.items(), key=lambda item: sort_key(item)))

top_n = 20
counter = 0
for key, value in sorted_model_results.items():
    trial_name = key
    tracking_error, best_epoch, k_R, k_Omega = value

    if best_epoch > 200 and not np.any(k_R > 2):
        model_dir = os.path.join(train_results_dir, trial_name)
        filename = [file for file in os.listdir(model_dir) if file.endswith('.pkl')][0]
        print(f'trial_name = \'{trial_name}\'')
        print(f'filename = \'{filename}\'')
        print(f'\taverage tracking error: {tracking_error}')
        print(f'\tbest epoch: {best_epoch}')
        print(f'\tKr: [{k_R[0]}, {k_R[1]}, {k_R[2]}]')
        print(f'\tKomega: [{k_Omega[0]}, {k_Omega[1]}, {k_Omega[2]}]')
        counter += 1
    
    if counter >= top_n:
        break

trial_name = 'reg_P_2e-3_reg_k_R_2e-3_k_R_z_1.4'
filename = 'seed=0_M=50_E=1000_pinit=2.00_pfreq=2000_regP=0.0020.pkl'
	average tracking error: 0.07914063334465027
	best epoch: 999
	Kr: [1.4627842903137207, 1.467042326927185, 1.280902624130249]
	Komega: [0.2666648030281067, 0.2610495686531067, 0.39015090465545654]
trial_name = 'reg_P_1e-1_reg_k_R_0_k_R_scale_1_k_R_z_1.4'
filename = 'seed=0_M=50_E=1000_pinit=2.00_pfreq=2000_regP=0.1000.pkl'
	average tracking error: 0.07956773042678833
	best epoch: 999
	Kr: [1.462704062461853, 1.4668426513671875, 1.2790851593017578]
	Komega: [0.26672986149787903, 0.26166218519210815, 0.39088571071624756]
trial_name = 'reg_P_1_reg_Kr_1e-3'
filename = 'seed=0_M=50_E=1000_pinit=2.00_pfreq=2000_regP=1.0000.pkl'
	average tracking error: 0.09446663409471512
	best epoch: 999
	Kr: [1.4619311094284058, 1.4632068872451782, 0.8545881509780884]
	Komega: [0.2674632668495178, 0.27009400725364685, 0.3497861325740814]
trial_name = 'reg_P_1e-2_reg_k_R_0_k_R_scale_1.1_k_R

In [361]:
# pkl = f'{trial_name}/{filename}'
# with open(f'train_results/{pkl}', 'rb') as file:
#     raw = pickle.load(file)

# print('Best epoch:', raw['best_step_meta'])

# validate_model(raw, 10, mystery_i=0, print_=True, plot=True)
    
# E = int(raw['hparams']['ensemble']['num_epochs'])
# plot_losses(raw)
# plot_losses(raw, xlim=[0,5])
# plot_losses(raw, ylim=[0,100])