In [2]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
from bsuite.environments import cartpole

from mbrlax.models import GPModelSpec, SVGP, initialize_gp_model
from mbrlax.utils import ReplayBuffer, Driver, EnvironmentModel
from mbrlax.policy import GPPolicy
from mbrlax.transition_model import GPTransitionModel
from mbrlax.optimizers import SGD
from mbrlax.utils.initial_state_model import ParticleInitialStateModel

from gpjax.likelihoods import Gaussian
from gpjax.parameters import build_constrain_params
from gpjax.datasets import CustomDataset, NumpyLoader
from gpjax.config import default_float

from gpflow_pilco.envs import CartPole

import jax
import jax.numpy as jnp

import tensorflow as tf
import optax

  lax_internal._check_user_dtype_supported(dtype, "array")


ImportError: cannot import name 'policy_loss' from 'mbrlax.utils.loss_functions' (/Users/karyam/Desktop/mbrlax/mbrlax/utils/loss_functions.py)

In [13]:
def func(params):
    return params[0] * 2 + params[1] * 3

def nested_grad(blah, params):
    def step(params, tmp):
        res = func(params)
        return params, res
        
    _, scan_out = jax.lax.scan(
        step,
        params,
        jnp.zeros((3,))
    )
    return jnp.sum(scan_out)


In [21]:
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
key3, key4 = jax.random.split(key)
assert key2.all() == key4.all()
x = jax.random.uniform(key1)
y = jax.random.uniform(key3)
print(x)
print(y)

0.48026035987734006
0.48026035987734006


In [14]:
params = jnp.array([1.0,2.0])
value, grad = jax.value_and_grad(nested_grad, argnums=1)("test", params)
grad

DeviceArray([6., 9.], dtype=float64)

In [2]:
dtype = default_float()
seed = 42
initial_lr = 1e-3
key = jax.random.PRNGKey(seed)
cartpole_env = cartpole.Cartpole(seed=seed)
# cartpole_env = CartPole(time_per_step=0.1)
# action_space = cartpole_env.action_space
action_space = jnp.array([0,1,2])

## Collect experience (states, actions) and format it as input to the GP.

In [3]:
class RandomPolicy:
    def __init__(self, key, action_space):
        self.key = key
        self.action_space = action_space

    def step(self, time_step, mode=None):
        return jax.random.choice(key=self.key, a=self.action_space)
        # return self.action_space.sample()

In [4]:
random_policy = RandomPolicy(rng, action_space)
replay_buffer = ReplayBuffer(5000)
driver = Driver(
    mode="random",
    env=cartpole_env,
    policy=random_policy,
    transition_observers=[replay_buffer.push],
    observers=[],
    max_steps=30
)
driver.run(cartpole_env.reset())
experience = replay_buffer.gather_all()

In [5]:
def transition_model_optimizer_callback(epoch, loss_history):
    if epoch % 20 == 0:
        clear_output(True)
        plt.figure(figsize=[16, 8])
        plt.subplot(1, 2, 1)
        plt.title("Mean ELBO = %.3f" % -jnp.mean(jnp.array(loss_history[-32:])))
        plt.scatter(jnp.arange(len(loss_history)), jnp.array(loss_history)*-1.0)
        plt.grid()
        plt.show()

In [18]:
sgd_optimizer = SGD(
    optimizer=optax.adam(initial_lr),
    callback=transition_model_optimizer_callback
)

model_spec = GPModelSpec(
    type=SVGP,
    num_inducing=32,
    likelihood=Gaussian(),
    model_uncertainty=True,
)

transition_model = GPTransitionModel(
    gp_model_spec = model_spec,
    inference_strategy = None,
    optimizer = sgd_optimizer,
    reinitialize = True
)
data = transition_model.get_gp_data(experience)
transition_model.initialize(experience)

dict_keys(['kernel', 'likelihood', 'mean_function', 'inducing_variable', 'q_mu', 'q_sqrt'])


## Train model

In [21]:
start_learning_rate = 1e-3
batch_size = 60
num_epochs = 900
model_params = transition_model.model.get_params()

In [22]:
inputs, targets = data
training_data = CustomDataset(inputs, targets)
train_dataloader = NumpyLoader(training_data, batch_size=batch_size, shuffle=True)

### Configure transition model's optimizer based on custom loss and training step

In [None]:
# svgp_transforms = transition_model.model.get_transforms()
# constrain_params = build_constrain_params(svgp_transforms)
# elbo = transition_model.model.build_elbo(constrain_params=constrain_params, num_data=experience[0].shape[0])

