In [1]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import torch
import jax.numpy as jnp

sim = True


In [2]:
data_all = np.load('50k_iris_sim_expanded_more4.npy' if sim else '50k_holybro_expanded_more2.npy' )
print(f"Loaded {'50k_iris_sim_expanded_more4.npy' if sim else '50k_holybro_expanded_more2.npy'} because sim is {sim}")

Loaded 50k_iris_sim_expanded_more4.npy because sim is True


In [3]:
class QuadrotorDataset(Dataset):
  def __init__(self, state_size = 9, ctrlinput_size = 4, training=True, data_all = data_all):
    self.input_size = state_size + ctrlinput_size
    self.output_size = ctrlinput_size
    np.random.shuffle(data_all)
    train_size = int(len(data_all)*0.8)
    if training:
      self.data = data_all[0:train_size]
    else:
      self.data = data_all[train_size:]

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    inputs = self.data[idx, 0: self.input_size]
    outputs = self.data[idx, self.input_size:]
    return torch.FloatTensor(inputs), torch.FloatTensor(outputs).view(4)
  
train_dataset = QuadrotorDataset(training=True)
test_dataset = QuadrotorDataset(training=False)


batch_size = 128
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
)

print(f"{train_dataloader.dataset.data.shape = }")
print(f"{test_dataloader.dataset.data.shape = }")


train_dataloader.dataset.data.shape = (40000, 17)
test_dataloader.dataset.data.shape = (10000, 17)


# Flax NNs

In [4]:
import flax
from flax import linen as nn
from typing import Any, Callable, Sequence
from jax import random
import jax

In [5]:
fake_train_dataloader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
)
for data in fake_train_dataloader:
    input_data, output_data = data
    print(input_data.shape)
    print(output_data.shape)
    break    

torch.Size([1, 13])
torch.Size([1, 4])


In [6]:
for data in fake_train_dataloader:
    print(f"{data = }")
    input_data, output_data = data
    print(input_data.shape)
    print(output_data.shape)
    break

data = [tensor([[ 0.2288,  0.2288,  1.1144,  0.1144,  0.1144,  0.1144,  0.7189,  0.7189,
          0.7189, 15.0446,  0.1030,  0.1030,  0.1030]]), tensor([[-1.9420,  0.7672,  2.4488,  0.8013]])]
torch.Size([1, 13])
torch.Size([1, 4])


