In [47]:
import gymnasium as gym
import torch
import torch.nn as nn
from skimage.color import rgb2gray
from skimage.transform import resize
import numpy as np
from collections import deque, namedtuple
import random

In [48]:
env = gym.make("Pong-v4")
n_actions = env.action_space.n

In [49]:
device = torch.device("mps" if torch.has_mps else "cpu")

In [50]:
obs_sample = env.observation_space.sample()

In [51]:
def convert_observation(observation):
    """Converts the observation from 210x160 RGB numpy to 84x84 grayscale torch Tensor"""
    transformed_obs = rgb2gray(observation)
    transformed_obs = resize(transformed_obs, (84, 84), mode="constant")
    return torch.from_numpy(transformed_obs.astype(np.float32)).to(device)

In [52]:
Transition = namedtuple(
    "Transition", ("state_stack", "action", "reward", "next_state_stack")
)


class ReplayMemory:
    def __init__(self, capacity: int) -> None:
        self.memory = deque(maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [53]:
class PongModel(nn.Module):
    def __init__(self, n_actions: int):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        self.flatten = nn.Flatten(start_dim=1)

        self.linear_layers = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512), nn.ReLU(), nn.Linear(512, n_actions)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.flatten(x)
        x = self.linear_layers(x)
        return x

In [54]:
BATCH_SIZE = 128
GAMMA = 0.99
TAU = 0.05

In [55]:
class PongAgent:
    def __init__(self, env: gym.Env, epsilon: float) -> None:
        self.env = env
        self.main = PongModel(env.action_space.n).to(device)  # type: ignore
        self.target = PongModel(env.action_space.n).to(device)  # type: ignore

        self.optimiser = torch.optim.Adam(params=self.main.parameters(), lr=0.01)

        self.epsilon = epsilon

        self.state_stack = deque(maxlen=4)
        self.memory = ReplayMemory(capacity=40_000)

    def get_action(self, observation) -> torch.Tensor:
        if len(self.state_stack) <= 4 or random.random() < self.epsilon:
            return torch.tensor(
                [[self.env.action_space.sample()]], dtype=torch.long, device=device
            )

        with torch.no_grad():
            x = torch.argmax(self.main(observation))
            print(x)
            return x

    def optimise(self):
        if len(self.memory) < BATCH_SIZE:
            return

        transitions = self.memory.sample(BATCH_SIZE)

        # This converts batch-array of Transitions to Transition of batch-arrays.
        batch = Transition(*zip(*transitions))
        # print(batch.state_stack[0].shape)

        state_batch = torch.stack(batch.state_stack)
        next_state_batch = torch.stack(batch.next_state_stack)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net

        # print(state_batch.shape)
        state_action_values = self.main(state_batch).gather(1, action_batch)

        with torch.no_grad():
            next_state_values = self.target(next_state_batch).max(1)[0]

        expected_state_action_values = (next_state_values * GAMMA) + reward_batch

        # Compute Huber loss
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        self.optimiser.zero_grad()
        loss.backward()
        # In-place gradient clipping
        nn.utils.clip_grad.clip_grad_value_(self.main.parameters(), 100)
        self.optimiser.step()

    def train(self, n_episodes: int):
        # populate state stack
        observation, _ = self.env.reset()
        self.state_stack.append(convert_observation(observation))

        # populate state stack
        while len(self.state_stack) < 1000:
            action = self.get_action(self.stack_observations())
            next_observation, _, _, _, _ = self.env.step(action)
            self.state_stack.append(convert_observation(next_observation))

        for _ in range(n_episodes):
            observation, _ = self.env.reset()
            self.state_stack.append(convert_observation(observation))
            stacked_state = self.stack_observations()
            done = False
            step = 0
            summed_reward = 0
            while not done:
                step += 1
                # Get next action
                action = self.get_action(self.stack_observations())

                # Apply action to env
                next_observation, reward, terminated, truncated, info = self.env.step(
                    action
                )
                summed_reward += reward
                done = terminated or truncated
                # Push new obervation to state stack
                self.state_stack.append(convert_observation(next_observation))
                # Create the next state stack
                next_stacked_state = self.stack_observations()

                # Store in replay memory
                self.memory.push(
                    stacked_state,
                    action,
                    torch.tensor([reward], device=device),
                    next_stacked_state,
                )
                self.optimise()
                if step % 1000 == 0:
                    self.target.load_state_dict(self.main.state_dict())
            print(summed_reward)

    def stack_observations(self):
        return torch.stack(list(self.state_stack))

In [57]:
agent = PongAgent(env, 0.5)
agent.train(100)

-20.0
-21.0
-20.0


KeyboardInterrupt: 

In [None]:
agent.epsilon = 0

agent.train(10)

-21.0
-20.0
-21.0
-20.0
-21.0
-21.0
-20.0
-20.0
-21.0
-21.0


In [59]:
agent.env = gym.make("Pong-v4", render_mode="human")
agent.train(1)

-21.0


In [None]:
env.close()