In [None]:
import sys
import torch
import gymnasium as gym
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random

In [None]:
env = gym.make('CartPole-v1')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, n_states, n_actions, lr):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(n_states, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
        self.critic = nn.Sequential(
            nn.Linear(n_states, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, state):
        value = self.critic(state)
        probs = F.softmax(self.actor(state), dim = 1)
        return value, probs

In [None]:
def show_game(model):
    env2 = gym.make('CartPole-v1', render_mode = 'human')
    state = env2.reset()[0]
    done = False
    while not done:
        state = torch.FloatTensor(state).to(device).unsqueeze(0)
        _, probs = model(state)
        action_dist = Categorical(probs)
        action = action_dist.sample().item()
        next_state, reward, terminated, truncated, _ = env2.step(action)
        done = terminated or truncated
        state = next_state
    print()
    return

In [None]:
hidden_size = 256
lr = 3e-4
GAMMA = 0.99
num_steps = 500
max_episodes = 9000
n_actions = env.action_space.n
n_states = env.observation_space.shape[0]

In [None]:
actor_critic = ActorCritic(n_states, n_actions, lr).to(device)
optimizer = optim.AdamW(actor_critic.parameters(), lr = lr, amsgrad = True)

all_lengths = []
average_lengths = []
all_rewards = []
entropy_term = 0

for i_episode in range(max_episodes):
    log_probs = []
    values = []
    rewards = []
    
    state = env.reset()[0]
    
    for steps in range(num_steps):
        state = torch.tensor(state, dtype = torch.float, device = device).unsqueeze(0)
        value, probs = actor_critic.forward(state)
        probs_np = probs.detach().to('cpu').numpy()
        
        action_dist = Categorical(probs)
        action = action_dist.sample().item()
        
        log_prob = torch.log(probs.squeeze(0)[action])
        entropy = -np.sum(np.mean(probs_np) * np.log(probs_np))
        next_state, reward, terminated, truncated, _ = env.step(action)
        
        done = terminated or truncated
        
        rewards.append(reward)
        values.append(value.detach().to('cpu'))
        log_probs.append(log_prob)
        entropy_term += entropy
        state = next_state
        
        if done:
            next_state = torch.tensor(next_state, dtype = torch.float, device = device).unsqueeze(0)
            Qval, _ = actor_critic(next_state)
            Qval_np = Qval.detach().to('cpu').numpy()[0,0]
            all_rewards.append(np.sum(rewards))
            all_lengths.append(steps)
            average_lengths.append(np.mean(all_lengths[-10:]))
            if i_episode % 10 == 0:
                print(f"Episode:{i_episode} | Reward: {np.sum(rewards)} | Total_Length: {steps} | Average_Length: {average_lengths[-1]}")
            break
    
    Qvals = np.zeros_like(values)
    for t in reversed(range(len(rewards))):
        Qval = rewards[t] + GAMMA * Qval
        Qvals[t] = Qval
    Qvals = list(Qvals)
    
    values = torch.tensor(values, dtype = torch.float, device = device)
    Qvals = torch.tensor(Qvals, dtype = torch.float, device = device)
    log_probs = torch.stack(log_probs)
    
    advantage = Qvals - values
    actor_loss = (-log_probs*advantage).mean()
    critic_loss = 0.5*advantage.pow(2).mean()
    ac_loss = actor_loss + critic_loss + 0.001*entropy_term
    
    optimizer.zero_grad()
    ac_loss.backward()
    optimizer.step()
    
    if (i_episode+1) % 500 == 0:
        smoothed_rewards = pd.Series.rolling(pd.Series(all_rewards), 10).mean()
        smoothed_rewards = [x for x in smoothed_rewards]
        plt.plot(all_rewards)
        plt.plot(smoothed_rewards)
        plt.plot()
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.show()
        
        plt.plot(all_lengths)
        plt.plot(average_lengths)
        plt.xlabel('Episode')
        plt.ylabel('Episode Length')
        plt.show()
    if (i_episode+1) % 500 == 0:
        show_game(actor_critic)