In [None]:
import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy.stats.norm as norm
from autograd import grad
from autograd.misc import flatten

from autograd.misc.optimizers import sgd, adam

import gym
import roboschool

In [None]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from multiprocessing_env import SubprocVecEnv

batch_size = 32
# env_name = "RoboschoolInvertedPendulumSwingup-v1"
env_name = "RoboschoolInvertedPendulum-v1"


def make_env():
    def _thunk():
        env = gym.make(env_name)
        return env

    return _thunk

envs = [make_env() for i in range(batch_size)]
envs = SubprocVecEnv(envs)

env = gym.make(env_name)

num_states  = envs.observation_space.shape[0]
num_actions = envs.action_space.shape[0]

In [None]:
def init_param(scale, layer_sizes):#, rs=npr.RandomState(0)):
    return [(npr.randn(insize, outsize) * scale, 
             npr.randn(outsize) * scale)
            for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])]

def nonlin(x):
    return np.maximum(x,0.)

def actor_critic(params, x): # policy function
    
    mean_params = params['mean_params']
    std_params = params['std_params']
    value_params = params['value_params']
    
    mean = x
    for W, b in mean_params[:-1]:
        mean = np.dot(mean, W) + b
        mean = nonlin(mean)
    W, b = mean_params[-1]
    mean = np.dot(mean, W) + b

    log_std = x
    for W, b in std_params[:-1]:
        log_std = np.dot(log_std, W) + b
        log_std = nonlin(log_std)
    W, b = std_params[-1]
    log_std = np.dot(log_std, W) + b
    std = np.exp(log_std)
#     std = np.exp(log_std)/np.sum(np.exp(log_std),0)

    value = x
    for W, b in value_params[:-1]:
        value = np.dot(value, W) + b
        value = nonlin(value)
    W, b = value_params[-1]
    value = np.dot(value, W) + b
    
    return mean, np.exp(log_std), value


mean_policy_sizes = [num_states] + [128, 128] + [num_actions]
std_policy_sizes = [num_states] + [128, 128] + [num_actions]
value_sizes = [num_states] + [128, 128] + [1]

params= {
    'mean_params' : init_param(0.1, mean_policy_sizes),
    'std_params' : init_param(0., mean_policy_sizes),
    'value_params' : init_param(0.1, value_sizes)
}
# params.update({'std_params' : 
#                [[np.ones((num_states, num_actions))*0., np.zeros(num_actions)]],
#               })

In [None]:
def compute_gae(next_value, rewards, masks, values, 
                        gamma=0.99, tau=0.95):
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * tau * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

def test_compute_gae(final_value, rewards, masks, values, 
                        gamma=0.99, tau=0.95):
    gae = 0
    returns = np.zeros_like(rewards)
    for t in reversed(range(tf)):
        
        delta = rewards[:, t] \
                + (gamma * values[:,t+1] * masks[:,t] - values[:,t]
                   if t+1 < tf else final_value)
        gae = delta + gamma * tau * masks[:,t] * gae
        returns[:, t] = gae + values[:, t]
    return returns


In [None]:
def logprob_mse(x, states, actions, rtgs, advantage):
    params, _ = flatten(x)
    mean, std, value = actor_critic(x, states)
    value = np.concatenate(value).reshape((-1,1))
    logpdf = norm.logpdf(mean, actions, std)
    logpdf = np.concatenate(logpdf)
    value_err = np.mean(np.square(rtgs - value))
    advantage = np.concatenate(advantage).reshape((-1,1))
    return -np.mean(logpdf * advantage) \
            + 0.5 * value_err
grad_logprob_mse = grad(logprob_mse)

In [None]:
def callback(x, i, g):
    if i*tf % 1000 == 0:
        g_flat, _ = flatten(g)
        rew = 0.
        state = env.reset()
        for t in range(tf):
            mean, std, value = actor_critic(x, state)
            state, reward, done, _ = env.step(mean)
            env.render()
            if done : break
            rew += reward
        reward_history.append(rew)
        plot(i*tf, reward_history)
        print('iter : {}, tot rew : {}, grad norm : {}'.format(i*tf, rew, np.linalg.norm(g_flat)))
def plot(frame_idx, rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))
    plt.plot(rewards)
    plt.show()

In [None]:
tf = 200
frame_idx = 0
reward_history = []

def update(params, k):
    actions = np.zeros((batch_size, tf, num_actions))
    states = np.zeros((batch_size, tf, num_states))
    rewards = np.zeros((batch_size, tf))
    values = np.zeros((batch_size, tf))
    masks = np.zeros((batch_size, tf))

    state = envs.reset()
    for t in range(tf):
        mean, std, value = actor_critic(params, state)
        sampled_action = npr.normal(mean, std)

        values[:, t] = value[0]
        actions[:, t, :] = sampled_action
        states[:, t, :] = state

        state, reward, done, _ = envs.step(sampled_action)
        rewards[:, t] = reward 
        masks[:, t] = 1-done

    # get the last state
    _, _, final_value = actor_critic(params, state)

    # compute cost-to-go
    rtgs = test_compute_gae(final_value[0], rewards, masks, values)
    advantage = rtgs - values

    advantage = np.concatenate(advantage).reshape((-1,1))
    rtgs = np.concatenate(rtgs).reshape((-1,1))

    return grad_logprob_mse(params, states, actions, rtgs, advantage)


params = sgd(update, params, 
             num_iters=500,
             callback=callback, 
             step_size=1e-2)


In [None]:
state = env.reset()
for t in range(tf):
    mean, std, value = actor_critic(params, state)
    sampled_action = npr.normal(mean, std)
    state, reward, done, _ = env.step(mean)


In [None]:
state = env.reset()

mean, std, value = actor_critic(params, state)
sampled_action = npr.normal(mean, std)
#     state, reward, done, _ = env.step(mean)

In [None]:
norm.logpdf(mean, sampled_action, 10000000*std)