## Initializing Stuff 

In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import tqdm as tqdm
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count,
from abc import ABC, abstractmethod
from typing import Any, List, Union

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from gymnasium import Env
from gymnasium import spaces

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<matplotlib.pyplot._IonContext at 0x2549ebff2e0>

## Define Replay Memory

In [3]:
Transition = namedtuple('Transition',
    ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
        

## Define Q-function approximator networks.

In [4]:
class FeedForwardDQN(nn.Module):
    def __init__(self, inp_dim, out_dim):
        super().__init__()
        self.hidden = nn.Linear(inp_dim, 64)
        self.output = nn.Linear(64, out_dim)

        # Helps in creating clones
        self.inp_args = [inp_dim, out_dim]

    def forward(self, state):
        # Only if x is not already a tensor
        if not isinstance(state, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)

        x = F.relu(self.hidden(x))

        return self.output(x)


In [None]:
class ConvDQN(nn.Module):
    def __init__(self, inp_size, out_dim):
        super().__init__()

        # Calculate the conv layers output
        conv_op_dim = inp_size - 3 + 1

        self.conv = nn.Conv2d(1, 128, 3)
        self.maxpool = nn.MaxPool2d(conv_op_dim)

        self.hidden = nn.Linear(128, 64)
        self.output = nn.Linear(64, out_dim)

        # Helps in creating clones
        self.inp_args = [inp_size, out_dim]

    def forward(self, state):
        # Convert to torch tensors
        if not isinstance(state, torch.Tensor):
            x = torch.as_tensor(state)

        # Scale pixel values between 0 and 1
        x = x/255 

        # Handle add the channel dim
        if len(x.shape) == 3:
            x = x.unsqueeze(1)

        x = F.relu(self.conv(x))
        x = self.maxpool(x).flatten(start_dim=1)

        x = self.hidden(F.relu(x))

        return self.output(x)


## Define the trainer

In [None]:
class QLearningTrainer:
    def __init__(self, env: Env, policy_net: nn.Module):
        self.env = env
        self.memory = ReplayMemory(int(1e5))

        # if gpu is to be used
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.policy_net = policy_net.to(self.device)
        # Clone the policy net
        self.target_net = type(policy_net)(*policy_net.inp_args).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
       
        # Cumulative reward of each episode
        self.reward_history: List[float] = list()
        # Number of time steps taken in each episode
        self.episode_durations: Union[List[int], List[float]] = list()

    def ep_greedy_policy(
        self,
        state,
        *,
        random_num: float,
        epsilon: float
    ):
        if random_num > epsilon:
            return self.greedy_policy(state)

        else:
            return torch.tensor(
                [[self.env.action_space.sample()]],
                device=self.device,
                dtype=torch.long
            )

    def greedy_policy(self, state):
        return self.policy_net(state).argmax(1).reshape(1, 1)

    def _optimize(self, lr: float, gamma: float, batch_size: int):
        if len(self.memory) < batch_size:
            return

        optimizer = optim.AdamW(
            self.policy_net.parameters(),
            lr=lr,
            amsgrad=True
        )

        transitions = self.memory.sample(batch_size)
        # Transpose the batch
        batch = Transition(*zip(*transitions))

        # Compute a mask of non-final states and concatenate the batch elements
        non_final_mask = torch.tensor(
            tuple(map(lambda s: s is not None, batch.next_state)),
            device=self.device, dtype=torch.bool
        )

        non_final_next_states = torch.cat(
            [s for s in batch.next_state if s is not None]
        )

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        # Get the policy network's estimations of the current state's action values.
        st_action_values  = self.policy_net(state_batch).gather(1, action_batch)

        # Get the target network's estimations of the next state's action values.
        nx_st_action_values = torch.zeros(self.batch_size, device=self.device)
        with torch.no_grad():
            nx_st_action_values[non_final_mask] = \
                self.target_net(non_final_next_states).max(1)[0]

        target_st_action_values = (nx_st_action_values*gamma) + reward_batch

        # Compute Huber loss
        # We use Huber loss instead of MSE as it gives better stability
        # It behaves like MSE when the error is small and mae when it is large
        criterion = nn.SmoothL1Loss()
        loss = criterion(st_action_values, target_st_action_values.unsqueeze(1))

        # Optimize the model
        optimizer.zero_grad()
        loss.backward()
        # In-place gradient clipping
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        optimizer.step()

    def train(
        self,
        *,
        num_episodes: int = 600,
        lr: float = 1e-4,
        gamma: float = .99,
        ep_min: float = .08,
        ep_max: float = 0.95,
        decay_rate: float = .0008,
        tau: float = 0.005,
        batch_size: int = 16,
        quiet: bool = False
    ) -> List[float]:
        """Train the agent

        Args:
        num_episodes (int, optional): Defaults to 600.
        lr (float, optional): The learning rate for the otimizer. Defaults to 1e-4.
        gamma (float, optional): The discount factor. Defaults to .9.
        ep_min (float, optional): Min exploration probability. Defaults to .08.
        ep_max (float, optional): Max exploration probability. Defaults to 1.
        decay_rate (float, optional): Epsilon decay rate. Defaults to .0008.
        tau (float, optional): The update rate of the target network

        Returns:
        List[float]: Reward earned in every episode.
        """
        for episode in tqdm(range(num_episodes), desc='Train Episode', disable=quiet):
            epsilon = ep_min + (ep_max-ep_min)*np.exp(-decay_rate*episode)

            total_reward = 0
            state, _ = self.env.reset(seed=88)
            state = torch.tensor(
                state, dtype=torch.float32, device=self.device
            ).unsqueeze(0)

            # Run the episode
            for t in count():
                action = self.ep_greedy_policy(state, epsilon)

                next_state, reward, terminated, truncated, _ = self.env.step(action.item())
                reward = torch.tensor([reward], device=self.device)
                done = terminated or truncated

                if terminated:
                    next_state = None
                else:
                    next_state = torch.tensor(
                        next_state, dtype=torch.float32, device=self.device
                    ).unsqueeze(0)

                total_reward += reward
                # Store the transition in memory
                self.memory.push(state, action, next_state, reward)

                state = next_state

                # Otimize the DQN
                self._optimize(lr, gamma, batch_size)

                # Soft update of the target network's weights
                # θ′ ← τ θ + (1 −τ )θ′
                target_net_state_dict = self.target_net.state_dict()
                policy_net_state_dict = self.policy_net.state_dict()

                for key in policy_net_state_dict:
                    target_net_state_dict[key] = \
                        policy_net_state_dict[key]*tau + target_net_state_dict[key]*(1-tau)

                self.target_net.load_state_dict(target_net_state_dict)

                if done:
                    self.episode_durations.append(t + 1)
                    self.plot_durations()
                    self.reward_history.append(total_reward)
                    self.plot_rewards()
                    break

    def evaluate(self, num_episodes: int = 10**2, quiet: bool = False) -> List[float]:
        """Evaluate the agent

        Args:
            num_episodes (int): Defaults to 10**2.

        Returns:
            List[float]: Reward earned in every episode.
        """
        for _ in tqdm(range(num_episodes), desc='Eval Episode: ', disable=quiet):
            total_reward = 0
            state, _ = self.env.reset(seed=88)
            state = torch.tensor(
                state, dtype=torch.float32, device=self.device
            ).unsqueeze(0)

            # Run the episode
            for t in count():
                action = self.greedy_policy(state)

                next_state, reward, terminated, truncated, _ = self.env.step(action.item())
                reward = torch.tensor([reward], device=self.device)
                done = terminated or truncated

                if terminated:
                    next_state = None
                else:
                    next_state = torch.tensor(
                        next_state, dtype=torch.float32, device=self.device
                    ).unsqueeze(0)

                total_reward += reward

                state = next_state

                if done:
                    self.episode_durations.append(t + 1)
                    self.plot_durations()
                    self.reward_history.append(total_reward)
                    self.plot_rewards()
                    break


## Make the Env Wrappers

In [6]:
from a1_env import DeterministicGridEnvironment

class A1EnvWrapper(DeterministicGridEnvironment):
    def reset(self, seed=None, options=None):
        obs, info = super().reset(seed, options)

        state = obs['agent']
        # Convert 2-d state index to flat index
        # flat_idx = curr_row*num_cols+curr_col
        state = torch.tensor(state[0]*self.env.size+state[1])
        # Convert to one-hot form
        state = F.one_hot(
            state,
            num_classes=self.size**2
        ).to(torch.float32)

        # The reward situation
        situation = torch.from_numpy(obs['reward'])

        # Concatenate the state and situation
        state = torch.cat((state, situation))

        return state, info
        
    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)

        state = obs['agent']
        # Convert 2-d state index to flat index
        # flat_idx = curr_row*num_cols+curr_col
        state = torch.tensor(state[0]*self.env.size+state[1])
        # Convert to one-hot form
        state = F.one_hot(
            state,
            num_classes=self.size**2
        ).to(torch.float32)

        # The reward situation
        situation = torch.from_numpy(obs['reward'])

        # Concatenate the state and situation
        state = torch.cat((state, situation))

        return state, reward, terminated, truncated, info
