In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import utils

import gym
from torch import nn
import matplotlib.pyplot as plt
import copy

# Actor Critic Reinforcement Learning

In [2]:
class Agent(nn.Module):
    """Recurrent Actor Critic agent"""
    def __init__(self, state_dim, n_actions, n_hidden = 32, act_fn = utils.RLReLU()):
        super().__init__()
        self.act_fn = act_fn
        
        # init the hidden state of the LSTM cell
        self.init_hidden = (torch.zeros(1,n_hidden),
                            torch.zeros(1,n_hidden))
        self._reset_hidden_state()
        
        # extract game state info
        self.state_extractor = nn.Linear(state_dim, n_hidden)
        
        # stateful LSTM cell ingesting each new state info.
        # internal states are never reset
        self.rnn = nn.LSTMCell(n_hidden, n_hidden)
        
        # picking Action from the output of the LSTM cell
        self.pre_actor = nn.Linear(n_hidden, n_hidden)
        self.actor = nn.Linear(n_hidden, n_actions)
        
        # evaluating Value from the output of the LSTM cell
        self.pre_critic = nn.Linear(n_hidden, n_hidden)
        self.critic = nn.Linear(n_hidden, 1)
        self.apply(utils.atari_initializer)
        
    def forward(self, x):
        # extract embedding from game state
        state_embedding = self.act_fn(self.state_extractor(x))
        
        # get current state of the LSTM cell
        hx, cx = self.get_hidden_state()
        hx, cx = self.rnn(state_embedding[None,...], (hx, cx))
        
        # compute action logits from LSTM output
        action_logits = self.act_fn(self.pre_actor(hx.squeeze(0)))
        action_logits = self.actor(action_logits)
        
        # compute value from LSTM output
        value = self.act_fn(self.pre_critic(hx.squeeze(0)))
        value = self.critic(value)
        return action_logits, value.squeeze(-1), (hx, cx)
    
    def set_hidden_state(self, hidden_state):
        self.hidden_state = hidden_state
        
    def _reset_hidden_state(self):
        self.set_hidden_state((self.init_hidden[0].clone(), 
                               self.init_hidden[1].clone()))
        
    def stop_bptt(self):
        # stop backpropagating through time: dropping 'old'
        # gradients from the autograd computations, otherwise
        # GPU memory explodes, and anyway old gradients vanish to 0
        # so they're not useful.
        hidden_state = self.get_hidden_state()
        self.set_hidden_state((hidden_state[0].detach().clone(), 
                               hidden_state[1].detach().clone()))
        
    def get_hidden_state(self):
        return self.hidden_state
    
    def sample_action(self, state):
        # observes a state and randomly samples an action
        # from it, using the distribution given by the Actor part
        # of the network.
        
        action_logits, value, (hx, cx) = self(state)
        action_probs = F.softmax(action_logits, dim = -1)
        
        dist = torch.distributions.Categorical(probs=action_probs)
        action = dist.sample()
        
        log_probs = dist.log_prob(action)
        entropy = dist.entropy()
        
        return action, value, log_probs, entropy, (hx, cx)
    
def t(x): return torch.from_numpy(x).float()

In [3]:
env = gym.make("CartPole-v0").unwrapped


In [4]:
n_episodes = 700
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = Agent(state_dim, n_actions)
adam = torch.optim.Adam(agent.parameters())
cos_lr = torch.optim.lr_scheduler.OneCycleLR(adam, max_lr = 1e-3, 
                                             total_steps = n_episodes, 
                                             final_div_factor = 100)
gamma = 0.99

In [5]:
episode_rewards = []
agent.reset_hidden_state()
for i in range(n_episodes):
    done = False
    total_reward = 0
    state = env.reset()
    av_rewards = []
    entropy_factor = 0.01
    agent.stop_bptt()
    steps_done = 0

    while not done:
        
        # observe state and take action
        action, value, log_probs, entropy, (hx, cx) = agent.sample_action(t(state))
        # get rewarded for action taken
        next_state, reward, done, info = env.step(action.detach().data.numpy())
        
        total_reward += reward
        state = next_state
        agent.set_hidden_state((hx, cx))
        
        # compute advantage: reward + sum of discounted rewards - value
        advantage = reward
        if not done:
            with torch.no_grad():
                _, future_value, _, _, _ = agent.sample_action(t(next_state))
            advantage += gamma * future_value.detach()
        advantage -= value
        
        critic_loss = advantage.pow(2).mean()
        actor_loss = -log_probs * (advantage.detach())
        loss = critic_loss + 0.5 * actor_loss - entropy_factor * entropy
        
        adam.zero_grad()
        loss.backward(retain_graph=True)
        #nn.utils.clip_grad_norm_(agent.parameters(), 1)
        adam.step()
        
        env.render()
        steps_done += 1
        if steps_done % 70 == 0:
            agent.stop_bptt()
    episode_rewards.append(total_reward)
    average_100_games = np.mean(episode_rewards[-100:])
    entropy_factor *= 0.99
    
    if (i + 1) % 100 == 0:
        print(f"Average reward of last 100 games {average_100_games}")
        
    if average_100_games > 195:
        print("VICTORY")
        break
    
    cos_lr.step()

Average reward of last 100 games 13.68
Average reward of last 100 games 55.73
VICTORY


In [None]:
def roll_n_steps(agent, env, n_steps):
    sim_env = copy.deepcopy(env)
    sim_agent = copy.deepcopy(agent)
    with torch.no_grad():
        for i in range(n_steps):
            action, value, log_probs, entropy, (hx, cx) = agent.sample_action(t(state))
            next_state, reward, done, info = env.step(action.detach().data.numpy())
    