In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import utils

import gym
from torch import nn
import matplotlib.pyplot as plt
import copy
import time

# Actor Critic Reinforcement Learning

In [6]:
class Agent(nn.Module):
    """Recurrent Actor Critic agent"""
    def __init__(self, state_dim, n_actions, n_hidden = 32, act_fn = utils.RLReLU()):
        super().__init__()
        self.act_fn = act_fn
        
        # init the hidden state of the LSTM cell
        self.init_hidden = (torch.zeros(1,n_hidden).cuda(),
                            torch.zeros(1,n_hidden).cuda())
        self._reset_hidden_state()
        
        # extract game state info
        self.state_extractor = nn.Linear(state_dim, n_hidden)
        
        # stateful LSTM cell ingesting each new state info.
        self.rnn = nn.LSTMCell(n_hidden, n_hidden)
        
        # picking Action from the output of the LSTM cell
        self.pre_actor = nn.Linear(n_hidden, n_hidden)
        self.actor = nn.Linear(n_hidden, n_actions)
        
        self.norm_actor = nn.LayerNorm(n_hidden)
        self.norm_critic = nn.LayerNorm(n_hidden)
        
        # evaluating Value from the output of the LSTM cell
        self.pre_critic = nn.Linear(n_hidden, n_hidden)
        self.critic = nn.Linear(n_hidden, 1)
        self.apply(utils.atari_initializer)
        
    def forward(self, x):
        # extract embedding from game state
        
        state_embedding = self.act_fn(self.state_extractor(x))
        
        # get current state of the LSTM cell
        hx, cx = self.get_hidden_state()
        hx, cx = self.rnn(state_embedding[None,...], (hx, cx))
        
        # compute action logits from LSTM output
        action_logits = self.act_fn(self.pre_actor(hx.squeeze(0)))
        action_logits = self.actor(self.norm_actor(action_logits))
        
        # compute value from LSTM output
        value = self.act_fn(self.pre_critic(hx.squeeze(0)))
        value = self.critic(self.norm_critic(value))
        return action_logits, value.squeeze(-1), (hx, cx)
    
    def set_hidden_state(self, hidden_state):
        self.hidden_state = hidden_state
        
    def _reset_hidden_state(self):
        self.set_hidden_state((self.init_hidden[0].detach().clone(), 
                               self.init_hidden[1].detach().clone()))
        
    def stop_bptt(self):
        # stop backpropagating through time: dropping 'old'
        # gradients from the autograd computations, otherwise
        # GPU memory explodes, and anyway old gradients vanish to 0
        # so they're not useful.
        hidden_state = self.get_hidden_state()
        self.set_hidden_state((hidden_state[0].detach().clone(), 
                               hidden_state[1].detach().clone()))
        
    def get_hidden_state(self):
        return self.hidden_state
    
    def sample_action(self, state):
        # observes a state and randomly samples an action
        # from it, using the distribution given by the Actor part
        # of the network.
        
        action_logits, value, (hx, cx) = self(state)
        policy = F.softmax(action_logits, dim = -1)
        
        dist = torch.distributions.Categorical(probs=policy)
        action = dist.sample()
        
        log_probs = F.log_softmax(action_logits, dim = -1)[action]
        return action, value, dist.entropy(), log_probs, policy, (hx, cx)
    
    def take_action(self, state):
        # observes a state and randomly samples an action
        # from it, using the distribution given by the Actor part
        # of the network.
        
        action_logits, value, (hx, cx) = self(state)
        action_probs = F.softmax(action_logits, dim = -1)
        action_probs.cpu().detach().numpy()
        return np.argmax(action_probs.cpu().detach().numpy())
    
def t(x): return torch.from_numpy(x).float().cuda()

