In [123]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gym
from gym import spaces
import copy
import os
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import torch.nn.functional as F
from torch.autograd import Variable

In [131]:
# these following functions were coded with reference to
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail

def init_params(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        m.weight.data.normal_(0, 1)
        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
        if m.bias is not None:
            m.bias.data.fill_(0)


# this is from https://github.com/p-morais/deep-rl/blob/master/rl/distributions/gaussian.py
class DiagonalGaussian(nn.Module):
    def __init__(self, num_outputs, init_std=1, learn_std=True):
        super(DiagonalGaussian, self).__init__()

        self.logstd = nn.Parameter(
            torch.ones(1, num_outputs) * np.log(init_std),
            requires_grad=learn_std
        )

        self.learn_std = learn_std

    def forward(self, x):
        mean = x

        std = self.logstd.exp()
        
        return mean, std

    def sample(self, x, deterministic):
        if deterministic is False:
            action = self.evaluate(x).sample()
        else:
            action, _ = self(x)

        return action

    def evaluate(self, x):
        mean, std = self(x)
        return torch.distributions.Normal(mean, std)

    
class ACModel(nn.Module):
    def __init__(self, num_inputs, action_dim, recurrent, batch_size, hidden_size=64):
        super().__init__()
        
        self.recurrent = recurrent
        if self.recurrent:
            self.hidden = self.init_hidden(batch_size, hidden_size)
            self.gru = nn.GRU(num_inputs, hidden_size)
            for name, param in self.gru.named_parameters():
                if 'bias' in name:
                    nn.init.constant_(param, 0)
                elif 'weight' in name:
                    nn.init.orthogonal_(param)
        else:
            self.fwd = nn.Sequential(nn.Linear(num_inputs, hidden_size), nn.Tanh())
        
        
        self.actor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size), nn.Tanh(),
            nn.Linear(hidden_size, action_dim)
        )
        
        self.critic = nn.Sequential(
            nn.Linear(hidden_size, hidden_size), nn.Tanh(),
            nn.Linear(hidden_size, 1)
        )
        
        self.dist = DiagonalGaussian(action_dim)

        self.apply(init_params)
    
    def init_hidden(self, batch_size, hidden_size):
        # h0 should be of shape (num_layers * num_directions, batch, hidden_size)
        # num_layers is 1, and RNN is not bidirectional, so num_dir = 1
        # (1, batch_size, hidden size)
        h = torch.zeros(1, batch_size, hidden_size)
        return nn.Parameter(h, requires_grad=True)

    def forward(self, obs, rnn_h=None):
        if self.recurrent:
            if rnn_h is None:
                rnn_h = self.hidden
            # TODO figure out dimensionality
            obs, rnn_h = self.gru(obs.unsqueeze(0), rnn_h)
        else:
            obs = self.fwd(obs)
            
        action_dist = self.dist.evaluate(self.actor(obs))
        
        forward_critic = self.critic(obs)
        
        return action_dist, forward_critic, rnn_h

In [135]:
obs = torch.FloatTensor([[1, 2, 3, 4]])
acmodel = ACModel(4, 2, False, 1, 64)

In [136]:
dist, value, _ = acmodel(obs)

In [137]:
dist.sample()

tensor([[-0.1347, -1.6704]])

In [138]:
class RolloutBuffer:
    def __init__(self, acmodel, env, discount=0.995, gae_lambda=0.95, device=None):
        self.episode_length = env.episode_length
        self.device = device
        self.acmodel = acmodel
        self.discount = discount
        self.gae_lambda = gae_lambda
        
        self.actions = None
        self.values = None
        self.rewards = None
        self.log_probs = None
        self.obss = None
        self.gaes = None
        
        self.reset()
        
    def reset(self):
        self.actions = torch.zeros(self.episode_length, device=self.device)
        self.values = torch.zeros(self.episode_length, device=self.device)
        self.rewards = torch.zeros(self.episode_length, device=self.device)
        self.log_probs = torch.zeros(self.episode_length, device=self.device)
        self.obss = [None] * self.episode_length
        
    
    def process_obs(self, obs):
        # TODO: formatting stuff
        return obs
    
    def collect_experience(self):
        obs = env.reset()
        total_return = 0
        T = 0
        
        while True:
            with torch.no_grad():
                dist, value = self.acmodel(obs)
                
            action = dist.sample()
            
            self.obss[T] = obs
            
            obs, reward, done, _ = env.step(action.item())
            
            total_return += reward
            
            self.actions[T] = action
            self.values[T] = value
            self.rewards[T] = reward
            # TODO figure out log probs..
            # self.log_probs[T] = dist.log_prob(action)
            
            
            T += 1
            if done:
                break
                
        self.actions = self.actions[:T]
        self.values = self.values[:T]
        self.rewards = self.rewards[:T]
        self.log_probs = self.log_probs[:T]
        self.gaes = self.compute_advantage_gae(T)
            
    def compute_advantage_gae(self, T):
        def _delta(t, rewards, discount, values):
            return rewards[t] + ((discount * values[t+1] - values[t]) if t+1 < values.shape[0] else 0)

        advantages = torch.zeros_like(self.values)

        n = self.values.shape[0]
        for t in range(n):
            advantages[t] = sum([(self.gae_lambda*self.discount)**i * _delta(t+i, self.rewards, self.discount, self.values) for i in range(n-t)])

        return advantages[:T]
        

