In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
import gym
import math
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
EPISODES = 1000
ROLLOUTS = 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Utils

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class PytorchWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
    
    def step(self, action):
        obs, reward, done, _ = self.env.step(action)
        obs = torch.tensor(obs, dtype=torch.float)
        reward = torch.tensor(reward, dtype=torch.float)
        return obs, reward, done
    
    def reset(self):
        obs = self.env.reset()
        obs = torch.tensor(obs, dtype=torch.float)
        return obs

In [None]:
def make_env(env):
    env = PytorchWrapper(env)
    return env

# Actor & Critic Network

In [None]:
class PolicyNetwork(nn.Module):
    """
    Actor Network - Predicts action distribution
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        self.action_head = nn.Sequential(
            nn.Linear(in_features, 32), 
            nn.ReLU(), 
            nn.Dropout(p=0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(16, out_features),
            nn.Softmax(dim=-1) ## Predicts action probs
        )
    
    def forward(self, x):
        action_prob = self.action_head(x)
        return action_prob

class ValueNetwork(nn.Module):
    """
    Critic Network - Predicts Value function of a state
    """
    def __init__(self, in_features):
        super().__init__()
        self.value_head = nn.Sequential(
            nn.Linear(in_features, 32), 
            nn.ReLU(), 
            nn.Dropout(p=0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(16, 1)
        )
    
    def forward(self, x):
        value_pred = self.value_head(x)
        return value_pred

# Rollout Buffer

In [None]:
class RolloutBuffer:
    def __init__(self, rollout_steps, gamma, device):
        self.rollout_steps = rollout_steps
        self.gamma = gamma
        self.device = device
        self.states = None
        self.rewards = None
        self.value_preds = None
        self.actions = None
        self.count = None
        self.reset()
    
    def reset(self):
        self.states = [None] * self.rollout_steps
        self.rewards = [None] * self.rollout_steps
        self.value_preds = [None] * self.rollout_steps
        self.actions = [None] * self.rollout_steps
        self.count = 0
    
    def store(self, state, reward, value_pred, action):
        self.states[self.count] = state
        self.rewards[self.count] = (self.gamma ** self.count) * reward
        self.value_preds[self.count] = value_pred 
        self.actions[self.count] = action
        self.count += 1
    
    def compute_advantages(self):
        returns = []
        advantages = []
        for i in range(self.count):
            returns.append(sum(self.rewards[i:self.count]) / self.gamma ** i)
            advantages.append(returns[-1] - self.value_preds[i])
        return returns, advantages

    def get_values(self):
        states = torch.stack(self.states[:self.count]).to(self.device)
        actions = torch.tensor(self.actions[:self.count]).to(self.device).long()
        returns, advantages = self.compute_advantages()
        returns = torch.stack(returns).to(self.device)
        advantages = torch.stack(advantages).to(self.device)
        self.reset()
        return states, returns, advantages, actions

# Advantage Actor-Critic (A2C)

In [None]:
class A2C:
    def __init__(self, obs_size, action_size, rollout_steps, device="cpu", gamma=0.99, lr=0.001):
        self.buffer = RolloutBuffer(rollout_steps, gamma, device)
        self.actor = PolicyNetwork(obs_size, action_size).to(device)
        self.critic = ValueNetwork(obs_size).to(device)
        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=lr)
        self.rollout_steps = rollout_steps
        self.device = device
    
    def critic_loss(self, pred, target):
        return F.smooth_l1_loss(pred, target, reduction="sum")
    
    def actor_loss(self, log_probs, advantages):
        loss = -log_probs * advantages
        loss = loss.sum()
        return loss
    
    def forward(self, obs, grad=False):
        action_prob = self.actor(obs)
        value_pred = self.critic(obs)
        if not grad:
            action_prob = action_prob.detach()
            value_pred = value_pred.detach()
        return action_prob, value_pred
    
    def learn(self, env, episodes):
        writer = SummaryWriter()
        count = 0
        for eps in range(episodes):
            obs = env.reset()
            reward_tracker = AverageMeter()
            for _ in range(self.rollout_steps):
                obs = obs.to(self.device)
                action_prob, value_pred = self.forward(obs.unsqueeze(0))
                action = Categorical(action_prob).sample().item()
                next_obs, reward, done = env.step(action)
                reward_tracker.update(reward.item())
                self.buffer.store(obs.cpu(), reward.squeeze(), value_pred.squeeze(), action)
                obs = next_obs
                    
                if(done):
                    break
            
            states, returns, advantages, actions = self.buffer.get_values()
            action_prob, value_pred = self.forward(states, grad=True)
            action_cat = Categorical(action_prob)
            log_probs = action_cat.log_prob(actions)

            tot_loss = 0.0
            # Fit Critic
            self.critic_optim.zero_grad()
            loss = self.critic_loss(value_pred.squeeze(), returns)
            loss.backward()
            self.critic_optim.step()
            tot_loss += loss.item()

            ## Fit Actor
            self.actor_optim.zero_grad()
            loss = self.actor_loss(log_probs, advantages)
            loss.backward()
            self.actor_optim.step()
            tot_loss += loss.item()

            writer.add_scalar('Loss', tot_loss, count)
            writer.add_scalar("Reward", reward_tracker.sum, count)
            count += 1
            
            if(eps % 10 == 0):
                print(f"Episode: {eps+1}/{episodes}, loss: {tot_loss}, reward: {reward_tracker.sum}")

In [None]:
env = gym.make('CartPole-v1')
env = make_env(env)
obs_size = env.observation_space.shape[0]
action_size = env.action_space.n

In [None]:
agent = A2C(obs_size, action_size, ROLLOUTS, device)

In [None]:
## Load tensorboard for visualization of loss
%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
agent.learn(env, EPISODES)