# Inference Test 😎

In [None]:
from dataclasses import dataclass
import jax.numpy as jnp

@dataclass
class Problem:
    n : int
    rigid_cov : jnp.array
    transition_cov : jnp.array
    
    """
    amplify_fun(1, 0.02, 0.1),
    amplify_fun(1, 0.25, 0.1),
    amplify_fun(30, 0.001, 0.0001),
    amplify_fun(5, 0.02, 0.005),
    amplify_fun(3, 0.05, 0.01)"""

# "no_rigid_phases" <- normal run
    
problems = {
    "best_run" : Problem(1, jnp.array([0.02]), jnp.array([0.1])),
    "long_rigid_phase" : Problem(1, jnp.array([0.25]), jnp.array([0.1])),
    "many_tiny_stops" : Problem(30, jnp.array([0.001] * 30), jnp.array([0.0001] * 30)),
    "some_short_stops" : Problem(5, jnp.array([0.02] * 5), jnp.array([0.005] * 5)),
    "some_rigid_phases" : Problem(3, jnp.array([0.05] * 3), jnp.array([0.01] * 3))
}

### Generate Data

Define System

In [None]:
import x_xy

three_seg_rigid = r"""
<x_xy model="three_seg_rigid">
    <options gravity="0 0 9.81" dt="0.01"/>
    <worldbody>
        <body name="seg2" joint="free">
            <body name="seg1" joint="rsry">
                <body name="imu1" joint="frozen"/>
            </body>
            <body name="seg3" joint="rsrz">
                <body name="imu2" joint="frozen"/>
            </body>
        </body>
    </worldbody>
</x_xy>
"""

Define Logic

In [None]:
import jax
import jax.numpy as jnp
from jax import random
import x_xy

"""
Calculates an amplifying-function, which can be used to decrease the values of another array with the same length.
An array of shape (time / sampling rate) will be returned, containing values between 0 and 1.
The covariance values are relative to the length of the array that will be created.
"""


def motion_amplifier(
        time,
        sampling_rate,
        key_rigid_phases,
        n_rigid_phases=3,
        rigid_duration_cov=jnp.array([0.02] * 3),
        transition_cov=jnp.array([0.1] * 3)
) -> jnp.ndarray:
    assert rigid_duration_cov.shape == (n_rigid_phases,) == transition_cov.shape, "motion_amplifier: There must be a variance for each rigid phase!"
    n_frames = int(time / sampling_rate)
    key_rigid_means, key_rigid_variances, key_slope_down_variances, key_slope_up_variances = random.split(
        key_rigid_phases, 4)

    # Calculate center points of rigid phases
    means = jnp.sort(random.uniform(key_rigid_means, shape=(
        n_rigid_phases, 1), minval=0, maxval=n_frames).T)

    # Calculate durations, which is twice the rigid distance from the center points for each rigid phase.
    rigid_distances = jnp.abs(random.multivariate_normal(key_rigid_variances, mean=jnp.zeros_like(
        means), cov=jnp.diag((rigid_duration_cov * n_frames)**2)))

    # Calculate transition durations
    transition_slowdown_durations = jnp.abs(random.multivariate_normal(
        key_slope_down_variances, mean=jnp.zeros_like(means), cov=jnp.diag((transition_cov * n_frames)**2)))
    transition_speedup_durations = jnp.abs(random.multivariate_normal(
        key_slope_up_variances, mean=jnp.zeros_like(means), cov=jnp.diag((transition_cov * n_frames)**2)))

    # Phase start and end points
    rigid_starts = (means - rigid_distances).astype(int).flatten()
    rigid_ends = (means + rigid_distances).astype(int).flatten()
    starts_slowing = (means - rigid_distances -
                      transition_slowdown_durations).astype(int).flatten()
    ends_moving = (means + rigid_distances +
                   transition_speedup_durations).astype(int).flatten()

    # Create masks
    def create_mask(start, end):
        nonlocal n_frames
        return jnp.where(jnp.arange(n_frames) < start, 1, 0) + jnp.where(jnp.arange(n_frames) >= end, 1, 0)

    mask = jax.vmap(create_mask)
    rigid_mask = jnp.prod(mask(rigid_starts, rigid_ends), axis=0)
    slowdown_masks = mask(starts_slowing, rigid_starts).astype(float)
    speedup_masks = mask(rigid_ends, ends_moving).astype(float)

    def linsp(mask, start, end, begin_val, carry_fun):
        range = end - start
        def true_fun(carry, x): return (carry_fun(carry, range), 1 - carry)
        def false_fun(carry, x): return (carry, x)
        def f(carry, x): return jax.lax.cond(
            x == 0, true_fun, false_fun, *(carry, x))
        return jax.lax.scan(f, begin_val, mask)[1]

    linsp_desc = jax.vmap(lambda m, s1, s2: linsp(
        m, s1, s2, 0.0, lambda carry, range: carry + 1/range))
    slowdown_mask = jnp.prod(linsp_desc(
        slowdown_masks, starts_slowing, rigid_starts), axis=0)

    linsp_asc = jax.vmap(lambda m, s1, s2: linsp(
        m, s1, s2, 1.0, lambda carry, range: carry - 1/range))
    speedup_mask = jnp.prod(
        linsp_asc(speedup_masks, rigid_ends, ends_moving), axis=0)

    return jnp.min(jnp.stack([rigid_mask, slowdown_mask, speedup_mask]), axis=0)


