In [2]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as distributions
from tqdm import tqdm

# Actor Model
class Actor(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(Actor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_size),
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        return self.network(state)


# Critic Model
class Critic(nn.Module):
    def __init__(self, state_size, hidden_size):
        super(Critic, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, state):
        return self.network(state)


# A2C Agent
class A2CAgent:
    def __init__(self, env, state_size, action_size, hidden_size):
        self.env = env
        self.actor = Actor(state_size, action_size, hidden_size)
        self.critic = Critic(state_size, hidden_size)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=7e-3)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=7e-3)
        self.gamma = 0.99

    def get_action(self, state):
        state = torch.tensor([state], dtype=torch.float)
        probs = self.actor(state)
        dist = distributions.Categorical(probs)
        action = dist.sample()
        return action.item()

    def train_step(self, state, action, reward, next_state, done):
        state = torch.tensor([state], dtype=torch.float)
        next_state = torch.tensor([next_state], dtype=torch.float)
        action = torch.tensor([action], dtype=torch.int)
        reward = torch.tensor([reward], dtype=torch.float)
        done = torch.tensor([done], dtype=torch.float)

        # Calculate loss
        curr_Q = self.critic(state)
        next_Q = self.critic(next_state)
        expected_Q = reward + self.gamma * next_Q * (1 - done)
        TD = expected_Q - curr_Q

        critic_loss = nn.MSELoss()(curr_Q, expected_Q.detach())
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        probs = self.actor(state)
        dist = distributions.Categorical(probs)
        log_prob = dist.log_prob(action)
        actor_loss = -(log_prob * TD.detach()).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        return actor_loss.item(), critic_loss.item()


# Training Loop
env_name = "CartPole-v0"
env = gym.make(env_name)
env.seed(2000)

state_size = env.observation_space.shape[0]
action_size = env.action_space.n
hidden_size = 32
max_episodes = 300

agent = A2CAgent(env, state_size, action_size, hidden_size)

# Create a tqdm iterator object
progress_bar = tqdm(range(max_episodes), desc="Training Progress")

for episode in progress_bar:
    state = env.reset()
    episode_reward = 0
    done = False

    while not done:
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        aloss, closs = agent.train_step(state, action, reward, next_state, done)
        state = next_state
        episode_reward += reward

        if done:
            progress_bar.set_postfix({'Episode': episode + 1, 'Reward': episode_reward})



  logger.warn(
  deprecation(
  deprecation(
  deprecation(
Training Progress: 100%|██████████| 300/300 [00:22<00:00, 13.37it/s, Episode=300, Reward=9]
