In [None]:
import sys
import math
import random
import time
from jupyterthemes import jtplot

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline
gym.logger.set_level(40)
jtplot.style()
torch.backends.cudnn.benchmark = True

use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

In [None]:
from common.multiprocessing_env import SubprocVecEnv

num_envs = 1
#env_name = "Pendulum-v0"
env_name = "BipedalWalker-v2"

def make_env():
    def _thunk():
        env = gym.make(env_name)
        return env

    return _thunk

envs = [make_env() for i in range(num_envs)]
envs = SubprocVecEnv(envs)
env = gym.make(env_name)

num_inputs  = envs.observation_space.shape[0]
num_outputs = envs.action_space.shape[0]
num_codes = 2

a2c_hidden_size       = 128

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0., std=0.1)
        nn.init.constant_(m.bias, 0.1)
        
class Actor(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_codes, hidden_size, std=0.0):
        super(Actor, self).__init__()
        
        self.linear1  = nn.Linear(num_inputs, hidden_size)
        self.linear2  = nn.Linear(hidden_size, hidden_size)
        self.linear_code  = nn.Linear(num_codes, hidden_size)
        self.linear_actor  = nn.Linear(hidden_size*2, num_outputs)
        self.log_std = nn.Parameter(torch.ones(1, num_outputs) * std)
        self.apply(init_weights)
        
    def forward(self, x, c):
        x = F.tanh(self.linear1(x))
        x = F.tanh(self.linear2(x))
        c = F.tanh(self.linear_code(c))
        mu = self.linear_actor(torch.cat([x,c],1))
        std   = self.log_std.exp().expand_as(mu)
        dist  = Normal(mu, std)
        return dist

In [None]:
def test_env(vis=False, code=0):
    state = env.reset()
    if vis: env.render()
    done = False
    total_reward = 0

    onehot_code = torch.zeros([num_envs, num_codes]).to(device)
    onehot_code[:, code] = 1

    while not done:
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        dist = actor(state, onehot_code)
        #next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])
        next_state, reward, done, _ = env.step(dist.mean.detach().cpu().numpy()[0])
        state = next_state
        if vis: env.render()
        total_reward += reward
    return total_reward

In [None]:
actor = Actor(num_inputs, num_outputs, num_codes, a2c_hidden_size).to(device)
param = torch.load('asset/infoGAIL/4000-500-5-1/infoGAIL_actor.pth')
actor.load_state_dict(param)
test_env(True,1)