def random_angles_with_rigid_phases_over_time(
    key_t,
    key_ang,
    T,
    Ts,
    key_rigid_phases,
    n_rigid_phases=3,
    rigid_duration_cov=jnp.array([0.02] * 3),
    transition_cov=jnp.array([0.1] * 3),
    ANG_0=0.0,
    dang_min=0.01,
    dang_max=0.05,
    t_min=0.1,
    t_max=0.5,
    randomized_interpolation=False,
    range_of_motion=False,
    range_of_motion_method="uniform"
):
    mask = motion_amplifier(T, Ts, key_rigid_phases,
                            n_rigid_phases, rigid_duration_cov, transition_cov)

    qs = x_xy.algorithms.random_angle_over_time(
        key_t,
        key_ang,
        ANG_0,
        dang_min,
        dang_max,
        t_min,
        t_max,
        T,
        Ts,
        randomized_interpolation,
        range_of_motion,
        range_of_motion_method
    )

    # derivate qs
    qs_diff = jnp.diff(qs, axis=0)

    # mulitply with motion amplifier
    qs_diff = qs_diff * mask[:-1]

    # integrate qs_diff
    qs_rigid_phases = jnp.concatenate((qs[0:1], jnp.cumsum(qs_diff, axis=0)))
    return qs_rigid_phases


In [None]:
import jax
from x_xy import maths
from x_xy import base

@dataclass
class ExtendedConfig(x_xy.algorithms.RCMG_Config):
    n_rigid_phases : int = 3
    cov_rigid_durations : jax.Array = jnp.array([0.02] * n_rigid_phases)
    cov_transitions : jax.Array = jnp.array([0.1] * n_rigid_phases)
    
    def __post_init__(self):
        assert self.cov_rigid_durations.shape == self.cov_transitions.shape
        
def define_joints():

    
    def _draw_sometimes_rigid(
            config: ExtendedConfig, key_t: jax.random.PRNGKey, key_value: jax.random.PRNGKey
    ) -> jax.Array:
        key_t, key_rigid_phases = jax.random.split(key_t)
        return random_angles_with_rigid_phases_over_time(
            key_t=key_t,
            key_ang=key_value,
            T=config.T,
            Ts=config.Ts,
            key_rigid_phases=key_rigid_phases,
            n_rigid_phases=config.n_rigid_phases,
            rigid_duration_cov=config.cov_rigid_durations,
            transition_cov=config.cov_transitions,
            ANG_0=0,
            dang_min=config.dang_min,
            dang_max=config.dang_max,
            t_min=config.t_min,
            t_max=config.t_max,
            randomized_interpolation=config.randomized_interpolation,
            range_of_motion=config.range_of_motion_hinge,
            range_of_motion_method=config.range_of_motion_hinge_method
        )

    def _rxyz_transform(q, _, axis):
        q = jnp.squeeze(q)
        rot = maths.quat_rot_axis(axis, q)
        return base.Transform.create(rot=rot)

    rsrx_joint = x_xy.algorithms.JointModel(
        lambda q, _: _rxyz_transform(q, _, jnp.array([1.0, 0, 0])), [None], rcmg_draw_fn=_draw_sometimes_rigid
    )
    rsry_joint = x_xy.algorithms.JointModel(
        lambda q, _: _rxyz_transform(q, _, jnp.array([0, 1.0, 0])), [None], rcmg_draw_fn=_draw_sometimes_rigid
    )
    rsrz_joint = x_xy.algorithms.JointModel(
        lambda q, _: _rxyz_transform(q, _, jnp.array([0, 0, 1.0])), [None], rcmg_draw_fn=_draw_sometimes_rigid
    )
    x_xy.algorithms.register_new_joint_type("rsrx", rsrx_joint, 1)
    x_xy.algorithms.register_new_joint_type("rsry", rsry_joint, 1)
    x_xy.algorithms.register_new_joint_type("rsrz", rsrz_joint, 1)
    
