In [1]:
# Import general needed libraries
import jax
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_num_cpu_devices', 30)

import jax.numpy as jnp
import jax.random as jr
import time

%matplotlib notebook
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

# Import necessary libraries for optimization procedure
from typing import NamedTuple
from diffilqrax import ilqr
from diffilqrax.utils import initialise_stable_dynamics, keygen
from diffilqrax.ilqr import ilqr_solver
from diffilqrax.typs import (
    iLQRParams,
    Theta,
    System,
    ModelDims,
)
from src.arm_model import *


# Generate keys
key = jr.PRNGKey(seed=0)
key, skeys = keygen(key, 5)

In [2]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7),
 CpuDevice(id=8),
 CpuDevice(id=9),
 CpuDevice(id=10),
 CpuDevice(id=11),
 CpuDevice(id=12),
 CpuDevice(id=13),
 CpuDevice(id=14),
 CpuDevice(id=15),
 CpuDevice(id=16),
 CpuDevice(id=17),
 CpuDevice(id=18),
 CpuDevice(id=19),
 CpuDevice(id=20),
 CpuDevice(id=21),
 CpuDevice(id=22),
 CpuDevice(id=23),
 CpuDevice(id=24),
 CpuDevice(id=25),
 CpuDevice(id=26),
 CpuDevice(id=27),
 CpuDevice(id=28),
 CpuDevice(id=29)]

In [3]:
# Initialize Task
init_thetas, target_angles, targets = init_radial_task()
init_thetas = init_thetas.squeeze()
init_state = jnp.vstack([init_thetas[:, None], jnp.zeros((2,1))]).squeeze()

In [4]:
# Test update of state when squeezed
update_state(init_state, jnp.zeros((2,)))

Array([0.72972766, 1.68213734, 0.        , 0.        ], dtype=float64)

## iLQR Solver Below (Non-Functioning)

In [5]:
# Define dynamics step
# Add arm_params to theta, update arm state as arm_params
def m1_arm_step(t, x, u, theta):
    """
    """
    m1_state = x[:-4]
    # Forward step M1 dynamics (Tau is 150 ms)
    m1_state_dot = ((-1 * m1_state) + theta.Uh @ jax.nn.relu(m1_state) + theta.h + theta.Wh @ u)/150
    n_m1_state = m1_state + dt * m1_state_dot
    
    return jnp.vstack([n_m1_state[:, None], x[-4:, None]]).squeeze()

In [6]:
def cost(t, x, u, theta):
    """
    """
    arm_x = x[-4:]
    target = theta.target
    arm_x0 = theta.init_thetas
    
    # target cost
    # delay period (300 ms) and T (900 ms)
    t_cost = jax.lax.select(
        t >= 300,
        ((t-300.)**2/900.**2) * jnp.sum((arm_x[:2]-target)**2),
        jnp.array(0.)
    )
    
    # Calculate torques
    neur_x = x[:200]
    torques = theta.C @ jax.nn.relu(neur_x)
    
    # null cost
    alpha_null = 1.
    n_cost = alpha_null * jax.lax.select(
        t < 300,
        (
            jnp.sum((arm_x[:2]-arm_x0)**2) +
            jnp.sum((arm_x[2:])**2) +
            jnp.sum((torques)**2) 
        ),
        jnp.array(0.)
    )
    
    # effort cost
    #NOTE: Want to make this a penalty on the input derivative!
    alpha_effort = 5.e-7
    e_cost = (alpha_effort/len(u)) * jnp.sum((u)**2)
    
    return t_cost + n_cost + e_cost

def costf(x, theta):
    """
    Think this is just the final cost
    """
    # target cost
    arm_final = x[-4:-2]
    target = theta.target
    
    return jnp.sum((arm_final-target)**2)
    
    

In [7]:
class M1ArmParams():
    def __init__(self, Uh, Wh, sigma, Q, C, h, target, init_thetas):
        """
        """
        # Add arguments as attributes
        self.Uh = Uh
        self.Wh = Wh
        self.sigma = sigma
        self.Q = Q
        self.C = C
        self.h = h
        self.target = target
        self.init_thetas = init_thetas
        
    def __repr__(self):
        return f"M1ArmParams(Uh={self.Uh}, Wh={self.Wh}, sigma={self.sigma}, Q={self.Q}, C={self.C}, h={self.h}, target={self.target}, init_thetas={init_thetas})"

# Register the class with JAX's pytree system
def tree_flatten(obj):
    # Separate JAX arrays (children/leaves) from static data (metadata)
    children = (obj.Uh, obj.Wh, obj.sigma, obj.Q, obj.C, obj.h, obj.target, obj.init_thetas)
    return (children, None)

def tree_unflatten(aux_data, children):
    # Reconstruct the class from flattened data
    Uh, Wh, sigma, Q, C, h, target, init_thetas = children
    return M1ArmParams(
        Uh=Uh, 
        Wh=Wh,
        sigma=sigma,
        Q=Q,
        C=C,
        h=h,
        target=target,
        init_thetas=init_thetas
    )

