In [1]:
import numpy as np
import flax
from flax import linen as nn
from typing import Any, Callable, Sequence
from jax import random
import jax
from flax import serialization
import jax.numpy as jnp


In [14]:
# Load from file
with open("sim_params.bin", "rb") as f:
    loaded_bytes = f.read()

In [15]:
STATE = np.array([ 2.49944982,  1.60042476, -1.90054772, -0.09470656,  0.61739224,
        0.64762334,  0.17935197,  0.13538294, -0.05906635, 11.50631109,
        0.67154372,  0.71453788,  0.81230723])

INPUT = np.array([ 1.87724153,  2.57925116, -0.31868312,  0.59077944])

In [16]:
class FeedForwardLoad(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
    #   print(f"{i = }, {feat = }")
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x

modelnew = FeedForwardLoad(features=[13, 128, 256, 256, 128, 4])

key1, key2 = random.split(random.key(0), 2)
x = STATE
init_params = modelnew.init(key2, x)


In [17]:
# Restore parameters from bytes
loaded_params = serialization.from_bytes(init_params, loaded_bytes)

In [18]:
sim = True
data_all = np.load('50k_iris_sim.npy' if sim else '50k_holybro.npy' )
print(f"Loaded {'50k_iris_sim.npy' if sim else '50k_holybro.npy'} because sim is {sim}")

Loaded 50k_iris_sim.npy because sim is True


In [19]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
class QuadrotorDataset(Dataset):
  def __init__(self, state_size = 9, ctrlinput_size = 4, training=True, data_all = data_all):
    self.input_size = state_size + ctrlinput_size
    self.output_size = ctrlinput_size
    np.random.shuffle(data_all)
    train_size = int(len(data_all)*0.8)
    if training:
      self.data = data_all[0:train_size]
    else:
      self.data = data_all[train_size:]

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    inputs = self.data[idx, 0: self.input_size]
    outputs = self.data[idx, self.input_size:]
    return torch.FloatTensor(inputs), torch.FloatTensor(outputs).view(4)
  
train_dataset = QuadrotorDataset(training=True)
test_dataset = QuadrotorDataset(training=False)


batch_size = 128
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
)

print(f"{train_dataloader.dataset.data.shape = }")
print(f"{test_dataloader.dataset.data.shape = }")


train_dataloader.dataset.data.shape = (40000, 17)
test_dataloader.dataset.data.shape = (10000, 17)


In [20]:
test_array = []
@jax.jit
def mse(params, x_batched, y_batched):
    pred = modelnew.apply(params, x_batched)
    SE = (pred - y_batched)**2
    return jnp.mean(SE)

for data in test_dataloader:
    x_samples, y_samples = data
    x_samples = jnp.array(x_samples.numpy())  # Convert to jax.numpy array
    y_samples = jnp.array(y_samples.numpy())  # Convert to jax.numpy array
    loss_val  = mse(loaded_params, x_samples, y_samples)
    print('Test Loss: ', loss_val)
    test_array.append(loss_val)
print('Mean Test Loss: ', np.mean(test_array))

Test Loss:  0.0019400252
Test Loss:  0.0017135364
Test Loss:  0.0014636364
Test Loss:  0.0017905764
Test Loss:  0.0017181349
Test Loss:  0.001804505
Test Loss:  0.0016288267
Test Loss:  0.0018577299
Test Loss:  0.001627598
Test Loss:  0.0017183679
Test Loss:  0.0018232509
Test Loss:  0.0019828733
Test Loss:  0.0019792053
Test Loss:  0.0017084738
Test Loss:  0.0018759171
Test Loss:  0.0018061978
Test Loss:  0.0018558702
Test Loss:  0.0019150736
Test Loss:  0.0018489601
Test Loss:  0.0017837841
Test Loss:  0.0015791996
Test Loss:  0.0015917886
Test Loss:  0.0017853761
Test Loss:  0.00181926
Test Loss:  0.0017519521
Test Loss:  0.0019155019
Test Loss:  0.001638169
Test Loss:  0.002171372
Test Loss:  0.0016295933
Test Loss:  0.0018101349
Test Loss:  0.0019726965
Test Loss:  0.0018853527
Test Loss:  0.001667918
Test Loss:  0.0019204024
Test Loss:  0.0022012452
Test Loss:  0.0017729761
Test Loss:  0.001711145
Test Loss:  0.0015479433
Test Loss:  0.0020177239
Test Loss:  0.0018045647
Test Los

Test Loss:  0.001703923
Test Loss:  0.0019247064
Test Loss:  0.001674507
Test Loss:  0.0019294166
Test Loss:  0.0018371665
Test Loss:  0.0017438885
Test Loss:  0.0017160224
Test Loss:  0.0024745523
Test Loss:  0.0016715796
Test Loss:  0.0019238127
Test Loss:  0.0018227017
Test Loss:  0.0017470929
Test Loss:  0.0018098522
Test Loss:  0.0017143105
Test Loss:  0.0019217604
Test Loss:  0.0016020726
Mean Test Loss:  0.0018065559


# Make Compiled NN Functions


In [21]:
from functools import partial
def model_apply(params, state, ctrl):
    x = jnp.concatenate((state, ctrl))
    return modelnew.apply(params, x)
apply_model = jax.jit(partial(model_apply, loaded_params))

In [22]:
import time
time_log = []
for i in range(16):
    x_samples, y_samples = data
    state1 = x_samples[i, 0:9]
    state1 = jnp.array(state1.numpy())
    ctrl1 = x_samples[i, 9:]
    ctrl1 = jnp.array(ctrl1.numpy())

    t0 = time.time()
    apply_model(state1, ctrl1)
    time_log.append(time.time() - t0)

print('Mean inference time: ', np.mean(time_log))

Mean inference time:  0.007016882300376892


In [23]:
from jax import jacfwd

@jax.jit
def compute_jacobian(state, ctrl):
    return jacfwd(lambda x: apply_model(state, x))(ctrl)

In [24]:
jacobian_time_log = []

for i in range(16):
    x_samples, y_samples = data
    state1 = x_samples[i, 0:9]
    state1 = jnp.array(state1.numpy())
    ctrl1 = x_samples[i, 9:]
    ctrl1 = jnp.array(ctrl1.numpy())

    t0 = time.time()
    jac = compute_jacobian(state1, ctrl1)
    jacobian_time_log.append(time.time() - t0)

print('Mean jacobian computation time: ', np.mean(jacobian_time_log))

Mean jacobian computation time:  0.007231995463371277


In [25]:
jac

Array([[ 0.00417626,  0.09946538, -0.36449325, -0.09978978],
       [ 0.01697065,  0.30175763,  0.06594227, -0.00204584],
       [-0.15662749,  0.04095466, -0.02830148, -0.010667  ],
       [-0.0039165 ,  0.02387397,  0.03821919,  0.8509513 ]],      dtype=float32)