In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import utils

import fastai
from fastai.layers import AdaptiveConcatPool2d, conv_layer, Flatten
import time
import gym
from torch import nn
import matplotlib.pyplot as plt
import copy

# Actor Critic Reinforcement Learning

In [2]:
class Agent(nn.Module):
    """Recurrent Actor Critic agent"""
    def __init__(self, state_dim, n_actions, state_extractor, 
                 n_hidden = 64, act_fn = utils.RLReLU()):
        super().__init__()
        self.act_fn = act_fn
        
        # init the hidden state of the LSTM cell
        self.init_hidden = (torch.zeros(1,n_hidden).cuda(),
                            torch.zeros(1,n_hidden).cuda())
        self._reset_hidden_state()
        
        # extract game state info
        self.state_extractor = state_extractor#nn.Linear(state_dim, n_hidden)
        
        # stateful LSTM cell ingesting each new state info.
        # internal states are never reset
        self.rnn = nn.LSTMCell(state_dim, n_hidden)
        
        # picking Action from the output of the LSTM cell
        self.pre_actor = nn.Linear(n_hidden, n_hidden)
        self.actor = nn.Linear(n_hidden, n_actions)
        
        # evaluating Value from the output of the LSTM cell
        self.pre_critic = nn.Linear(n_hidden, n_hidden)
        self.critic = nn.Linear(n_hidden, 1)
        self.apply(utils.atari_initializer)
        
    def forward(self, x):
        # extract embedding from game state

        state_embedding = self.act_fn(self.state_extractor(x))
        # get current state of the LSTM cell
        hx, cx = self.get_hidden_state()
        hx, cx = self.rnn(state_embedding, (hx, cx))
        # compute action logits from LSTM output
        action_logits = self.act_fn(self.pre_actor(hx.squeeze(0)))
        action_logits = self.actor(action_logits)
        
        # compute value from LSTM output
        value = self.act_fn(self.pre_critic(hx.squeeze(0)))
        value = self.critic(value)
        return action_logits, value.squeeze(-1), (hx, cx)
    
    def set_hidden_state(self, hidden_state):
        self.hidden_state = hidden_state
        
    def _reset_hidden_state(self):
        self.set_hidden_state((self.init_hidden[0].clone(), 
                               self.init_hidden[1].clone()))
        
    def stop_bptt(self):
        # stop backpropagating through time: dropping 'old'
        # gradients from the autograd computations, otherwise
        # GPU memory explodes, and anyway old gradients vanish to 0
        # so they're not useful.
        hidden_state = self.get_hidden_state()
        self.set_hidden_state((hidden_state[0].detach().clone(), 
                               hidden_state[1].detach().clone()))
        
    def get_hidden_state(self):
        return self.hidden_state
    
    def sample_action(self, state):
        # observes a state and randomly samples an action
        # from it, using the distribution given by the Actor part
        # of the network.
        
        action_logits, value, (hx, cx) = self(state)
        action_probs = F.softmax(action_logits, dim = -1)
        
        dist = torch.distributions.Categorical(probs=action_probs)
        action = dist.sample()
        
        log_probs = dist.log_prob(action)
        entropy = dist.entropy()
        
        return action, value, log_probs, entropy, (hx, cx)
    
def t(x): return torch.from_numpy(x).float()


def roll_n_steps(agent, env_state, env_system_state, n_steps, state):
    rewards = 0
    sim_env.ale.restoreState(env_state)
    sim_env.ale.restoreSystemState(env_system_state)
    #sim_env = copy.copy(env)
    with torch.no_grad():
        if n_steps > 0:
            for i in range(n_steps):
                action, value, log_probs, entropy, (hx, cx) = agent.sample_action(state)
                _, reward, done, info = sim_env.step(action.cpu().detach().data.numpy())
                next_state = utils.get_screen(sim_env)
                if done: return rewards
                rewards += reward * (gamma ** (i+1))
                state = next_state
                agent.set_hidden_state((hx, cx))
        _, future_reward, _, _, _ = agent.sample_action(state)
    return rewards + future_reward.detach() * (gamma ** (n_steps + 1))

In [3]:
env = gym.make("Breakout-v0").unwrapped
sim_env = gym.make("Breakout-v0").unwrapped

In [4]:
layers = [conv_layer(1, 8, stride = 2, ks = 4, norm_type = None), 
          conv_layer(8, 16, stride = 2, ks = 4, norm_type = None), 
          conv_layer(16, 32, stride = 2, ks = 4, norm_type = None),
          conv_layer(32, 64, stride = 2, ks = 4, norm_type = None),
          nn.AdaptiveMaxPool2d(1), 
          Flatten()]

cnn = nn.Sequential(*layers).cuda()

In [5]:
n_episodes = 2000
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = Agent(64, n_actions, cnn).cuda()
adam = torch.optim.Adam(agent.parameters())
cos_lr = torch.optim.lr_scheduler.OneCycleLR(adam, max_lr = 1e-3, 
                                             total_steps = n_episodes, 
                                             final_div_factor = 100)
gamma = 0.99

In [None]:
episode_rewards = []
for i in range(n_episodes):
    done = False
    total_reward = 0
    env.reset()
    state = utils.get_screen(env)
    av_rewards = []
    entropy_factor = 0.01
    agent.stop_bptt()
    steps_done = 0
    while not done:
        # observe state and take action
        action, value, log_probs, entropy, (hx, cx) = agent.sample_action(state)
        
        # get rewarded for action taken
        _, reward, done, info = env.step(action.cpu().detach().data.numpy())
        next_state = utils.get_screen(env)
        
        
        total_reward += reward
        state = next_state
        agent.set_hidden_state((hx, cx))
        
        # compute advantage: reward + sum of discounted rewards - value
        advantage = reward
        ohx = hx.clone()
        ocx = cx.clone()
        
        #print("before", start_state)
        if not done:
            sys_state = env.ale.cloneSystemState()
            st = env.ale.cloneState()
            future_reward = roll_n_steps(agent, st, sys_state, 2, next_state)
            advantage += future_reward
        #print("after", start_state)
        
        advantage -= value
        agent.set_hidden_state((ohx, ocx))
        
        critic_loss = advantage.pow(2).mean()
        actor_loss = -log_probs * (advantage.detach())
        loss = critic_loss + 0.5 * actor_loss - entropy_factor * entropy
        
        adam.zero_grad()
        loss.backward(retain_graph=True)
        #nn.utils.clip_grad_norm_(agent.parameters(), 1)
        adam.step()
        
        env.render()
        steps_done += 1
        if steps_done % 50 == 0:
            agent.stop_bptt()
    episode_rewards.append(total_reward)
    average_100_games = np.mean(episode_rewards[-10:])
    entropy_factor *= 0.99
    
    if (i + 1) % 10 == 0:
        print(f"Average reward of last 100 games {average_100_games}")
        
    """if average_100_games > 195:
        print("VICTORY")
        break"""
    
    cos_lr.step()

Average reward of last 100 games 2.0
Average reward of last 100 games 2.2
Average reward of last 100 games 2.6
Average reward of last 100 games 1.8