jax.tree_util.register_pytree_node(
    M1ArmParams,
    tree_flatten,
    tree_unflatten
)

In [8]:
dims = ModelDims(horizon=1200, n=204, m=200, dt=0.001)

key = jr.PRNGKey(seed=234)
key, skeys = keygen(key, 10)

# TODO: Design Dynamics matrix using Schur Decomp
# Start with stable dynamics to see if even trainable
Uh = initialise_stable_dynamics(next(skeys), dims.n-4, dims.horizon, 0.6)[0]
# Identity for Input Matrix, B
Wh = jnp.identity(dims.m)
# Initialize Dynamics Parameters
C = (.05/jnp.sqrt(dims.n)) * jr.normal(next(skeys), shape=(2, dims.n-4))
h = 5. + 5. * jr.normal(next(skeys), shape=(dims.n-4,))

target = jr.choice(next(skeys), target_angles).reshape(-1,)

theta = M1ArmParams(
    Uh, Wh, jnp.zeros(dims.n), jnp.eye(dims.n), C, h, target, init_thetas
)

# Create iLQR Params
init_state = jnp.vstack([
    jr.normal(next(skeys), (dims.n-4,1)), # Neural State
    jnp.vstack([theta.init_thetas[:, None], jnp.zeros((2,1))]) # Arm State
]).squeeze()
params = iLQRParams(x0=init_state, theta=theta)

Us = jnp.zeros((dims.horizon, dims.m))   
# define linesearch hyper parameters
ls_kwargs = {
    "beta":0.8,
    "max_iter_linesearch":16,
    "tol":1e0,
    "alpha_min":0.0001,
    }

model = System(cost, costf, m1_arm_step, dims)

In [9]:
jax.config.update("jax_debug_nans", True)

In [10]:
(Cxx,Cxu), (_,Cuu) = model.quad_cost(0., init_state, jnp.zeros(200,), theta)

In [11]:
Cxu

Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float64)

In [12]:
# Test dynamics function
res = m1_arm_step(t=0., x=jnp.zeros(204,), u=jnp.zeros(200,), theta=theta)
res.shape

(204,)

In [None]:
start_time = time.time()
(opt_xs, opt_us, opt_lambdas), ilqr_fcost, cost_log = ilqr.ilqr_solver(
    model,
    params,
    Us,
    max_iter=70,
    convergence_thresh=1e-10,
    alpha_init=1.,
    verbose=True,
    use_linesearch=True,
    **ls_kwargs
)
opt_time = time.time() - start_time

In [None]:
cost_log

In [None]:
fig = plt.figure()
plt.plot(opt_us[])
plt.show()

In [None]:
opt_time/60 #min

In [None]:
def calc_pos(thetas):
    """
    """
    # Extract necessary state vars
    theta1, theta2 = thetas
    # Calculate positions and return
    elbow_pos = jnp.array([
        (L1*jnp.cos(theta1)), (L1*jnp.sin(theta1))
    ])
    print(elbow_pos.shape)

    hand_pos = jnp.array(
        [(elbow_pos[0] + L2*jnp.cos(theta1+theta2)),
         (elbow_pos[1] + L2*jnp.sin(theta1+theta2))]
    )

    return jnp.vstack([hand_pos, elbow_pos])

In [None]:
pos_traj = jax.vmap(calc_arm_pos)(opt_xs[:, -4:-2])

In [None]:
def plot_task_state(pos, targets, ax=None):
    """
    """
    targets_cm = targets * 100 #cm

    # Create figure if not given
    if ax == None:
        fig, ax = plt.subplots()
    else:
        ax.clear()

    # Convert positions to cm
    pos_cm = pos*100

    # Target locations
    ax.scatter(targets_cm[:, 0], targets_cm[:, 1], marker='s', color='green', s=200) #TODO, figure out how big these need to be

    # Start Locations
    ax.scatter(0., 0., color='red') # Shoulder
    ax.scatter(pos_cm[0, 0], pos_cm[0, 1], color='blue') # Hand
    ax.scatter(pos_cm[1, 0], pos_cm[1, 1], color='green') # Elbow

    # Make arms by connecting joints
    ax.plot([0., pos_cm[1, 0]], [0., pos_cm[1, 1]], color='blue') # Lower Arm
    ax.plot([pos_cm[0, 0], pos_cm[1, 0]], [pos_cm[0, 1], pos_cm[1, 1]], color='blue') # Upper Arm

    # Set axis boundaries
    ax.set_xlim([-60, 60])
    ax.set_ylim([-30, 60])

    if ax == None:
        plt.show()
    else:
        return ax

In [None]:
fig, ax = plt.subplots()

num_frames = 500
frame_idx = jnp.linspace(0, len(pos_traj)-1, num_frames, dtype=int)
sampled_pos = pos_traj[frame_idx]

ani = FuncAnimation(fig, lambda f: plot_task_state(sampled_pos[f], targets, ax), frames=num_frames, interval=1 ,repeat=False, blit=True)
plt.show()

In [None]:
target

In [None]:
pos_traj[-1]