# Testing the PILCO framework

In [1]:
# %load ~/dev/marthaler/header.py
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

%load_ext autoreload
%autoreload 2

In [2]:
# Enable Float64 for more stable matrix inversions.
import jax
from jax import Array, config
import jax.numpy as jnp
import numpy as np
import jax.random as jr
from jaxtyping import ArrayLike, install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt

config.update("jax_enable_x64", True)
import gpjax as gpx


key = jr.key(123)

cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

  torch.utils._pytree._register_pytree_node(cls, tree_flatten, tree_unflatten)


# Functions

In [3]:
# Function to sample from a single mean and covariance
def sample_mvnormal(key, mean, cov, num_samples):
    return jr.multivariate_normal(key, mean, cov, (num_samples,))

# Vectorize the sampling function
vectorized_sample = jax.vmap(sample_mvnormal, in_axes=(None, 0, 0, None))

In [4]:
def cart_pole_cost(
    states_sequence: ArrayLike, 
    target_state: ArrayLike = jnp.array([jnp.pi,0.0]), 
    lengthscales: ArrayLike = jnp.array([3.0,1.0]),
    angle_index:int = 2,
    pos_index:int = 0
)->Array:
    """
    Cost function given by the combination of the saturated distance between |theta| and 'target angle', and between x and 'target position'.
    """
    x = states_sequence[pos_index]
    theta = states_sequence[angle_index]

    target_x = target_state[1]
    target_theta = target_state[0]

    return 1 - jnp.exp(
        -(jnp.square((jnp.abs(theta) - target_theta) / lengthscales[0])) - jnp.square((x - target_x) / lengthscales[1])
    )

## Generate the environment

In [5]:
import gymnasium as gym
env = gym.make("InvertedPendulum-v5")

In [6]:
action_dim = env.action_space.shape[0]

In [7]:
initial_state_exploration, _ = env.reset()
state_dim = initial_state_exploration.shape[0]

In [8]:
# Initialize a random controller
from controllers import RandomController
policy = RandomController(state_dim,action_dim,True,3.0)

In [9]:
explore_timesteps = 10

In [10]:
X = []
Y = []
ep_return_full = 0
ep_return_sampled = 0
key = jr.key(42)
x = initial_state_exploration.copy()
for timestep in range(explore_timesteps):
    key, subkey = jr.split(key)
    u = policy.compute_action(x,timestep,subkey)
    #print(u)
    z = env.step(np.array(u))
    #print(z)
    #x_new, r, done, _, __ = env.step(np.array(u))
    x_new = z[0]
    r = z[1]
    X.append(jnp.hstack((x, u)))
    Y.append(x_new - x)
    ep_return_sampled += r
    x = x_new
X = jnp.array(X)
Y = jnp.array(Y)

In [11]:
D = gpx.Dataset(X=X, y=Y)

In [12]:
from model_learning.mgpr import DynamicalModel

In [13]:
model = DynamicalModel(data=D)

In [14]:
model.optimize()

  0%|          | 0/1000 [00:00<?, ?it/s]

  torch.utils._pytree._register_pytree_node(cls, tree_flatten, tree_unflatten)
  torch.utils._pytree._register_pytree_node(cls, tree_flatten, tree_unflatten)


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [15]:
# Now do a rollout with this model

# Generate an initial state
x0, _ = env.reset()
key, subkey = jr.split(key)
# Generate an initial action
u0 = policy.compute_action(x,timestep,subkey)
initial_state = jnp.hstack((x, u)).reshape(1,-1)
# Compute the moments from the trained GP transition function
predictive_moments = model.predict_all_outputs(initial_state)

  torch.utils._pytree._register_pytree_node(cls, tree_flatten, tree_unflatten)


In [16]:
# initialize some particles
num_particles = 100
init_samples = jnp.squeeze(
    vectorized_sample(
        key, predictive_moments[:,:,0], jax.vmap(jnp.diag)(predictive_moments[:,:,1]), num_particles
    )
)

In [17]:
from gpjax.parameters import (
    Parameter,
)

