In [1]:
import os, math
import gym
import pickle

import numpy as np
from collections import deque

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical, Normal

from tensorboardX import SummaryWriter 

from utils.utils import *

import matplotlib.pyplot as plt

print('you are using PyTorch version ',torch.__version__)

if torch.cuda.is_available():
    use_cuda = True
    print("you have", torch.cuda.device_count(), "GPUs")
    device = torch.device("cuda:0")
    print(device)
else:
    use_cuda = False
    print('no GPUs detected')
    device = torch.device("cpu")

%load_ext autoreload
%autoreload 2
%matplotlib inline

you are using PyTorch version  1.4.0
you have 2 GPUs
cuda:0


## Variational Adversarial Inverse Reinforcement Learning

When you get a large number of inputs x in sequence, but cannot store every x, yet would like to update M and V which are the running mean and variance. `ZFilter`  incorporates the input into a running estimate of mean and variance, then 
returns the z-score of the input 

initialize
$M_1 = x_i$
$V_1 = 0$

$M_t = M_{t-1} + \frac{(x_t + M_{t-1})}{t}$

$S_t = S_{t-1} + \frac{(x_t – M_{t-1})(x_t – M_t)}{t}$


In [9]:
env = gym.make('BipedalWalker-v3') #env = gym.make('Hopper-v2')
env.seed(0)
torch.manual_seed(0)
print("state space", env.observation_space, "action space", env.action_space)
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
running_state = ZFilter((num_inputs,), clip=5)

state space Box(24,) action space Box(4,)


In [10]:
def get_action(mu, std):
    action = torch.normal(mu, std)
    action = action.data.numpy()
    return action

def get_entropy(mu, std):
    dist = Normal(mu, std)
    entropy = dist.entropy().mean()
    return entropy

def log_prob_density(x, mu, std):
    log_prob_density = -(x - mu).pow(2) / (2 * std.pow(2)) \
                     - 0.5 * math.log(2 * math.pi)
    return log_prob_density.sum(1, keepdim=True)

def get_reward(vdb, state, action):
    state = torch.Tensor(state)
    action = torch.Tensor(action)
    state_action = torch.cat([state, action])
    with torch.no_grad():
        return -math.log(vdb(state_action)[0].item())

def kl_divergence(mu, logvar):
    kl_div = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1, dim=1)
    return kl_div

def save_checkpoint(state, filename):
    torch.save(state, filename)

