In [None]:
import os 

import gym

import torch 
import torch.nn as nn

import numpy as np
from collections import deque

In [None]:
episodes = 10000
hidden_size = 64
observation_space = 4
n_actions = 2

learning_rate = 0.001
gamma = 1
sigma = 0.1
update_frequency = 10

In [None]:
class PolicyNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #Hyperparams to be pushed into a config.
        self.sigma = sigma
        self.learning_rate = learning_rate

        input_shape = observation_space
        output_shape = n_actions
        
        self.device = 'cpu' 
        self.input_shape = observation_space
        self.output_shape = output_shape

        self.input_layer = nn.Linear(self.input_shape, hidden_size)
        self.hidden_layer = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, self.output_shape)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        self.to(self.device)

    def forward(self, x):
        x = torch.nn.functional.relu(self.input_layer(x))
        x = torch.nn.functional.relu(self.hidden_layer(x))
        x = self.output_layer(x)

        return torch.nn.functional.softmax(x, dim=1)

    def choose_action(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
    
        probabilities = self.forward(state) + torch.normal(torch.tensor(0.0), torch.tensor(sigma), size=(1,2))
        probabilities[probabilities < 0] = 0
        
        model = torch.distributions.Categorical(probabilities)
        action = model.sample()
        return action.item(), model.log_prob(action)

In [None]:
class Trajectory:
    def __init__(self, max_size=500):
        self.rewards = deque(maxlen=max_size)
        self.log_probabilities = deque(maxlen=max_size)

class Agent:
    def __init__(self):
        #Hyperparams to be pushed into a config.
        self.update_frequency = update_frequency
        self.gamma = gamma
        
        self.policy_network = PolicyNet()
        self.trajectory = Trajectory()
        self.loss = deque(maxlen=self.update_frequency)
        
    def calculate_loss(self):
        rewards = torch.tensor(self.trajectory.rewards, dtype=torch.float32, device=self.policy_network.device)
        discounts = [self.gamma**i for i in range(len(rewards)+1)]
        discounted_total_reward = np.sum([d*r for d,r in zip(discounts, rewards)])
        
        loss_storage = []
        for log_prob in self.trajectory.log_probabilities:
            loss_storage.append(-log_prob * discounted_total_reward)
        self.loss.append(torch.cat(loss_storage).sum())
        
        self.trajectory.rewards.clear()
        self.trajectory.log_probabilities.clear()
    
    def learn(self):
        loss = torch.mean(torch.stack([l for l in self.loss]))
        self.policy_network.optimizer.zero_grad()
        loss.backward()
        self.policy_network.optimizer.step()
        self.policy_network.sigma *= 0.996 #Try annealing the exploration. 

In [None]:
env = gym.make('CartPole-v1')
agent = Agent()

for episode in range(episodes):
    state = env.reset()
    cumulative_reward = 0
    
    while True:
        # env.render()
        action, log_probability = agent.policy_network.choose_action(state)
        state_next, reward, done, _ = env.step(action)
        agent.trajectory.rewards.append(reward)
        agent.trajectory.log_probabilities.append(log_probability)
        state = state_next
        cumulative_reward += reward

        if done:
            agent.calculate_loss()
            if episode % agent.update_frequency == 0:
                agent.learn()
                print("Episode {} Reward {}".format(int(episode/agent.update_frequency), cumulative_reward))
            break