try:
    define_joints()
except AssertionError:
    pass

sys = x_xy.io.load_sys_from_str(three_seg_rigid)

Define Problems

Generate random data

In [None]:
from jax import random
from x_xy.algorithms import random_angle_over_time

KEYNUM = 6

problem = problems["some_rigid_phases"]

key_t = random.PRNGKey(KEYNUM)
key_ang = random.PRNGKey(KEYNUM)
key_rigid_phases = random.PRNGKey(KEYNUM)
T = 60
Ts = 0.1

std_config = x_xy.algorithms.RCMG_Config()
extended_config = ExtendedConfig(n_rigid_phases=problem.n, cov_rigid_durations=problem.rigid_cov, cov_transitions=problem.transition_cov)

mask = motion_amplifier(T, Ts, key_t, problem.n, problem.rigid_cov, problem.transition_cov)
angles = random_angles_with_rigid_phases_over_time(key_t=key_t , key_ang=key_ang, key_rigid_phases=key_rigid_phases, T=T, Ts=Ts, n_rigid_phases=problem.n, rigid_duration_cov=problem.rigid_cov, transition_cov=problem.transition_cov)
# angles_normal = random_angle_over_time(key_t=key_t, key_ang=key_ang, T=T, Ts=Ts, ANG_0=0, dang_min=std_config.dang_min, dang_max=std_config.dang_max, t_min=std_config.t_min, t_max=std_config.t_max)

In [None]:
# plot angle data
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
ax.plot(np.arange(0, T, Ts), angles)
ax.plot(np.arange(0, T, Ts), mask * 0.05, color="red")
ax.set(xlabel='time (s)', ylabel='angle', title='Angle over time')
ax.legend(["Angle", "Mask"])
ax.grid()
plt.show()


### Generate X and y

In [None]:
from x_xy import algorithms, utils

config = extended_config
generator = algorithms.build_generator(sys, config)
# we can even batch together multiple generators
# generator = algorithms.batch_generator([generator, generator], [32, 16])
seed = jax.random.PRNGKey(1,)
X, y = generator(seed)

In [None]:
print(X)

Open Pickle file and store data

In [None]:
def load_pickle_params(problem_key):
    from neural_networks import io_params
    pickle_file = f"/data/uk16ural/prism_params/{problem_key}.pickle"
    return io_params.load(pickle_file)

Do inference :yep:

In [None]:
from neural_networks.rnno import rnno_v2

params = load_pickle_params("best_run")

network = rnno_v2(sys)
# initialize the network parameters and the initial state 
# using a random seed `key`
_, state = network.init(key_ang, X)
# then we can call the network with 
yhat, _ = network.apply(params, state, X)

In [48]:
print(yhat["imu1"])
print(y.rot)

[[ 9.9599540e-01 -3.7002610e-04  1.1567527e-03  8.9395858e-02]
 [ 9.9778342e-01 -1.5109566e-04  7.3970878e-04  6.6540517e-02]
 [ 9.9803782e-01 -1.0796409e-04  5.4305728e-04  6.2611155e-02]
 ...
 [ 7.2989804e-01  4.6010475e-05 -3.5368663e-04 -6.8355602e-01]
 [ 7.2990096e-01  4.6007754e-05 -3.5368433e-04 -6.8355298e-01]
 [ 7.2990382e-01  4.6005007e-05 -3.5368092e-04 -6.8354976e-01]]
[[[ 1.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 1.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 1.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 1.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 1.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]]

 [[ 9.9999893e-01 -8.0802810e-04 -1.2521901e-03 -1.6683455e-04]
  [ 9.9999738e-01 -8.0787390e-04 -2.1744401e-03 -1.6757968e-04]
  [ 9.9999738e-01 -8.0787390e-04 -2.1744401e-03 -1.6757968e-04]
  [ 9.9999893e-01 -8.0802810e-04 -1.2521901e-03 -1.6683455e-04]
  [ 9.9999893e-01 -8.0802810e-04 -1.25