In [None]:
import math
import random
import time
from jupyterthemes import jtplot

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline
gym.logger.set_level(40)
jtplot.style()
torch.backends.cudnn.benchmark = True

use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

In [None]:
from common.multiprocessing_env import SubprocVecEnv

num_envs = 1
#env_name = "Pendulum-v0"
env_name = "BipedalWalker-v2"

def make_env():
    def _thunk():
        env = gym.make(env_name)
        return env

    return _thunk

envs = [make_env() for i in range(num_envs)]
envs = SubprocVecEnv(envs)
env = gym.make(env_name)

num_inputs  = envs.observation_space.shape[0]
num_outputs = envs.action_space.shape[0]
num_codes = 2
num_tests = 5

if env_name == "BipedalWalker-v2":
    num_steps = 2000
    expert_traj = np.load("trajectory/BipedalWalker_100.npy")
    min_buffer_size      = 4000
    batch_size           = 500

if env_name == "Pendulum-v0":
    num_steps = 100
    expert_traj = np.load("trajectory/Pendulum50000.npy")
    min_buffer_size      = 2000
    batch_size           = 5
    
#Hyper params:
a2c_hidden_size       = 128
discrim_hidden_size   = 128
posterior_hidden_size = 128
lr                    = 3e-4
num_G_updates         = 5
num_D_updates         = 1

print(num_inputs,num_outputs)

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0., std=0.1)
        nn.init.constant_(m.bias, 0.1)
        
class Actor(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_codes, hidden_size, std=0.0):
        super(Actor, self).__init__()
        
        self.linear1  = nn.Linear(num_inputs, hidden_size)
        self.linear2  = nn.Linear(hidden_size, hidden_size)
        self.linear_code  = nn.Linear(num_codes, hidden_size)
        self.linear_actor  = nn.Linear(hidden_size*2, num_outputs)
        self.log_std = nn.Parameter(torch.ones(1, num_outputs) * std)
        self.apply(init_weights)
        
    def forward(self, x, c):
        x = F.tanh(self.linear1(x))
        x = F.tanh(self.linear2(x))
        c = F.tanh(self.linear_code(c))
        mu = self.linear_actor(torch.cat([x,c],1))
        std   = self.log_std.exp().expand_as(mu)
        dist  = Normal(mu, std)
        return dist

class Critic(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_codes, hidden_size, std=0.0):
        super(Critic, self).__init__()
        
        self.linear1  = nn.Linear(num_inputs, hidden_size)
        self.linear2  = nn.Linear(hidden_size, hidden_size)
        self.linear_code  = nn.Linear(num_codes, hidden_size)
        self.linear_critic  = nn.Linear(hidden_size*2, 1)
        self.apply(init_weights)
        
    def forward(self, x, c):
        x = F.tanh(self.linear1(x))
        x = F.tanh(self.linear2(x))
        c = F.tanh(self.linear_code(c))
        value = self.linear_critic(torch.cat([x,c],1))
        return value
    
class Discriminator(nn.Module):
    def __init__(self, num_inputs, hidden_size):
        super(Discriminator, self).__init__()
        
        self.linear1   = nn.Linear(num_inputs, hidden_size)
        self.linear2   = nn.Linear(hidden_size, hidden_size)
        self.linear3   = nn.Linear(hidden_size, 1)
        self.linear3.weight.data.mul_(0.1)
        self.linear3.bias.data.mul_(0.0)
    
    def forward(self, x):
        x = F.tanh(self.linear1(x))
        x = F.tanh(self.linear2(x))
        prob = F.sigmoid(self.linear3(x))
        return prob
    
def expert_reward(state, action):
    state_action = torch.cat((state,action),1)
    return -np.log(discriminator(state_action).cpu().data.numpy())

class Posterior(nn.Module):
    def __init__(self, num_inputs, num_outputs, hidden_size):
        super(Posterior, self).__init__()
        
        self.linear1   = nn.Linear(num_inputs, hidden_size)
        self.linear2   = nn.Linear(hidden_size, hidden_size)
        self.linear3   = nn.Linear(hidden_size, num_outputs)
        self.linear3.weight.data.mul_(0.1)
        self.linear3.bias.data.mul_(0.0)
    
    def forward(self, x):
        x = F.tanh(self.linear1(x))
        x = F.tanh(self.linear2(x))
        Q = F.softmax(self.linear3(x))
        return Q