In [None]:
class PPO:
    def __init__(self,
                 acmodel,
                 clip_ratio=0.2,
                 entropy_coef=0.01,
                 lr=1e-3,
                 target_kl=0.01,
                 train_iters=5):
        
        self.acmodel = acmodel
        self.clip_ratio = clip_ratio
        self.entropy_coef = entropy_coef
        self.target_kl=target_kl
        self.train_iters = train_iters
        
        self.optimizer = torch.optim.Adam(acmodel.parameters(), lr=lr)
        
    def update(self, rollouts):
        # rollouts should be RolloutBuffer object
        dist, _ = self.acmodel(rollouts.obss) # TODO may need to process these observations
        old_logp = dist.log_prob(rollouts.actions)
        
        policy_loss, _ = self._compute_policy_loss_ppo(rollouts.obss, old_logp, rollouts.actions, rollouts.gaes)
        value_loss = self._compute_value_loss(rollouts.obss, rollouts.returns)
        
        for i in range(self.train_iters):
            self.optimizer.zero_grad()
            pi_loss, approx_kl = self._compute_policy_loss_ppo(rollouts.obss, old_logp, rollouts.actions, rollouts.gaes)
            v_loss = self._compute_value_loss(rollouts.obss, rollouts.returns)
            
            loss = v_loss + pi_loss
            
            if approx_kl > 1.5 * self.target_kl:
                break
                
            loss.backward(retain_graph=True) # lol todo are we supposed to retain graph?

            optimizer.step()
            
        return policy_loss.item(), value_loss.item()
        
    def _compute_policy_loss_ppo(obs, old_logp, actions, advantages):
        policy_loss, approx_kl = 0, 0

        dist, _ = self.acmodel(obs)
        new_logp = dist.log_prob(actions)

        new_p, old_p = torch.exp(new_logp), torch.exp(old_logp)
        r = new_p / old_p

        clamp_adv = torch.clamp(r, 1-args.clip_ratio, 1+args.clip_ratio)*advantages
        min_advs = torch.minimum(r*advantages, torch.clamp(r, 1-args.clip_ratio, 1+args.clip_ratio)*advantages)

        policy_loss = -torch.mean(min_advs)
        approx_kl = (old_logp - new_logp).mean()
        
        return policy_loss, approx_kl
    
    def _compute_value_loss(obs, returns):
        _, values = acmodel(obs)
        value_loss = torch.mean((returns - values)**2)

        return value_loss
        
    def update_parameters_ppo(optimizer, acmodel, sb, args):

    
    def _compute_value_loss(obs, returns):
        ### TODO: implement PPO value loss computation (10 pts) ##########
        ##################################################################
        _, values = acmodel(obs)
        value_loss = torch.mean((returns - values)**2)

        return value_loss



In [5]:
# config params from the pset

class Config:
    def __init__(self,
                score_threshold=0.93,
                discount=0.995,
                lr=1e-3,
                max_grad_norm=0.5,
                log_interval=10,
                max_episodes=2000,
                gae_lambda=0.95,
                use_critic=False,
                clip_ratio=0.2,
                target_kl=0.01,
                train_ac_iters=5,
                use_discounted_reward=False,
                entropy_coef=0.01,
                use_gae=False):
        
        self.score_threshold = score_threshold
        self.discount = discount
        self.lr = lr
        self.max_grad_norm = max_grad_norm
        self.log_interval = log_interval
        self.max_episodes = max_episodes
        self.use_critic = use_critic
        self.clip_ratio = clip_ratio
        self.target_kl = target_kl
        self.train_ac_iters = train_ac_iters
        self.gae_lambda=gae_lambda
        self.use_discounted_reward=use_discounted_reward
        self.entropy_coef = entropy_coef
        self.use_gae = use_gae