In [131]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions.categorical import Categorical
import gymnasium as gym

In [132]:
class Actor(nn.Module):
    def __init__(self, layer_sizes, activation, output_activation) -> None:
        super().__init__()
        layers = []
        num_of_layers = len(layer_sizes)
        for i in range(num_of_layers - 1):
            layer = nn.Linear(layer_sizes[i], layer_sizes[i+1])
            activation_function = activation if i < (num_of_layers - 2) else output_activation
            layers += [layer, activation_function()]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        action_probs = self.model(x)
        return Categorical(probs=action_probs)

    def update(self, optimizer, loss):
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()


In [133]:
class Critic(nn.Module):
    def __init__(self, layer_sizes, activation, output_activation) -> None:
        super().__init__()
        layers = []
        num_of_layers = len(layer_sizes)
        for i in range(num_of_layers - 1):
            layer = nn.Linear(layer_sizes[i], layer_sizes[i+1])
            activation_function = activation if i < (num_of_layers - 2) else output_activation
            layers += [layer, activation_function()]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def update(self, optimizer, loss):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [134]:
from statistics import mean

# Agent
class ActorCritic: 
    def __init__(self, env, actor_lr=1e-3, critic_lr=3e-4, gamma=0.99) -> None:
        self.env = env
        self.actor = Actor([env.observation_space.shape[0], 64, 64, env.action_space.n], nn.ReLU, nn.Softmax)
        self.critic = Critic([env.observation_space.shape[0], 64, 64, 1], nn.ReLU, nn.Identity)
        self.actor_optimizer = Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma

    def sample_action(self, obs):
        obs = torch.tensor(obs, dtype=torch.float32)
        return self.actor(obs).sample().item()

    def compute_actor_loss(self, observations, actions, weights):
        observations = torch.stack([torch.tensor(obs, dtype=torch.float32) for obs in observations])
        actions = torch.tensor(actions, dtype=torch.float32)
        logp = self.actor(observations).log_prob(actions)
        return -(logp * weights).mean()
        
    def compute_critic_loss(self, values, rewards):
        rewards = torch.tensor(rewards, dtype=torch.float32)
        values_tensor = torch.stack(values).squeeze()
        return F.mse_loss(rewards, values_tensor)
        
    def compute_action_values(self, rewards, gamma, values):
        action_values = np.zeros_like(rewards)
        for t in reversed(range(len(rewards))):
            action_value = rewards[t] + gamma * values[t]
            action_values[t] = action_value
        return torch.tensor(action_values, dtype=torch.float32)

    def compute_advantage(self, action_values, values):
        values_tensor = torch.stack(values)
        return action_values - values_tensor.squeeze()

    def train(self, epochs=100, episodes=100):
        for epoch in range(epochs):
            returns, lengths = [], []
            for episode in range(episodes):
                observations, actions, values = [], [], []
                obs, info  = self.env.reset()
                terminated = truncated = False
                rewards = []
                while not terminated and not truncated:
                    observations.append(obs)
                    action = self.sample_action(obs)
                    actions.append(action)
                    value = self.critic(torch.tensor(obs.copy(), dtype=torch.float32))
                    values.append(value)
                    obs, reward, terminated, truncated, info = self.env.step(action)
                    rewards.append(reward)
                value = self.critic(torch.tensor(obs.copy(), dtype=torch.float32))
                action_values = self.compute_action_values(rewards, 0.99, values)
                advantages = self.compute_advantage(action_values, values)
                actor_loss = self.compute_actor_loss(observations, actions, advantages)
                critic_loss = self.compute_critic_loss(values, rewards)
                self.actor.update(self.actor_optimizer, actor_loss)
                self.critic.update(self.critic_optimizer, critic_loss)
                ep_return = sum(rewards)
                returns.append(ep_return)
                lengths.append(len(rewards))
            print(f"Epoch: {epoch} Return: {mean(returns)}")

In [135]:
env = gym.make("CartPole-v1")
ac = ActorCritic(env)
ac.train()

  input = module(input)


Epoch: 0 Return: 16.89
Epoch: 1 Return: 12.75
Epoch: 2 Return: 10.52
Epoch: 3 Return: 9.91
Epoch: 4 Return: 9.56
Epoch: 5 Return: 9.69
Epoch: 6 Return: 9.63


KeyboardInterrupt: 