In [1]:
from jax import numpy as jnp
from jax import random, grad, vmap, jit, tree_multimap
from jax.experimental import stax, optimizers
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax

import keras

import gym
from gym import wrappers

import io
import base64
from IPython.display import HTML

from utils import pwl

import numpy as np
import pandas as pd
import sys, os, math, time, pickle, itertools
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
## verify Jax is using the GPU
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [3]:
## Instantiate global rng for reproducability
rng = random.PRNGKey(0)

In [11]:
num_rollouts = 5

expert_model_path = "./LunarLander-v2-config.json"
expert_weights_path = "./LunarLander-v2-weights.h5"
data_path = ""
loss_type = 'celoss'

mean_rewards = []
stds = []
main_returns = []

In [5]:
# Define a compiled update step
@jit
def step(i, opt_state, batch):
    x1, y1 = batch
    p = get_params(opt_state)
    g = grad(loss)(p, x1, y1)
    return opt_update(i, g, opt_state)

In [6]:
print('loading and building expert policy')

with open(expert_model_path, 'r') as f:
    policy_fn = keras.models.model_from_json(f.read())
policy_fn.load_weights(expert_weights_path)   

print('loaded and built')

#task_data = load_task_data("{}.pkl".format(data_path))
env = gym.make('LunarLander-v2')


loading and building expert policy
Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


loaded and built


In [7]:
init_obs = env.reset()
init_action = policy_fn.predict(np.array([init_obs]))
obs_data = jnp.array([init_obs])
act_data = jnp.array([init_action])

act_data = act_data.reshape(act_data.shape[0], act_data.shape[2])







In [8]:
## set up network
net_init, net_apply, net_walk = pwl(
    Dense(1024), Relu,
    Dense(4),
)
in_shape = (-1,) + (obs_data.shape[1],)

In [9]:
def mseloss(params, inputs, targets):
    # Computes average loss for the batch
    predictions = net_apply(params, inputs)
    return np.mean((targets - predictions)**2)

def celoss(params, inputs, targets):
    logits = net_apply(params, inputs)
    logits = stax.logsoftmax(logits)  # log normalize
    return -np.mean(np.sum(logits * targets, axis=1))  # cross entropy loss

def fit_policy(params, opt_state, data, batch_size, epochs):
    X_tr, y_tr = data
    num_train = X_tr.shape[0]
    
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)
    
    def data_stream():
        while True:
            rng = np.random.RandomState(0)
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield X_tr[batch_idx], y_tr[batch_idx]
    batches = data_stream()
    
    itercount = itertools.count()
    
    for epoch in tqdm(range(epochs),desc='tr policy', position=2, leave=False):
        for _ in range(num_batches):
            ii = next(itercount)
            opt_state = step(ii, opt_state, next(batches))
            
    params = get_params(opt_state)
    return params

In [12]:
if loss_type == 'celoss':
    loss = celoss
elif loss_type == 'mseloss':
    loss = mseloss

for j in tqdm(range(5)): #Dagger main loop    
    out_shape, init_params = net_init(rng, in_shape)
    opt_init, opt_update, get_params = optimizers.adam(0.00001, b1=0.9, b2=0.999, eps=1e-08)
    opt_state = opt_init(init_params)
    params = fit_policy(init_params, opt_state, (obs_data, act_data), batch_size=64, epochs=30)
    pickle.dump(params, open( './models/{}_dagger_model.h5'.format(j), "wb" ))
    
    env = gym.make('LunarLander-v2')
    max_steps = env.spec.max_episode_steps

    returns = []
    new_observations = []
    new_actions = []
    for i in tqdm(range(num_rollouts), position=1, desc='rollout', leave=False):
        obs = env.reset()

        done = False
        totalr = 0.0

        for _ in range(max_steps):
            expert_action = policy_fn.predict(obs[None,:])
            predicted_action = net_apply(params, obs[None, :])
            predicted_action = np.argmax( predicted_action )
            new_observations.append(obs)
            new_actions.append(expert_action)

            obs, r, done, info = env.step(np.array(predicted_action))
            totalr += r

        returns.append(totalr)

    print('returns', returns)
    print('mean return', np.mean(returns))
    print('std of return', np.std(returns))

    #record returns
    main_returns.append(returns)
    mean_rewards.append(np.mean(returns))
    stds.append(np.std(returns))

    #data aggregation
    obs_data = np.concatenate((obs_data, np.array(new_observations)))
    new_actions = np.array(new_actions)
    act_data = np.concatenate((act_data, np.array(new_actions.reshape(new_actions.shape[0], new_actions.shape[2]))))


returns [-91772.60817740177, -94933.25811119816, -86810.43591820316, -90357.6450017454, -88601.23341839312]
mean return -90495.03612538832
std of return 2775.3071341413283


returns [44648.950548796376, 74894.2250263559, 66025.43937408968, -92725.65219444488, 67142.5821480634]
mean return 31997.108980572095
std of return 63165.07004471768


returns [-89363.19227564016, -94910.5392808261, -86923.95073915925, -94311.31192433376, -93021.00934428454]
mean return -91706.00071284876
std of return 3070.593264091286


returns [-430.3144051627751, 72933.74443018195, -11530.66102431075, 67258.28688707371, 54858.51071045614]
mean return 36617.913319647654
std of return 35443.61116951529


returns [-86721.0606240035, 71093.59780871749, 41120.96879092928, 71822.88683635945, 67697.54020033425]
mean return 33002.78660246739
std of return 60928.38498514253



In [13]:
print(mean_rewards)

[-90495.03612538832, 31997.108980572095, -91706.00071284876, 36617.913319647654, 33002.78660246739]


In [15]:
final_model = './models/{}_dagger_model.h5'.format(j)
!xvfb-run -s "-screen 0 600x400x24" python3.6 render.py --mpath $final_model

 10%|████▏                                   | 105/1000 [00:01<00:15, 59.15it/s]


In [16]:
!ls ./gym-results

openaigym.episode_batch.0.14997.stats.json
openaigym.manifest.0.14997.manifest.json
openaigym.video.0.14997.video000000.meta.json
openaigym.video.0.14997.video000000.mp4


In [17]:
for file in os.listdir("./gym-results/"):
    if file.endswith(".mp4"):   
        mp4name = os.path.join("./gym-results", file)
        print(mp4name)

./gym-results/openaigym.video.0.14997.video000000.mp4


In [18]:
video = io.open(mp4name, 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''
    <video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'''
.format(encoded.decode('ascii')))