# Import Packages for Actor-Critic

In [1]:
!pip install gymnasium

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import gymnasium as gym

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

from tqdm import tqdm
import matplotlib.pyplot as plt

# We'll use OmegaConf to manage hyperparameters!

In [3]:
!pip install omegaconf

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
from omegaconf import OmegaConf

# Environment
- CartPole-v1

In [5]:
env = gym.make('CartPole-v1')
env

<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>

# Hyperparameters

In [6]:
AC_config = OmegaConf.create({
    # DQN parameters
    'gamma': 0.99,
    
    # policy network parameters
    'device': 'cuda:0',
    'hidden_dim': 64,
    'state_dim': env.observation_space.shape[0],
    'action_dim': int(env.action_space.n),

    # learning parameters
    'learning_rate': 0.0001,
    'n_rollout': 10,
})

In [7]:
class ActorCritic(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.data = []
        self.config = config
        
        # actor network
        self.actor = nn.Sequential(
            nn.Linear(self.config.state_dim, self.config.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.config.hidden_dim, self.config.action_dim),
            nn.Softmax(dim=-1)
        )
        
        # critic network
        self.critic = nn.Sequential(
            nn.Linear(self.config.state_dim, self.config.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.config.hidden_dim, 1)
        )
        
        # load them to gpu (if available)
        self.to(self.config.device)
        
        # optimizer
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.config.learning_rate)
    
    def actor_forward(self, state):
        return self.actor(state)

    def critic_forward(self, state):
        return self.critic(state)
    
    # add transition data
    def put_data(self, transition):
        self.data.append(transition)
        
    # convert transitions to batch data
    def make_batch(self):
        state_list, action_list, reward_list, next_state_list, terminated_list = [], [], [], [], []
        for transition in self.data:
            state, action, reward, next_state, termindated = transition
            
            state_list.append(list(state))
            action_list.append([action])
            reward_list.append([reward / 100.0])
            next_state_list.append(list(next_state))
            termindated_mask = 0.0 if termindated else 1.0
            terminated_list.append([termindated_mask])
        
        state_batch = torch.tensor(state_list, dtype=torch.float).to(self.config.device)
        action_batch = torch.tensor(action_list).to(self.config.device)
        reward_batch = torch.tensor(reward_list, dtype=torch.float).to(self.config.device)
        next_state_batch = torch.tensor(next_state_list, dtype=torch.float).to(self.config.device)
        terminated_batch = torch.tensor(terminated_list, dtype=torch.float).to(self.config.device)
        
        # clear buffer
        self.data = []
        
        return state_batch, action_batch, reward_batch, next_state_batch, terminated_batch
  
    def update(self):
        # get data using self.make_batch()
        states, actions, rewards, next_states, terminated = self.make_batch()

        # compute TD target
        td_target = rewards + self.config.gamma * self.critic(next_states) * terminated

        # compute TD error
        td_error = td_target - self.critic(states)

        # compute Actor loss
        action_probs = self.actor(states).gather(1, actions)
        actor_loss = -torch.log(action_probs) * td_error.detach()

        # compute Critic loss
        critic_loss = F.smooth_l1_loss(self.critic(states), td_target.detach())

        # Aggregate
        loss = actor_loss + critic_loss

        # backpropagation
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

In [None]:
num_epis, epi_rews = 10000, []
agent = ActorCritic(AC_config)

for n_epi in tqdm(range(num_epis)):
    state, _ = env.reset()
    terminated, truncated = False, False
    epi_rew = 0
    
    while not (terminated or truncated):
        for t in range(AC_config.n_rollout):
            # get action probs from actor & sample -- use Categorical!
            action_probs = agent.actor(torch.from_numpy(state).float().to(AC_config.device))
            action_dist = Categorical(action_probs)
            action = action_dist.sample()

            # step
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            
            # collect transition
            agent.put_data((state, action, reward, next_state, terminated or truncated))
            
            # state transition
            state = next_state
            
            # record reward
            epi_rew += reward
            
            if terminated or truncated:
                break
            
        # update
        agent.update()
        
    # record
    epi_rews += [epi_rew]
    
env.close()

 15%|█▌        | 1516/10000 [01:16<13:10, 10.73it/s]

In [None]:
plt.figure(figsize=(20, 10), dpi=300)
plt.plot(epi_rews, label='episode returns')
plt.legend(fontsize=20)
plt.show()
plt.close()