# def negative_elbo(params, batch):
#     return - elbo(params, batch)

# adam = optax.adam(start_learning_rate)

# @jax.jit
# def train_step(step_i, params, opt_state, batch):
#     loss, grads = jax.value_and_grad(negative_elbo, argnums=0)(params, batch)
#     updates, opt_state = adam.update(grads, opt_state)
#     return loss, updates, opt_state

# transition_model.optimizer.set_train_step(train_step)

In [3]:

import jax.numpy as jnp
dones = jnp.array([0,0,1,0,0])
ep_mask = (jnp.cumsum(dones) < 1).reshape(5, 1)
ep_mask

DeviceArray([[ True],
             [ True],
             [False],
             [False],
             [False]], dtype=bool)

In [None]:
# transition_model.train(experience)

In [None]:
# opt_state = adam.init(model_params)
# params = model_params
# loss_history = []

# for epoch in range(num_epochs):
#     loss, updates, opt_state = train_step(epoch, params, opt_state, data)
#     params = optax.apply_updates(params, updates)
#     loss_history.append(loss)

#     if epoch % 20 == 0:
#         clear_output(True)
#         plt.figure(figsize=[16, 8])
#         plt.subplot(1, 2, 1)
#         plt.title("Mean ELBO = %.3f" % -jnp.mean(jnp.array(loss_history[-32:])))
#         plt.scatter(jnp.arange(len(loss_history)), jnp.array(loss_history)*-1.0)
#         plt.grid()
#         plt.show()

In [23]:
mean, cov = transition_model.model.predict_f(model_params, inputs)
mean.shape, cov.shape



((30, 6), (30, 6))

In [24]:

samples = sample_mvn(rng, mean, cov)
samples.shape

(30, 6)

In [None]:
inputs, targets = transition_model.get_gp_data(experience)
inputs.shape, targets.shape

((90, 7), (90, 6))

## Test policy

In [14]:
from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

In [4]:
from mbrlax.inference_strategy import ConditionalSamplingStrategy
from mbrlax.utils import sample_mvn
inference_strategy = ConditionalSamplingStrategy(key=key, sampling_strategy=sample_mvn)

In [None]:
def long_term_cost(params, max_steps, time_step):
    cost = 0
    for step in range(max_steps):
        cost += gaussian_objective(time_step.observation)
        action = policy.step(time_step, mode="plan")
        next_time_step = agent.environment_model.step(action)
        time_step = next_time_step
    return cost

In [None]:


policy_optimizer = CMAOptimizer(
    key=key,
    fitness_function, 
    num_generations, 
    pop_size, 
    num_params, 
    callback=None
)

In [15]:
#TODO: cast with default_float
invlink = tfb.Chain(bijectors=[
    tfb.Scale(scale=20-1e-5),
    tfb.Shift(shift=-0.5),
    tfb.NormalCDF()]
)

policy_model_spec = GPModelSpec(
    type=SVGP,
    num_inducing=32,
    likelihood=Gaussian(),
    prior=None,
    mean_function="default",
    model_uncertainty=False,
    invlink = invlink
)

policy = GPPolicy(
    action_space = jnp.array([0,1,2]),
    gp_model_spec=policy_model_spec,
    optimizer=policy_optimizer,
    inference_strategy=inference_strategy
)

NameError: name 'policy_optimizer' is not defined

In [None]:
state_scale = jnp.diag(jnp.array([0.1, 0.1, 0.1, 0.1], dtype=dtype))
state_loc = jnp.array([0.0, jnp.pi, 0.0, 0.0], dtype=dtype)
initial_state_distribution = tfd.MultivariateNormalTriL(loc=state_loc, scale_tril=state_scale)
initial_state_model = ParticleInitialStateModel(initial_state_distribution)
initial_obs = initial_state_model.sample(rng, 128)
assert(initial_obs.shape == (128, 4))

In [None]:
environment_model = EnvironmentModel(
    transition_model=transition_model,
    reward_model=reward_model,
    initial_state_model=initial_state_model
)

virtual_driver = Driver(
    mode="plan",
    env=environment_model,
    policy=policy,
    transition_observers=[virtual_replay_buffer.push],
    observers=[],
    max_steps=30
)

In [None]:
policy.initialize(experience)
virtual_driver.run(environment_model.reset())
virtual_experience = virtual_replay_buffer.get_last_n(virtual_driver.max_steps)
result = policy.train(virtual_experience)