In [1]:
import numpy as np
import torch

from functools import partial

from kusanagi.shell import cartpole
from kusanagi.base import ExperienceDataset, apply_controller
from kusanagi.ghost.control import RandPolicy

from prob_mbrl import utils, models, algorithms, losses, train_regressor
torch.set_num_threads(2)

In [2]:
def forward(states, actions, dynamics, **kwargs):
    deltas, rewards = dyn((states, actions), return_samples=True,
                           separate_outputs=True, **kwargs)
    next_states = states + deltas
    return next_states, rewards

In [3]:
dyn_components = 4
dyn_layers = 2
pol_layers = 2

env = cartpole.Cartpole()
target = torch.tensor([0,0,0,np.pi]).float()
D = target.shape[-1]
U = 1
learn_reward = False
maxU = np.array([10.0])     
angle_dims = torch.tensor([3]).long()
target = utils.to_complex(target, angle_dims)
Da = target.shape[-1]
Q = torch.zeros(Da, Da).float()
Q[0, 0] = 1
Q[0, -2] = env.l
Q[-2, 0] = env.l
Q[-2, -2] = env.l**2
Q[-1, -1] = env.l**2
Q /= 0.1
def reward_func(states, target, Q, angle_dims):
    states = utils.to_complex(states, angle_dims)
    reward = losses.quadratic_saturating_loss(states, target, Q)
    return reward

dynE = 2*(D+1) if learn_reward else 2*D
reward_func = None if learn_reward else partial(reward_func, target=target, Q=Q, angle_dims=angle_dims)
dyn = models.DynamicsModel(
        models.dropout_mlp(
            Da+U, (dynE+1)*dyn_components, [200]*dyn_layers,
            nonlin=torch.nn.ReLU,
            dropout_layers=[models.modules.CDropout(0.5, 0.1)]*dyn_layers
        ),
        reward_func=reward_func, angle_dims=angle_dims,
        output_density=models.MixtureDensity(dynE/2, dyn_components)
    ).float()

pol = models.Policy(
    models.dropout_mlp(
        Da, U, [200]*pol_layers,
        nonlin=torch.nn.ReLU,
        output_nonlin=torch.nn.Tanh,
        dropout_layers=[models.modules.BDropout(0.5)]*pol_layers),
    maxU, angle_dims=angle_dims).float()
randpol = RandPolicy(maxU)
exp = ExperienceDataset()
params = filter(lambda p: p.requires_grad, pol.parameters())
opt = torch.optim.Adam(params, 1e-3, amsgrad=True)

forward_fn = partial(forward, dynamics=dyn)


[2018-07-11 19:15:51.740464] Experience > Initialising new experience dataset


In [4]:
#%matplotlib qt
def cb(*args, **kwargs):
    env.render()

H = 25
N_particles = 100
for rand_it in range(1):
    ret = apply_controller(env, randpol, H, callback=None)
    exp.append_episode(*ret)

[2018-07-11 19:15:51.758118] apply_controller > Starting run
[2018-07-11 19:15:51.759424] apply_controller > Running for 2.500000 seconds
[2018-07-11 19:15:51.902988] apply_controller > Done. Stopping robot. Value of run [24.976482]
[2018-07-11 19:15:51.904354] Cartpole > Stopping robot


In [None]:
for ps_it in range(100):
    # apply policy
    ret = apply_controller(env, pol, H, callback=None)
    exp.append_episode(*ret)

    # train dynamics
    X, Y = exp.get_dynmodel_dataset(deltas=True, return_costs=learn_reward)
    dyn.set_dataset(torch.tensor(X).to(dyn.X.device).float(), torch.tensor(Y).to(dyn.X.device).float())  
    train_regressor(dyn, 5000, N_particles, True, log_likelihood=losses.gaussian_mixture_log_likelihood)
    x0 = torch.tensor(exp.sample_states(N_particles, timestep=0)).to(dyn.X.device).float()
    x0 += 1e-2*x0.std(0)*torch.randn_like(x0)
    utils.plot_rollout(x0, forward_fn, pol, H)
    
    # train policy
    print "Policy search iteration %d" % (ps_it+1)
    algorithms.mc_pilco(x0, forward_fn, dyn, pol, H, opt, opt_iters=1000, exp=exp,
             maximize=False, pegasus=False, mm_states=False,
             mm_rewards=False, mpc=False, max_steps=25)
    utils.plot_rollout(x0, forward_fn, pol, H)
    

[2018-07-11 19:15:52.006581] apply_controller > Starting run
[2018-07-11 19:15:52.008157] apply_controller > Running for 2.500000 seconds
[2018-07-11 19:15:52.137672] apply_controller > Done. Stopping robot. Value of run [24.993557]
[2018-07-11 19:15:52.138901] Cartpole > Stopping robot


log-likelihood of data: 0.126448:  37%|███▋      | 1835/5000 [00:14<00:24, 129.67it/s] 