In [18]:
type(policy)

controllers.RandomController

In [19]:
from flax import nnx

In [20]:
graphdef, params, *static_state = nnx.split(policy, Parameter, ...)

In [21]:
import typing as tp
from flax import nnx
Model = tp.TypeVar("Model", bound=nnx.Module)
from gpjax.typing import (
    Array,
    KeyArray,
    ScalarFloat,
)

In [22]:
# def scan(f, init, xs, length=None):
#   if xs is None:
#     xs = [None] * length
#   carry = init
#   ys = []
#   for x in xs:
#     carry, y = f(carry, x)
#     ys.append(y)
#   return carry, np.stack(ys)

In [93]:
def rollout_all(
    policy,#: Controller,
    samples: ArrayLike,
    model: Model,
    timesteps: ArrayLike,
    key: KeyArray = jr.PRNGKey(42),
    time_horizon: int=50,
    num_particles:int=100
)->ScalarFloat:
    # Should wrap this in scan so we can jit compile it?
    costs = []
    for t in range(timestep+1,timestep+time_horizon):
        # Now generate some actions to take for these states
        key, *subkeys = jr.split(key,num_particles+1)
        u = jax.vmap(policy.compute_action)(samples,jnp.tile(t,num_particles),jnp.array(subkeys))
        this_state = jnp.hstack((samples, u))
        # should compute rewards here?
        predictive_moments = model.predict_all_outputs(this_state)
        # Generate samples
        key, subkey = jr.split(key)
        samples = jnp.squeeze(
            vectorized_sample(
                key, predictive_moments[:,:,0], jax.vmap(jnp.diag)(predictive_moments[:,:,1]), 1
            )
        )
        costs.append(jnp.sum(jax.vmap(cart_pole_cost)(samples)))
    return jnp.mean(jnp.array(costs))

In [94]:
rollout_all(policy,init_samples,model,timesteps)

Array(52.282467, dtype=float64)

In [50]:
from jax.tree_util import Partial

In [85]:
num_particles=100
def one_rollout_step(carry, t):
    compute_action, predict_all_outputs, key, samples, total_cost = carry
    key, *subkeys = jr.split(key,num_particles+1)
    u = jax.vmap(compute_action)(samples,jnp.tile(t,num_particles),jnp.array(subkeys))
    this_state = jnp.hstack((samples, u))
    predictive_moments = predict_all_outputs(this_state)
    key, subkey = jr.split(key)
    samples = jnp.squeeze(
        vectorized_sample(
            key, predictive_moments[:,:,0], jax.vmap(jnp.diag)(predictive_moments[:,:,1]), 1
        )
    )
    cost = jnp.sum(jax.vmap(cart_pole_cost)(samples))
    return (compute_action,predict_all_outputs,key,samples,total_cost+cost), cost

In [86]:
def rollout(
    policy,#: Controller,
    model,#Model
    init_samples: ArrayLike,
    timesteps: ArrayLike,
    key: KeyArray = jr.PRNGKey(42),
)->ScalarFloat:
    action = Partial(policy.compute_action)
    pao = Partial(model.predict_all_outputs)
    (action,pao,key,samples,total_cost), result = jax.lax.scan(
        one_rollout_step, (action,pao,key,init_samples,0), timesteps
    )
    return total_cost/len(timesteps)

In [87]:
time_horizon = 50
timesteps = jnp.arange(timestep+1,timestep+time_horizon)
cost = rollout(policy,model,init_samples,timesteps) 

In [88]:
cost

Array(52.282467, dtype=float64)

should partial map the above loop so that a given set of controller parameters are passed in.
This would change the "Compute action" function so that we need the parameters at each call.  So maybe the controller won't be a class...