def roll_n_steps(agent, env, n_steps, state_0):
    # get 'n_steps' of true rewards in a simulated env
    sim_env = copy.copy(env)
    with torch.no_grad():
        ag = copy.copy(agent)
        state = copy.copy(state_0)
        rewards = 0

        for i in range(n_steps):
            action, value, _, log_probs, probs, (hx, cx) = ag.sample_action(t(state))
            next_state, reward, done, info = sim_env.step(action.cpu().detach().data.numpy())
            if done: 
                return rewards
            rewards += reward * (gamma ** (i+1))
            state = next_state
            ag.set_hidden_state((hx, cx))

        _, future_reward, _, _, _, _ = ag.sample_action(t(state))
    return rewards + future_reward.detach() * (gamma ** (n_steps + 1))

In [7]:
env = gym.make("CartPole-v0").unwrapped

In [8]:
n_episodes = 600

actor_factor = 0.5

# discount factor
gamma = 0.99

# KL-divergence regularization ensuring smooth policy updates
# each episode starts with a uniform policy
kl_factor = 0.01

# Entropy regularization, ensuring that the policy doesnt
# collapse to its mode or something. Helps exploring states
entropy_factor = -0.005

# n-step rollout in the future. High n_step = low bias high variance
# low n_step = high bias low variance
n_steps = 1

state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = Agent(state_dim, n_actions).cuda()
adam = torch.optim.Adam(agent.parameters())
cos_lr = torch.optim.lr_scheduler.OneCycleLR(adam, max_lr = 1e-3, 
                                             total_steps = n_episodes, 
                                             final_div_factor = 100)


In [None]:
episode_rewards = []

for i in range(n_episodes):
    
    # init a policy with uniform dist 
    prev_policy = torch.ones(n_actions, device = 'cuda') / n_actions
    
    done = False
    total_reward = 0
    state = env.reset()
    av_rewards = []
    
    agent._reset_hidden_state()
    steps_done = 0
    
    # Exploration noise https://arxiv.org/abs/1706.01905
    # helps exploration
    with torch.no_grad():
        agent.pre_actor.weight.data.add_((0.005)*torch.randn(agent.pre_actor.weight.data.size()
                                                             , device = 'cuda'))
    while not done:
        # observe state and take action
        action, value, entropy, log_probs, policy, (hx, cx) = agent.sample_action(t(state))
        # get rewarded for action taken
        next_state, reward, done, info = env.step(action.cpu().detach().data.numpy())
        
        total_reward += reward
        state = next_state
        
        # rnn hidden state
        agent.set_hidden_state((hx, cx))
        
        # compute Advantage: reward + sum of discounted rewards - value
        advantage = reward
        
        # doing n_step rollout to estimate future rewards
        ohx = hx.clone()
        ocx = cx.clone()
        if not done:
            future_reward = roll_n_steps(agent, env, n_steps, next_state)
            advantage += future_reward
        advantage -= value
        agent.set_hidden_state((ohx, ocx))
        
        # solving Advantage Actor Critic equation (1) from
        # alphastar https://arxiv.org/pdf/1708.04782.pdf
        critic_loss = advantage.pow(2).mean()
        actor_loss = -log_probs * (advantage.detach())
        kl_loss = F.kl_div(policy.log(), prev_policy, reduction = 'sum')
        loss = (critic_loss 
                + (actor_factor * actor_loss) 
                + (kl_factor * kl_loss) 
                + (entropy_factor * entropy))
        
        prev_policy = policy.detach().clone()
        
        adam.zero_grad()
        loss.backward(retain_graph=True)
        adam.step()
        
        env.render()
        steps_done += 1
            
        if steps_done % 70 == 0:
            # need to implement truncated backprop through time
            # this poorman's solution seems to work ok
            agent.stop_bptt()
    episode_rewards.append(total_reward)
    average_100_games = np.mean(episode_rewards[-100:])
    entropy_factor *= 0.95
    if (i + 1) % 100 == 0:
        print(f"Average reward of last 100 games {average_100_games}")
        
    if average_100_games > 195:
        print("VICTORY")
        break
    
    cos_lr.step()

Average reward of last 100 games 27.79
Average reward of last 100 games 85.87
