In [1]:
import os 

import gym

import torch 
import torch.nn as nn

import numpy as np
from collections import deque

  import distutils.spawn


In [2]:
episodes = 10000
hidden_size = 64
observation_space = 4
n_actions = 2

learning_rate = 0.001
gamma = 1
sigma = 0.1
update_frequency = 10
depth = 100

In [3]:
class PolicyNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #Hyperparams to be pushed into a config.
        self.sigma = sigma
        self.learning_rate = learning_rate

        input_shape = observation_space
        output_shape = n_actions
        
        self.device = 'cpu' 
        self.input_shape = observation_space
        self.output_shape = output_shape

        self.input_layer = nn.Linear(self.input_shape, hidden_size)
        self.hidden_layer = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, self.output_shape)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        self.to(self.device)

    def forward(self, x):
        x = torch.nn.functional.relu(self.input_layer(x))
        x = torch.nn.functional.relu(self.hidden_layer(x))
        x = self.output_layer(x)
        return torch.nn.functional.softmax(x, dim=1)

    def choose_action(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
    
        probabilities = self.forward(state) + torch.normal(torch.tensor(0.0), torch.tensor(sigma), size=(1,2))
        probabilities[probabilities < 0] = 0
        
        model = torch.distributions.Categorical(probabilities)
        action = model.sample()
        return action.item(), model.log_prob(action)
    
class CriticNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #Hyperparams to be pushed into a config.
        self.learning_rate = learning_rate

        input_shape = observation_space
        output_shape = 1
        
        self.device = 'cpu' 
        self.input_shape = observation_space
        self.output_shape = output_shape

        self.input_layer = nn.Linear(self.input_shape, hidden_size)
        self.hidden_layer = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, self.output_shape)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        self.to(self.device)

    def forward(self, x):
        x = torch.nn.functional.relu(self.input_layer(x))
        x = torch.nn.functional.relu(self.hidden_layer(x))
        x = self.output_layer(x)
        return x
    
    def critic_value(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        return self.forward(state)

In [4]:
class Trajectory:
    def __init__(self, max_size=500):
        self.rewards = deque(maxlen=max_size)
        self.log_probabilities = deque(maxlen=max_size)
        self.values = deque(maxlen=max_size)

class Agent:
    def __init__(self):
        #These want to come from an argument parser.
        self.bootstrap = True
        self.baseline_subtract = True
        
        #And these hyperparams from a config.
        self.gamma = gamma 
        self.depth = depth
        self.update_frequency = update_frequency
        
        self.policy_network = PolicyNet()
        self.critic_network = CriticNet()
        self.loss_function = nn.MSELoss()
        
        self.trajectory = Trajectory()
        
        self.actor_loss = deque(maxlen=self.update_frequency)
        self.critic_loss = deque(maxlen=self.update_frequency)
    
    def calculate_discounted_rewards(self):
        discounted_rewards = []
        rewards = list(self.trajectory.rewards)
        
        if self.bootstrap:
            for t in range(len(rewards)):
                T = min(self.depth, len(rewards)-t)
                discounts = [self.gamma**i for i in range(T)]
                discounted_reward = np.sum([r*d for r,d in zip(rewards[t:t+T], discounts)])
                if not t+T == len(rewards):
                    discounted_reward += (self.gamma**T)*self.trajectory.values[t+T]
                discounted_rewards.append(discounted_reward)
        else:
            for t in range(len(rewards)):
                discounts = [self.gamma**i for i in range(len(rewards)-t)]
                discounted_reward = np.sum([r*d for r,d in zip (rewards[t:], discounts)])
                discounted_rewards.append(discounted_reward)
        return discounted_rewards
        
    def calculate_loss(self):
        discounted_rewards = torch.tensor(self.calculate_discounted_rewards(), dtype=torch.float32, 
                                          device=self.policy_network.device)
        values = torch.cat(list(self.trajectory.values)).squeeze()
        log_probs = torch.cat(list(self.trajectory.log_probabilities)).squeeze()
        
        if self.baseline_subtract:
            advantages = discounted_rewards - values
            actor_loss = -torch.sum(log_probs*advantages.detach())
            
        else:
            actor_loss = -torch.sum(log_probs*discounted_rewards)
        
        critic_loss = self.loss_function(discounted_rewards, values)
        
        self.actor_loss.append(actor_loss)
        self.critic_loss.append(critic_loss)
        
        self.trajectory.rewards.clear()
        self.trajectory.log_probabilities.clear()
        self.trajectory.values.clear()
    
    def learn(self):
        actor_loss = torch.stack(list(self.actor_loss)).squeeze().mean()
        self.policy_network.optimizer.zero_grad()
        actor_loss.backward()
        self.policy_network.optimizer.step()
        self.policy_network.sigma *= 0.996 #Try annealing the exploration. 
        
        critic_loss = torch.stack(list(self.critic_loss)).squeeze().mean()
        self.critic_network.optimizer.zero_grad()
        critic_loss.backward()
        self.critic_network.optimizer.step()

In [24]:
env = gym.make('CartPole-v1')
agent = Agent()
cumulative_rewards = []
print_freq = 50

for episode in range(episodes):
    state = env.reset()
    cumulative_reward = 0
    
    while True:
        # env.render()
        action, log_probability = agent.policy_network.choose_action(state)
        value = agent.critic_network.critic_value(state)
        state_next, reward, done, _ = env.step(action)
        agent.trajectory.rewards.append(reward)
        agent.trajectory.log_probabilities.append(log_probability)
        agent.trajectory.values.append(value)
        state = state_next
        cumulative_reward += reward

        if done:
            agent.calculate_loss()
            cumulative_rewards.append(cumulative_reward)
            if episode % agent.update_frequency == 0 and not episode == 0:
                agent.learn()
                if episode % print_freq == 0:
                    print(f'Episode {episode} Reward {round(np.mean(cumulative_rewards[-print_freq:]))}') #.format(int(episode/agent.update_frequency), cumulative_reward))
            break

Episode 50 Reward 21
Episode 100 Reward 22
Episode 150 Reward 29
Episode 200 Reward 30
Episode 250 Reward 37
Episode 300 Reward 35
Episode 350 Reward 42
Episode 400 Reward 49
Episode 450 Reward 59
Episode 500 Reward 65
Episode 550 Reward 77
Episode 600 Reward 77
Episode 650 Reward 99
Episode 700 Reward 118
Episode 750 Reward 145
Episode 800 Reward 161
Episode 850 Reward 219
Episode 900 Reward 210
Episode 950 Reward 242
Episode 1000 Reward 223
Episode 1050 Reward 264
Episode 1100 Reward 273
Episode 1150 Reward 326
Episode 1200 Reward 383
Episode 1250 Reward 438
Episode 1300 Reward 426
Episode 1350 Reward 312
Episode 1400 Reward 426
Episode 1450 Reward 467
Episode 1500 Reward 412
Episode 1550 Reward 443
Episode 1600 Reward 466
Episode 1650 Reward 468
Episode 1700 Reward 448
Episode 1750 Reward 469
Episode 1800 Reward 462
Episode 1850 Reward 463
Episode 1900 Reward 476
Episode 1950 Reward 477
Episode 2000 Reward 379
Episode 2050 Reward 350
Episode 2100 Reward 352
Episode 2150 Reward 413
E

KeyboardInterrupt: 