In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
import gym
import math

In [None]:
EPISODES = 1000
ROLLOUTS = 350
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Utils

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class PytorchWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
    
    def step(self, action):
        obs, reward, done, _ = self.env.step(action)
        obs = torch.tensor(obs, dtype=torch.float)
        reward = torch.tensor(reward, dtype=torch.float)
        return obs, reward, done
    
    def reset(self):
        obs = self.env.reset()
        obs = torch.tensor(obs, dtype=torch.float)
        return obs

In [None]:
def make_env(env):
    env = PytorchWrapper(env)
    return env

# Policy Network

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.action_head = nn.Sequential(
            nn.Linear(in_features, 32), 
            nn.ReLU(), 
            nn.Dropout(p=0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(16, out_features),
            nn.Softmax(dim=-1) ## Predicts action probs
        )
    
    def forward(self, x):
        action_prob = self.action_head(x)
        return action_prob

# Rollout Buffer

In [None]:
class RolloutBuffer:
    def __init__(self, rollout_steps, gamma, device):
        self.rollout_steps = rollout_steps
        self.gamma = gamma
        self.device = device
        self.states = None
        self.rewards = None
        self.actions = None
        self.count = None
        self.reset()
    
    def reset(self):
        self.states = [None] * self.rollout_steps
        self.rewards = [None] * self.rollout_steps
        self.actions = [None] * self.rollout_steps
        self.count = 0
    
    def store(self, state, reward, action):
        self.states[self.count] = state
        self.rewards[self.count] = (self.gamma ** self.count) * reward
        self.actions[self.count] = action
        self.count += 1
    
    def compute_returns(self):
        returns = []
        advantages = []
        for i in range(self.count):
            returns.append(sum(self.rewards[i:self.count]) / self.gamma ** i)
        return returns

    def get_values(self):
        states = torch.stack(self.states[:self.count]).to(self.device)
        actions = torch.tensor(self.actions[:self.count]).to(self.device).long()
        returns = self.compute_returns()
        returns = torch.stack(returns).to(self.device)
        self.reset()
        return states, returns, actions

# Policy Gradient Algorithm

In [None]:
class REINFORCE:
    def __init__(self, obs_size, action_size, rollout_steps, device="cpu", gamma=0.99, lr=0.001):
        self.buffer = RolloutBuffer(rollout_steps, gamma, device)
        self.policy = PolicyNetwork(obs_size, action_size).to(device)
        self.optim = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.rollout_steps = rollout_steps
        self.device = device
    
    def policy_loss(self, log_probs, returns):
        loss = -log_probs * returns
        loss = loss.sum()
        return loss
    
    def forward(self, obs, grad=False):
        action_prob = self.policy(obs)
        if not grad:
            action_prob = action_prob.detach()
        return action_prob
    
    def save_policy(self):
        self.policy.eval()
        torch.save(self.policy.state_dict(), "REINFORCE.bin")
        self.policy.train()
    
    def load_policy(self, path=None):
        if path is None:
            path = "REINFORCE.bin"
        
        self.policy.load_state_dict(torch.load(path))
        print("Successfully loaded")
    
    def learn(self, env, episodes):
        writer = SummaryWriter()
        count = 0
        for eps in range(episodes):
            obs = env.reset()
            reward_tracker = AverageMeter()
            for _ in range(self.rollout_steps):
                obs = obs.to(self.device)
                action_prob = self.forward(obs.unsqueeze(0))
                ## Sample action from prob distr
                action = Categorical(action_prob).sample().item()
                next_obs, reward, done = env.step(action)
                reward_tracker.update(reward.item())
                self.buffer.store(obs.cpu(), reward.squeeze(), action)
                obs = next_obs
                    
                if(done):
                    break
            
            states, returns, actions = self.buffer.get_values()
            action_prob = self.forward(states, grad=True)
            action_cat = Categorical(action_prob)
            log_probs = action_cat.log_prob(actions)

            ## Fit policy
            returns = (returns - returns.mean()) / (returns.std() + 1e-6)
            self.optim.zero_grad()
            loss = self.policy_loss(log_probs, returns)
            loss.backward()
            self.optim.step()

            writer.add_scalar('Loss', loss.item(), count)
            writer.add_scalar("Reward", reward_tracker.sum, count)
            count += 1
            
            if(eps % 10 == 0):
                print(f"Episode: {eps+1}/{episodes}, loss: {loss.item()}, reward: {reward_tracker.sum}")
                self.save_policy()

In [None]:
env = gym.make('CartPole-v1')
env = make_env(env)
obs_size = env.observation_space.shape[0]
action_size = env.action_space.n

In [None]:
agent = REINFORCE(obs_size, action_size, ROLLOUTS, device)

In [None]:
## Load tensorboard for visualization of loss
%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
agent.learn(env, EPISODES)