## Initializing Stuff 

In [9]:
import numpy as np
import pandas as pd
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

from tqdm import tqdm
from itertools import count

## Define the Actor and the  Critic Network

In [10]:
class Actor(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(num_inputs, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_actions)

    def forward(self, state):
        state = state.unsqueeze(0)
        policy_dist = F.relu(self.fc1(state))
        policy_dist = self.fc2(policy_dist)
        return policy_dist

    def act(self, state):
        state = state.unsqueeze(0)
        policy_logits = self.forward(state)
        m = Categorical(logits=policy_logits)
        action = m.sample()
        return action.item(), m.log_prob(action)

class Critic(nn.Module):
    def __init__(self, num_inputs, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(num_inputs, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, state):
        state = state.unsqueeze(0)
        value = F.relu(self.fc1(state))
        value = self.fc2(value)
        return value


## Define the Q-Actor-Critic Trainer

In [11]:
def train_q_actor_critic(
    env: gym.Env,
    actor_net: Actor,
    critic_net: Critic,
    optimizer: optim.Optimizer,
    num_episodes: int = int(1e4),
    disc_factor: float = .99,
    quiet: bool = False,
    cuda: bool = False,
    report_every_n_episodes: int = 10
):
    device = torch.device('cuda' if cuda else 'cpu')
    actor_net.to(device)
    critic_net.to(device)

    all_lengths = []
    average_lengths = []
    all_rewards = []
    average_rewards = []

    prog_bar = tqdm(range(num_episodes), desc='Training Episode', disable=quiet)
    for episode in prog_bar:
        log_probs = []
        values = []
        rewards = []

        state, _ = env.reset()
        state = torch.from_numpy(state).float().to(device)
        for steps in count():
            value = critic_net(state)
            action, log_prob = actor_net.act(state)

            new_state, reward, terminated, truncated, _ = env.step(action)
            new_state = torch.from_numpy(new_state).float().to(device)

            rewards.append(reward)
            values.append(value)
            log_probs.append(log_prob)
            state = new_state
            
            if terminated or truncated:
                # Aggregating
                rewards = torch.FloatTensor(rewards)
                values = torch.stack(values).squeeze()
                log_probs = torch.stack(log_probs).squeeze()

                all_rewards.append(rewards.sum().item())
                all_lengths.append(steps)
                average_lengths.append(np.mean(all_lengths[-report_every_n_episodes:]))
                average_rewards.append(np.mean(all_rewards[-report_every_n_episodes:]))
                if episode % report_every_n_episodes == 0:
                    prog_bar.set_postfix_str(
                        f"reward: {all_rewards[-1]}, len: {steps}, "\
                        f"{report_every_n_episodes} episode average reward: {average_rewards[-1]}, "\
                        f"{report_every_n_episodes} average len: {average_lengths[-1]}"
                    )
                break

        # Value of the last state is always 0
        value_t_p_1 = 0
        # Compute Q values for all the states in the trajectory
        q_values = torch.zeros_like(values)
        for t in reversed(range(len(rewards))):
            q_value_t_p_1 = rewards[t] + disc_factor * value_t_p_1
            q_values[t] = q_value_t_p_1
            # The value for time step t+1
            value_t_p_1 = values[t]


        advantage = q_values - values
        actor_loss = (-log_probs * advantage.detach()).mean()
        criterion = nn.SmoothL1Loss()
        critic_loss = criterion(values, q_values)
        ac_loss = actor_loss + critic_loss

        optimizer.zero_grad()
        ac_loss.backward()
        # In-place gradient clipping to prevent exploding gradients.
        torch.nn.utils.clip_grad_value_(actor_net.parameters(), 10)
        torch.nn.utils.clip_grad_value_(critic_net.parameters(), 10)
        optimizer.step()

    return average_lengths, average_rewards


### Agent Evaluation

In [12]:
def eval_q_actor_critic(
    env: gym.Env,
    actor_net: Actor,
    num_episodes: int = 100,
    quiet: bool = False,
    cuda: bool = False,
    report_every_n_episodes: int = 10
):
    device = torch.device('cuda' if cuda else 'cpu')
    actor_net.to(device)

    all_lengths = []
    average_lengths = []
    all_rewards = []
    average_rewards = []

    prog_bar = tqdm(range(num_episodes), desc='Evaluation Episode', disable=quiet)
    for episode in prog_bar:
        rewards = []

        state, _ = env.reset()
        state = torch.from_numpy(state).float().to(device)
        for steps in count():
            with torch.no_grad():
                action, _ = actor_net.act(state)

                new_state, reward, terminated, truncated, _ = env.step(action)
                new_state = torch.from_numpy(new_state).float().to(device)

                rewards.append(reward)
                state = new_state

            if terminated or truncated:
                # Aggregating
                rewards = torch.FloatTensor(rewards)

                all_rewards.append(rewards.sum().item())
                all_lengths.append(steps)
                average_lengths.append(np.mean(all_lengths[-report_every_n_episodes:]))
                average_rewards.append(np.mean(all_rewards[-report_every_n_episodes:]))
                if episode % report_every_n_episodes == 0:
                    prog_bar.set_postfix_str(
                        f"reward: {all_rewards[-1]}, len: {steps}, "\
                        f"{report_every_n_episodes} episode average reward: {average_rewards[-1]}, "\
                        f"{report_every_n_episodes} average len: {average_lengths[-1]}"
                    )
                break

    return average_lengths, average_rewards
            

## Part 1: CartPole-v1

### Definitions

In [13]:
# Hyperparams
hidden_size = 128
learning_rate = 1e-3

cartenv = gym.make('CartPole-v1')

num_inputs = cartenv.observation_space.shape[0]
num_outputs = cartenv.action_space.n

actor = Actor(num_inputs, num_outputs, hidden_size)
critic = Critic(num_inputs, hidden_size)

params = list(actor.parameters()) + list(critic.parameters())
optimizer = optim.AdamW(params, lr=learning_rate)

### Train the agent

In [14]:
avg_len, avg_rwd = train_q_actor_critic(cartenv, actor, critic, optimizer, int(4.096e3))

Training Episode: 100%|██████████| 4096/4096 [10:42<00:00,  6.37it/s, reward: 500.0, len: 499, average reward: 429.2 , average len: 428.2]


### Evaluate the agent

In [21]:
eval_avg_len, eval_avg_rwd = eval_q_actor_critic(cartenv, actor)

Evaluation Episode: 100%|██████████| 100/100 [00:15<00:00,  6.45it/s, reward: 500.0, len: 499, average reward: 396.4 , average len: 395.4]


In [22]:
# torch.save(actor.state_dict(),'m44_saileshr_assignment3_a2c_actor_cartpolev1.pth')
# torch.save(critic.state_dict(),'m44_saileshr_assignment3_a2c_critic_cartpolev1.pth')

## Part 2.1: LunarLander-v2

### Definitions

In [None]:
# Hyperparams
hidden_size = 128
learning_rate = 1e-3

lunarenv = gym.make("LunarLander-v2")

num_inputs = lunarenv.observation_space.shape[0]
num_outputs = lunarenv.action_space.n

actor = Actor(num_inputs, num_outputs, hidden_size)
critic = Critic(num_inputs, hidden_size)

params = list(actor.parameters()) + list(critic.parameters())
optimizer = optim.AdamW(params, lr=learning_rate)

### Train the agent

In [None]:
avg_len, avg_rwd = train_q_actor_critic(lunarenv, actor, critic, optimizer, int(1e4))

### Evaluate the agent

In [None]:
eval_avg_len, eval_avg_rwd = eval_q_actor_critic(lunarenv, actor)

In [None]:
# torch.save(actor,'m44_saileshr_assignment3_a2c_actor_lunarlander.pth')
# torch.save(critic,'m44_saileshr_assignment3_a2c_critic_lunarlander.pth')