In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import gym
from torch import nn
import matplotlib.pyplot as plt

In [2]:
class Agent(nn.Module):
    def __init__(self, state_dim, n_actions, n_hidden = 32):
        super().__init__()
        self.state_extractor = nn.Linear(state_dim, n_hidden)
        
        self.pre_actor = nn.Linear(n_hidden, n_hidden)
        self.actor = nn.Linear(n_hidden, n_actions)
        
        self.pre_critic = nn.Linear(n_hidden, n_hidden)
        self.critic = nn.Linear(n_hidden, 1)
        
    def forward(self, x):
        state_embedding = F.leaky_relu(self.state_extractor(x))
        
        action_logits = F.leaky_relu(self.pre_actor(state_embedding))
        action_logits = self.actor(action_logits)
        
        value = F.leaky_relu(self.pre_critic(state_embedding))
        value = self.critic(value)
        return action_logits, value.squeeze(-1), state_embedding
    
    def sample_action(self, state):
        action_logits, value, state_embedding = self(state)
        action_probs = F.softmax(action_logits, dim = -1)
        dist = torch.distributions.Categorical(probs=action_probs)
        action = dist.sample()
        log_probs = dist.log_prob(action)
        entropy = dist.entropy()
        return action, value, log_probs, entropy, state_embedding
    
def t(x): return torch.from_numpy(x).float()

In [3]:
env = gym.make("CartPole-v0").unwrapped


In [4]:
n_episodes = 500
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = Agent(state_dim, n_actions)
adam = torch.optim.Adam(agent.parameters())
cos_lr = torch.optim.lr_scheduler.OneCycleLR(adam, max_lr = 1e-3, 
                                             total_steps = n_episodes, 
                                             final_div_factor = 100)
gamma = 0.99

In [5]:
episode_rewards = []

for i in range(n_episodes):
    done = False
    total_reward = 0
    state = env.reset()
    av_rewards = []
    entropy_factor = 0.01

    while not done:
        
        # observe state and take action
        action, value, log_probs, entropy, state_embedding = agent.sample_action(t(state))
        # get rewarded for action taken
        next_state, reward, done, info = env.step(action.detach().data.numpy())
        
        total_reward += reward
        state = next_state
        
        # compute advantage: reward + sum of discounted rewards - value
        advantage = reward
        if not done:
            _, future_value, _, _, _ = agent.sample_action(t(next_state))
            advantage += gamma * future_value
        advantage -= value
        
        critic_loss = advantage.pow(2).mean()
        actor_loss = -log_probs * (advantage.detach())
        loss = critic_loss + 0.5 * actor_loss - entropy_factor * entropy
        
        adam.zero_grad()
        loss.backward()
        adam.step()
        
        env.render()
    episode_rewards.append(total_reward)
    average_100_games = np.mean(episode_rewards[-100:])
    entropy_factor *= 0.99
    
    if i % 100 == 0:
        print(f"Average reward of last 100 games {average_100_games}")
        
    if average_100_games > 195:
        print("VICTORY")
        break
    
    cos_lr.step()

Average reward of last 100 games 24.0
Average reward of last 100 games 20.33
Average reward of last 100 games 44.38
Average reward of last 100 games 124.76
VICTORY