In [7]:
class FeedForward(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
    #   print(f"{i = }, {feat = }")
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x

model = FeedForward(features=[13, 128, 256, 256, 128, 4])

key1, key2 = random.split(random.key(0), 2)
x = fake_train_dataloader.dataset.data[0][0:13]
params = model.init(key1, x)
y = model.apply(params, x)
print('pred:\n', y[0:5])
print('true:\n', train_dataloader.dataset.data[0][13:])

pred:
 [ 1.2459718  -0.860127    0.13101679  2.3912866 ]
true:
 [-1.94202399  0.76719505  2.44880295  0.8012557 ]


## Flax Training

In [8]:
learning_rate = 0.003  # Gradient step size.

# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
    pred = model.apply(params, x_batched)
    SE = (pred - y_batched)**2
    return jnp.mean(SE)
    
@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

loss_grad_fn = jax.value_and_grad(mse)


In [9]:
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [10]:
for i in range(21):
  for data in train_dataloader:
      x_samples, y_samples = data
      x_samples = jnp.array(x_samples.numpy())  # Convert to jax.numpy array
      y_samples = jnp.array(y_samples.numpy())  # Convert to jax.numpy array
      loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
      updates, opt_state = tx.update(grads, opt_state)
      params = optax.apply_updates(params, updates)

  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  3.3695362e-09
Loss step 10:  4.6185278e-14
Loss step 20:  6.899459e-11


In [13]:
x = fake_train_dataloader.dataset.data[0][0:13]
y = model.apply(params, x)
print('pred:\n', y[0:5])
print('true:\n', train_dataloader.dataset.data[0][13:])

pred:
 [-1.9420134   0.7671896   2.4487896   0.80125034]
true:
 [-1.94202399  0.76719505  2.44880295  0.8012557 ]


In [14]:
test_array = []
for data in test_dataloader:
    x_samples, y_samples = data
    x_samples = jnp.array(x_samples.numpy())  # Convert to jax.numpy array
    y_samples = jnp.array(y_samples.numpy())  # Convert to jax.numpy array
    loss_val, _ = loss_grad_fn(params, x_samples, y_samples)
    print('Test Loss: ', loss_val)
    test_array.append(loss_val)
print('Mean Test Loss: ', np.mean(test_array))

Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11


# Save Params

In [15]:
from flax import serialization
dict_output = serialization.to_state_dict(params)
bytes_output = serialization.to_bytes(params)

In [16]:
params['params']['layers_1']['kernel'][0,:][15:30]

Array([ 0.23633464,  0.08877068, -0.56622684, -0.44590765,  0.41151434,
        0.37425488, -0.37730283,  0.07838764,  0.17399848, -0.3088262 ,
        0.18105392,  0.4158016 ,  0.13710581,  0.2363445 , -0.5415333 ],      dtype=float32)

In [17]:
dict_output['params']['layers_1']['kernel'][0,:][15:30]

Array([ 0.23633464,  0.08877068, -0.56622684, -0.44590765,  0.41151434,
        0.37425488, -0.37730283,  0.07838764,  0.17399848, -0.3088262 ,
        0.18105392,  0.4158016 ,  0.13710581,  0.2363445 , -0.5415333 ],      dtype=float32)

In [18]:
with open("sim_params_expanded3.bin", "wb") as f:
    f.write(bytes_output)

## Load Params

In [19]:
class FeedForwardLoad(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
    #   print(f"{i = }, {feat = }")
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x

modelnew = FeedForwardLoad(features=[13, 128, 256, 256, 128, 4])

key1, key2 = random.split(random.key(0), 2)
x = fake_train_dataloader.dataset.data[0][0:13]
init_params = modelnew.init(key2, x)


In [20]:
fake_train_dataloader.dataset.data[0][13:]

array([-1.94202399,  0.76719505,  2.44880295,  0.8012557 ])

In [21]:
init_params['params']['layers_3']['kernel'][0,:][15:30]

Array([ 0.01019191, -0.00296605, -0.08019079,  0.05023021,  0.04959011,
       -0.02081057, -0.13012654, -0.05989638, -0.04309102,  0.0091062 ,
       -0.0193517 , -0.09684327,  0.04341459,  0.00183117, -0.02519199],      dtype=float32)

In [34]:
# Load from file
with open("sim_params_expanded3.bin", "rb") as f:
    loaded_bytes = f.read()



In [35]:
# Restore parameters from bytes
loaded_params = serialization.from_bytes(init_params, loaded_bytes)

In [36]:
loaded_params['params']['layers_1']['kernel'][0,:][15:30]

array([ 0.23633464,  0.08877068, -0.56622684, -0.44590765,  0.41151434,
        0.37425488, -0.37730283,  0.07838764,  0.17399848, -0.3088262 ,
        0.18105392,  0.4158016 ,  0.13710581,  0.2363445 , -0.5415333 ],
      dtype=float32)

In [37]:
@jax.jit
def mse(params, x_batched, y_batched):
    pred = modelnew.apply(params, x_batched)
    SE = (pred - y_batched)**2
    return jnp.mean(SE)

In [38]:
test_array = []
for data in test_dataloader:
    x_samples, y_samples = data
    x_samples = jnp.array(x_samples.numpy())  # Convert to jax.numpy array
    y_samples = jnp.array(y_samples.numpy())  # Convert to jax.numpy array
    loss_val  = mse(loaded_params, x_samples, y_samples)
    print('Test Loss: ', loss_val)
    test_array.append(loss_val)
print('Mean Test Loss: ', np.mean(test_array))

Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11
Test Loss:  8.901041e-11


In [27]:
@jax.jit
def model_apply(params, x):
    return modelnew.apply(params, x)

import time

In [28]:
time_log = []
for i in range(100):
    t0 = time.time()
    pred = model_apply(params, x)
    time_log.append(time.time() - t0)


In [29]:
np.array(time_log).mean()

np.float64(0.0019253134727478028)

In [30]:
@jax.jit
def model_apply(params, x):
    return modelnew.apply(params, x)

from functools import partial
apply_model = partial(model_apply, params)


import time
position_now = np.array([0.2342, .682, -1.2342,   0.0, 0.0, 0.0,   0.0, 0.0, 0.0])
input_data = np.array([0.324, 0.8953, 0.131, 0.09213])

dataNN = np.concatenate((position_now, input_data))

t0 = time.time()
outputNN = apply_model(dataNN)
print('outputNN: ', outputNN)


outputNN:  [-0.4263151   0.1881113   0.5163594   0.12628472]


In [31]:
for i in range(100):
    t0 = time.time()
    outputNN = apply_model(dataNN)
    print(time.time() - t0)

0.0018281936645507812
0.0004107952117919922
8.916854858398438e-05
0.0017545223236083984
0.00025010108947753906
0.00013685226440429688
6.413459777832031e-05
4.410743713378906e-05
4.673004150390625e-05
4.8160552978515625e-05
6.222724914550781e-05
5.1021575927734375e-05
5.745887756347656e-05
5.435943603515625e-05
5.698204040527344e-05
5.2928924560546875e-05
5.3882598876953125e-05
4.744529724121094e-05
3.552436828613281e-05
3.218650817871094e-05
3.24249267578125e-05
3.1948089599609375e-05
3.6716461181640625e-05
3.886222839355469e-05
3.457069396972656e-05
5.0067901611328125e-05
4.1484832763671875e-05
4.649162292480469e-05
4.172325134277344e-05
4.220008850097656e-05
3.600120544433594e-05
3.62396240234375e-05
4.887580871582031e-05
4.458427429199219e-05
4.124641418457031e-05
4.076957702636719e-05
4.5299530029296875e-05
3.8623809814453125e-05
4.00543212890625e-05
4.00543212890625e-05
3.790855407714844e-05
3.528594970703125e-05
3.5762786865234375e-05
3.600120544433594e-05
3.5762786865234375e-05


In [32]:
# Define a function that takes only the inputs you want the Jacobian with respect to
jac_fn = jax.jit(jacfwd(lambda x: predict_outputs(position_now, x,)))

@jax.jit
def model_output_wrt_input_data(input_data):
    # Concatenate position_now and input_data to form the full input
    dataNN = jnp.concatenate((position_now, input_data))
    return apply_model(dataNN)


# Compute the Jacobian of the output with respect to input_data
jacobian_fn = jax.jit(jax.jacrev(model_output_wrt_input_data))
time_log = []
for i in range(100):
    t0 = time.time()
    jacobian = jacobian_fn(input_data)
    time_log.append(time.time() - t0)


print("Jacobian:\n", jacobian)
print("Computation Time:", np.array(time_log).mean())

NameError: name 'jacfwd' is not defined

In [33]:
        """ Predicts the system output state using a feedforward neural network. """
        position_now = self.state_vector.T.tolist()[0]
        curr_thrust = -last_input[0][0]
        curr_rolldot = last_input[1][0]
        curr_pitchdot = last_input[2][0]
        curr_yawdot = last_input[3][0]
        input_data = [curr_thrust, curr_rolldot, curr_pitchdot, curr_yawdot]
        # print(f"position_now: {position_now}")
        # print(f"input_data: {input_data}")

        # Concatenate state vector and input vector
        state_vector = torch.tensor(position_now, dtype=torch.float32)
        input_vector = torch.tensor(input_data, dtype=torch.float32, requires_grad=True)
        dataNN = torch.cat([state_vector, input_vector])
        # print(f"dataNN: {dataNN}")

        # print("Feed Forward NN")
        # t1 = time.time()
        outputNN = self.NN(dataNN)
        # print(f"outputNN: {outputNN}")

        # Compute Jacobian
        jacobian = torch.zeros((4, 4))
        # print(input_vector.grad)
        # input_vector.grad.zero_()

        for i in range(4):
            self.NN.zero_grad()  # Reset gradients to zero
            if outputNN.grad is not None:
                outputNN.grad.zero_()
            outputNN[i].backward(retain_graph=True)  # Compute gradients
            jacobian[i] = input_vector.grad

        # print("Jacobian Matrix:\n", jacobian)

        inv_jac = np.linalg.inv(jacobian)
        # print(f"inv_jac: {inv_jac}")
        inv_jac[:, 2] = -inv_jac[:, 2]
        # print(f"inv_jac: {inv_jac}")

        self.jac_inv = inv_jac

        outputNN = outputNN.detach().numpy()
        outputNN = np.array([[outputNN[0], outputNN[1], outputNN[2], outputNN[3]]]).T

NameError: name 'self' is not defined