In [None]:
def fit(  # noqa: PLR0913
    *,
    model: Model,
    objective: Objective,
    train_data: Dataset,
    optim: ox.GradientTransformation,
    params_bijection: tp.Union[dict[Parameter, Bijector], None] = DEFAULT_BIJECTION,
    key: KeyArray = jr.PRNGKey(42),
    num_iters: int = 100,
    batch_size: int = -1,
    log_rate: int = 10,
    verbose: bool = True,
    unroll: int = 1,
    safe: bool = True,
) -> tuple[Model, jax.Array]:
    r"""Train a Module model with respect to a supplied objective function.
    Optimisers used here should originate from Optax.

    Example:
    ```pycon
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> import optax as ox
        >>> import gpjax as gpx
        >>> from gpjax.parameters import PositiveReal, Static
        >>>
        >>> # (1) Create a dataset:
        >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
        >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape)
        >>> D = gpx.Dataset(X, y)
        >>> # (2) Define your model:
        >>> class LinearModel(nnx.Module):
        >>>     def __init__(self, weight: float, bias: float):
        >>>         self.weight = PositiveReal(weight)
        >>>         self.bias = Static(bias)
        >>>
        >>>     def __call__(self, x):
        >>>         return self.weight.value * x + self.bias.value
        >>>
        >>> model = LinearModel(weight=1.0, bias=1.0)
        >>>
        >>> # (3) Define your loss function:
        >>> def mse(model, data):
        >>>     pred = model(data.X)
        >>>     return jnp.mean((pred - data.y) ** 2)
        >>>
        >>> # (4) Train!
        >>> trained_model, history = gpx.fit(
        >>>     model=model, objective=mse, train_data=D, optim=ox.sgd(0.001), num_iters=1000
        >>> )
    ```

    Args:
        model (Model): The model Module to be optimised.
        objective (Objective): The objective function that we are optimising with
            respect to.
        train_data (Dataset): The training data to be used for the optimisation.
        optim (GradientTransformation): The Optax optimiser that is to be used for
            learning a parameter set.
        num_iters (int): The number of optimisation steps to run. Defaults
            to 100.
        batch_size (int): The size of the mini-batch to use. Defaults to -1
            (i.e. full batch).
        key (KeyArray): The random key to use for the optimisation batch
            selection. Defaults to jr.PRNGKey(42).
        log_rate (int): How frequently the objective function's value should
            be printed. Defaults to 10.
        verbose (bool): Whether to print the training loading bar. Defaults
            to True.
        unroll (int): The number of unrolled steps to use for the optimisation.
            Defaults to 1.

    Returns:
        A tuple comprising the optimised model and training history.
    """
    if safe:
        # Check inputs.
        _check_model(model)
        _check_train_data(train_data)
        _check_optim(optim)
        _check_num_iters(num_iters)
        _check_batch_size(batch_size)
        _check_log_rate(log_rate)
        _check_verbose(verbose)

    # Model state filtering

    graphdef, params, *static_state = nnx.split(model, Parameter, ...)

    # Parameters bijection to unconstrained space
    if params_bijection is not None:
        params = transform(params, params_bijection, inverse=True)

    # Loss definition
    def loss(params: nnx.State, batch: Dataset) -> ScalarFloat:
        params = transform(params, params_bijection)
        model = nnx.merge(graphdef, params, *static_state)
        return objective(model, batch)

    # Initialise optimiser state.
    opt_state = optim.init(params)

    # Mini-batch random keys to scan over.
    iter_keys = jr.split(key, num_iters)

    # Optimisation step.
    def step(carry, key):
        params, opt_state = carry

        if batch_size != -1:
            batch = get_batch(train_data, batch_size, key)
        else:
            batch = train_data

        loss_val, loss_gradient = jax.value_and_grad(loss)(params, batch)
        updates, opt_state = optim.update(loss_gradient, opt_state, params)
        params = ox.apply_updates(params, updates)

        carry = params, opt_state
        return carry, loss_val

    # Optimisation scan.
    scan = vscan if verbose else jax.lax.scan

    # Optimisation loop.
    (params, _), history = scan(step, (params, opt_state), (iter_keys), unroll=unroll)

    # Parameters bijection to constrained space
    if params_bijection is not None:
        params = transform(params, params_bijection)

    # Reconstruct model
    model = nnx.merge(graphdef, params, *static_state)

    return model, history
