In [None]:
import jax
from jax import jit, vmap
import jax.numpy as jnp
import flax.linen as nn
from flax.training.train_state import TrainState
import matplotlib.pyplot as plt
import numpy as np

from pncbf.networks.mlp import MLP
from pncbf.networks.ncbf import SingleValueFn
from pncbf.networks.optim import get_default_tx

import dill

### Network Architecture

In [None]:
def mlp():

    mlp_fn = lambda: MLP(
        hid_sizes = (256, 256, 256),
        act = nn.tanh,
        act_final = True # act func() --> output layer
    )

    # wrap MLP with SingleValueFn for CBF output
    # SingleValueFn --> ensures the network produce a single 
    # scalar output
    return SingleValueFn(net_cls=mlp_fn)

### Loss Function

In [None]:
@jit
def pncbf_loss_fn(predicted_values, target_values):

    loss = jnp.mean(jnp.square(predicted_values - target_values))

    return loss

In [None]:
def train_pncbf(states, max_violations, network, learning_rates=1e-3, batch_size=128, epochs=1000):
    
    key = jax.random.PRNGKey(42)
    dummy_input = states[0:1] # a sample for jax initialization
    params = network.init(key, dummy_input)

    tx = get_default_tx(learning_rates) # --> adamW optimizer

    state = TrainState.create(
        apply_fn=network.apply,
        params=params,
        tx=tx
    )
    
    @jit
    def train_step(state, batch_states, batch_values):
        def loss_fn(params):
            predicted_values = state.apply_fn(params, batch_states)
            return pncbf_loss_fn(predicted_values, batch_values)
            # return pncbf_loss_fn(
            #     params, 
            #     batch_states, 
            #     batch_values, 
            #     lambda p, x: state.apply_fn(p, x)
            #     ) 
        
        # compute gradients and update parameters
        grad_fn = jax.value_and_grad(loss_fn) 
        loss, grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)

        return state, loss
    
    losses = []
    n_samples = len(states)
    steps_per_epoch = n_samples // batch_size 

    for epoch in range(epochs):
        # shuffle data at each epoch
        perm = jax.random.permutation(key, n_samples)
        # new random key for subsequent opers
        key, _ = jax.random.split(key)

        epoch_losses = []
        for step in range(steps_per_epoch):
            batch_indices = perm[step * batch_size:(step + 1) * batch_size]
            batch_states = states[batch_indices]
            batch_values = max_violations[batch_indices]

            state, loss = train_step(state, batch_states, batch_values)
            epoch_losses.append(loss)

        avg_loss = np.mean(epoch_losses) # avg loss across all batches
        losses.append(avg_loss)

        if epoch % 100 == 0:
            print(f"epoch {epoch}, loss: {avg_loss:.6f}")

    return state, losses

In [None]:
data = np.load('segway_training_data_10k.npy', allow_pickle=True).item()
states = data['states']
max_violations = data['violations']

MLPnn = mlp()

trained_state, losses = train_pncbf(states, max_violations, MLPnn, epochs = 1200)

In [None]:
print(f"final losses: {losses[-1]:.7f}")

In [None]:
np.save('segway_mlp_losses.npy', losses)
print("losses saved!")

# with open('pncbf_model_segway.pkl', 'wb') as f:
#     dill.dump(trained_state, f)

params_only = trained_state.params
with open('segway_mlp_model.pkl', 'wb') as f:
    dill.dump(params_only, f)

print("model saved!")