class Actor(nn.Module):
    def __init__(self, num_inputs, num_outputs, hidden_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(num_inputs, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_outputs)
        
        self.fc3.weight.data.mul_(0.1)
        self.fc3.bias.data.mul_(0.0)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        mu = self.fc3(x)
        logstd = torch.zeros_like(mu)
        std = torch.exp(logstd)
        return mu, std


class Critic(nn.Module):
    def __init__(self, num_inputs, hidden_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(num_inputs, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        
        self.fc3.weight.data.mul_(0.1)
        self.fc3.bias.data.mul_(0.0)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        v = self.fc3(x)
        return v


class VDB(nn.Module):
    def __init__(self, num_inputs, hidden_size, z_size):
        super(VDB, self).__init__()
        self.fc1 = nn.Linear(num_inputs, hidden_size)
        self.fc2 = nn.Linear(hidden_size, z_size)
        self.fc3 = nn.Linear(hidden_size, z_size)
        self.fc4 = nn.Linear(z_size, hidden_size)
        self.fc5 = nn.Linear(hidden_size, 1)
        
        self.fc5.weight.data.mul_(0.1)
        self.fc5.bias.data.mul_(0.0)

    def encoder(self, x):
        h = torch.tanh(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar/2)
        eps = torch.randn_like(std)
        return mu + std * eps

    def discriminator(self, z):
        h = torch.tanh(self.fc4(z))
        return torch.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        prob = self.discriminator(z)
        return prob, mu, logvar
    
def train_vdb(vdb, memory, vdb_optim, demonstrations, beta, 
              vdb_update_num, I_c, alpha_beta):
    
    memory = np.array(memory) 
    states = np.vstack(memory[:, 0]) 
    actions = list(memory[:, 1]) 

    states = torch.Tensor(states)
    actions = torch.Tensor(actions)

    criterion = torch.nn.BCELoss()

    for _ in range(vdb_update_num):
        
        learner, l_mu, l_logvar = vdb(torch.cat([states, actions], dim=1))
        demonstrations = torch.Tensor(demonstrations)
        expert, e_mu, e_logvar = vdb(demonstrations)

        l_kld = kl_divergence(l_mu, l_logvar)
        l_kld = l_kld.mean()
        
        e_kld = kl_divergence(e_mu, e_logvar)
        e_kld = e_kld.mean()
        
        kld = 0.5 * (l_kld + e_kld)
        bottleneck_loss = kld - I_c

        beta = max(0, beta + alpha_beta * bottleneck_loss)

        vdb_loss = criterion(learner, torch.ones((states.shape[0], 1))) + \
                   criterion(expert, torch.zeros((demonstrations.shape[0], 1))) + \
                   beta * bottleneck_loss
                
        vdb_optim.zero_grad()
        vdb_loss.backward(retain_graph=True)
        vdb_optim.step()

    expert_acc = ((vdb(demonstrations)[0] < 0.5).float()).mean()
    learner_acc = ((vdb(torch.cat([states, actions], dim=1))[0] > 0.5).float()).mean()

    return expert_acc, learner_acc

In [11]:
actor = Actor(num_inputs, num_actions, hidden_size=128)
critic = Critic(num_inputs, hidden_size=128)
vdb = VDB(num_inputs + num_actions, hidden_size=128, z_size=4)

learning_rate = 3e-4
l2_rate = 1e-3

actor_optim = optim.Adam(actor.parameters(), lr=learning_rate)
critic_optim = optim.Adam(critic.parameters(), lr=learning_rate, weight_decay=l2_rate) 
vdb_optim = optim.Adam(vdb.parameters(), lr=learning_rate)

In [12]:
# load demonstrations
expert_demo, _ = pickle.load(open('expert_demo.p', "rb"))
demonstrations = np.array(expert_demo)
print("demonstrations.shape", demonstrations.shape) # (50000, 14)
print(demonstrations[:2])

demonstrations.shape (50000, 14)
[[-0.48224705 -1.18786003  1.84605944  0.62223241 -0.39152268 -3.21709328
   0.05523458 -0.0175782   0.14056332  0.08432692  0.01398241  2.57012254
   2.16022653  1.25368368]
 [-0.48457226 -1.11279922  1.86942212  0.62266743 -0.38204572 -3.11900995
  -0.0192135   0.62860545  0.72578217  0.08719617  0.25489085  2.5566931
   2.40988924  1.14469644]]


In [13]:
episodes = 0
train_discrim_flag = True
max_iter_num = 4
total_sample_size = 512
num_steps = 100

for iter in range(max_iter_num):
    actor.eval(), critic.eval()
    memory = deque()
    steps = 0
    scores = []
    
    while steps < total_sample_size: 
        
        state = env.reset()
        score = 0

        state = running_state(state)

        for _ in range(num_steps): 

            steps += 1

            mu, std = actor(torch.Tensor(state).unsqueeze(0))
            action = get_action(mu, std)[0]
            next_state, reward, done, _ = env.step(action)
            irl_reward = get_reward(vdb, state, action)

            if done:
                mask = 0
            else:
                mask = 1

            memory.append([state, action, irl_reward, mask])

            next_state = running_state(next_state)
            state = next_state

            score += reward

            if done:
                break

        episodes += 1
        scores.append(score)
        
        break
    break

In [14]:
memory = np.array(memory) 
states = np.vstack(memory[:, 0]) 
actions = list(memory[:, 1]) 

states = torch.Tensor(states)
actions = torch.Tensor(actions)

criterion = torch.nn.BCELoss()
print(torch.cat([states, actions], dim=1).shape) # torch.Size([70, 28])
learner, l_mu, l_logvar = vdb(torch.cat([states, actions], dim=1))

torch.Size([70, 28])


In [26]:
beta, vdb_update_num, I_c, alpha_beta = 0, 3, 0.5, 1e-4 
expert_acc, learner_acc = train_vdb(vdb, memory, vdb_optim, demonstrations, 
                                    beta, vdb_update_num, I_c, alpha_beta)