def posterior_reward(state, action, code):
    state_action = torch.cat((state,action),1)
    Q = posterior(state_action)
    return np.log(Q.cpu().data.numpy()[0][code])


In [None]:
def compute_gae(next_value, rewards, masks, values, gamma=0.995, tau=0.97):
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * tau * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

In [None]:
# Batchサンプリング
def sample_batch(states, actions, codes, log_probs, returns, advantages):
    batch_iteration_num = states.shape[0] // batch_size
    perm = np.arange(states.shape[0])
    np.random.shuffle(perm)
    states, actions, codes, log_probs, returns, advantages= \
        states[perm], actions[perm], codes[perm], log_probs[perm], returns[perm], advantages[perm]
    for i in range (batch_iteration_num):
        ids = slice(i * batch_size, min((i + 1) * batch_size, states.shape[0]))
        yield states[ids, :], actions[ids, :], codes[ids], log_probs[ids, :], returns[ids, :], advantages[ids, :]


def ppo_update(states, actions, codes, log_probs, returns, advantages, clip_param=0.2):
    for _ in range(num_G_updates):
        for state, action, code, old_log_probs, return_, advantage in sample_batch(states, actions, codes, log_probs, returns, advantages):
            #dist, value = model(state)
            onehot_code = np.eye(num_codes)[code] 
            onehot_code = torch.FloatTensor(onehot_code).to(device)
            
            dist = actor(state, onehot_code)
            value = critic(state, onehot_code)
            entropy = dist.entropy().mean()
            new_log_probs = dist.log_prob(action)

            ratio = (new_log_probs - old_log_probs).exp()
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage

            actor_loss  = - torch.min(surr1, surr2).mean() - 0.001 * entropy
            #critic_loss = (return_ - value).pow(2).mean()
            critic_loss = critic_criterion(value, return_)
            
            optimizer_critic.zero_grad()
            critic_loss.backward()
            optimizer_critic.step()
            
            optimizer_actor.zero_grad()
            actor_loss.backward()
            optimizer_actor.step()
            
def discriminator_update():
    expert_state_action = expert_traj[np.random.randint(0, expert_traj.shape[0], states.shape[0]), :]
    expert_state_action = torch.FloatTensor(expert_state_action).to(device)
    fake = discriminator(state_actions)
    real = discriminator(expert_state_action)
    
    discrim_loss = discrim_criterion(fake, torch.ones((states.shape[0], 1)).to(device)) + \
            discrim_criterion(real, torch.zeros((expert_state_action.size(0), 1)).to(device))
    
    optimizer_discrim.zero_grad()
    discrim_loss.backward()
    optimizer_discrim.step()
    return real,fake

In [None]:
def plot(frame_idx, rewards, D_fake, D_real):
    clear_output(True)
    plt.figure(figsize=(20,8))
    plt.subplot(121)
    plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))
    plt.plot(rewards)
    plt.subplot(122)
    plt.title('D_fake %s  D_real %s' % (D_fake[-1] ,D_real[-1]))
    plt.plot(D_fake)
    plt.plot(D_real)
#    plt.subplot(223)
#    plt.title('loss %s' % (losses[-1:]))
#    plt.plot(losses)
    plt.show()
    
def test_env(vis=False, code=0):
    state = env.reset()
    if vis: env.render()
    done = False
    total_reward = 0

    onehot_code = torch.zeros([num_envs, num_codes]).to(device)
    onehot_code[:, code] = 1

    while not done:
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        dist = actor(state, onehot_code)
        #next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])
        next_state, reward, done, _ = env.step(dist.mean.detach().cpu().numpy()[0])
        state = next_state
        if vis: env.render()
        total_reward += reward
    return total_reward

In [None]:
actor = Actor(num_inputs, num_outputs, num_codes, a2c_hidden_size).to(device)
critic = Critic(num_inputs, num_outputs, num_codes, a2c_hidden_size).to(device)
discriminator = Discriminator(num_inputs + num_outputs, discrim_hidden_size).to(device)
posterior = Posterior(num_inputs + num_outputs, num_codes, posterior_hidden_size).to(device)

discrim_criterion = nn.BCELoss().cuda()
critic_criterion = nn.MSELoss().cuda()
posterior_criterion = nn.CrossEntropyLoss().cuda()


optimizer_actor  = optim.Adam(actor.parameters(), lr=lr)
optimizer_critic  = optim.Adam(critic.parameters(), lr=lr)
optimizer_discrim = optim.Adam(discriminator.parameters(), lr=lr)
optimizer_posterior = optim.Adam(posterior.parameters(), lr=lr)

In [None]:
test_rewards = []
D_real    = []
D_fake    = []
Qs        = []
max_frames = 50000000
frame_idx = 0

In [None]:
done = False

# 最大フレームまでエポックループ
while frame_idx < max_frames:
    log_probs = []
    values    = []
    states    = []
    actions   = []
    codes     = []
    rewards   = []
    masks     = []
    returns   = []

    # 更新バッファサイズまでエピソードループ
    while (len(states) < min_buffer_size):
        state = envs.reset()
        rewards_ = []
        masks_ = []
        values_ = []
        
        code = np.random.randint(num_codes)
        onehot_code = torch.zeros([num_envs, num_codes]).to(device)
        onehot_code[:, code] = 1
        
        # エピソード終了までタイムステップループ
        for i in range(1, num_steps+1):
            state = torch.FloatTensor(state).to(device)
            dist = actor(state, onehot_code)
            value = critic(state, onehot_code)            
            action = dist.sample()
            action_numpy = action.cpu().numpy()
            next_state, env_reward, done, _ = envs.step(action_numpy)
            if i == num_steps:
                done == True
            reward = expert_reward(state, action)*1.0 + posterior_reward(state, action, code)
            #reward = expert_reward(state, action)
            log_prob = dist.log_prob(action).to(device)
            
            # log
            states.append(state)
            actions.append(action)
            log_probs.append(log_prob)
            
            rewards_.append(torch.FloatTensor(reward).to(device))
            masks_.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))
            values_.append(value)
            codes.append(code)
            
            if done:
                break   
            state = next_state
            frame_idx += 1

        # エピソード終了後
        print("done time:", i, "episode reward:")
        next_state = torch.FloatTensor(next_state).to(device)
        next_value = critic(next_state, onehot_code)
        returns_ = compute_gae(next_value, rewards_, masks_, values_)
        returns.extend(returns_)
        rewards.extend(rewards_)
        masks.extend(masks_)
        values.extend(values_)

    returns   = torch.cat(returns).detach().to(device)
    log_probs = torch.cat(log_probs).detach().to(device)
    values    = torch.cat(values).detach().to(device)
    states    = torch.cat(states).to(device)
    actions   = torch.cat(actions).to(device)
    codes     = torch.LongTensor(codes).to(device)
    advantage = returns - values
    
    # Discriminator
    state_actions = torch.cat([states, actions], 1)
    for _ in range(num_D_updates):
        real, fake = discriminator_update()

    # Posterior
    Q = posterior(state_actions)
    posterior_loss = posterior_criterion(Q, codes)
    optimizer_posterior.zero_grad()
    posterior_loss.backward()
    optimizer_posterior.step()

    # Generator
    ppo_update(states, actions, codes, log_probs, returns, advantage)
    
    # Test
    test_reward = np.mean([test_env() for _ in range(3)])
    test_rewards.append(test_reward)
    D_real.append(torch.mean(real).data.cpu().numpy())
    D_fake.append(torch.mean(fake).data.cpu().numpy())
    plot(frame_idx, test_rewards, D_fake, D_real)
    print("code", code, " Q(s,a)" , Q.data[-1])

In [None]:
test_env(True,0)

In [None]:
# model save

MODEL_PATH_ACTOR = 'asset/infoGAIL/infoGAIL_actor.pth'
MODEL_PATH_CRITIC = 'asset/infoGAIL/infoGAIL_critic.pth'
MODEL_PATH_DISCRIMINATOR = 'asset/infoGAIL/infoGAIL_discriminator.pth'
MODEL_PATH_POSTERIOR = 'asset/infoGAIL/infoGAIL_posterior.pth'

torch.save(actor.state_dict(), MODEL_PATH_ACTOR)
torch.save(critic.state_dict(), MODEL_PATH_CRITIC)
torch.save(discriminator.state_dict(), MODEL_PATH_DISCRIMINATOR)
torch.save(posterior.state_dict(), MODEL_PATH_POSTERIOR)
