In [None]:
import torch
import torch.nn as nn
from torch import optim
import numpy as np
import cv2
from gymnasium import make, Env
import timeit
from random import random

CONSTANTS

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CHANNELS = 4  # 4 if using stacked frames, 1 if not
WINDOW_SIZE = 80  # We are using squared preprocessed images for the network, this is the size of the image in pixels.

NUM_ACTIONS = 3  # Number of actions the agent can take
INPUT_SHAPE = (CHANNELS, WINDOW_SIZE, WINDOW_SIZE)  # PyTorch uses (channels, height, width) format

# TODO: Fine tuning
LEARNING_RATE = 0.0005
MIN_MEMORY_CAPACITY = 100  # This should be at least BATCH_SIZE
MEMORY_CAPACITY = 100_000
NUM_EPISODES = 200
BATCH_SIZE = 32
UPDATE_FREQUENCY = 1

# These might be good
GAMMA = 0.99
EPSILON_MAX = 1.0
EPSILON_MIN = 0.01
EPSILON_DECAY = 0.995

PREPROCESSING

In [None]:
def crop(frame: np.ndarray) -> np.ndarray:
    """
    Crops the frame image to the relevant part of the screen.
    :param frame: the frame(state) image
    :return: the cropped image
    """
    # Exact crop [30:180, 8:152]
    # Rounded crop [30:180, 10:150]
    # Maybe try with both of them
    return frame[30:180, 8:152]


def resize(frame: np.ndarray) -> np.ndarray:
    """
    Resizes the frame image.
    :param frame: the frame(state) image
    :return: the resized image
    """
    state = cv2.resize(frame, (WINDOW_SIZE, WINDOW_SIZE))
    return state


def rgb2gray(rgb: np.ndarray) -> np.ndarray:
    """
    Converts a rgb image array to a grey image array.

    :param rgb: the rgb image array.
    :return: the converted array.
    """
    grayscale = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
    return grayscale


def format2pytorch(frame: np.ndarray) -> np.ndarray:
    """
    Formats the frame image to be used with PyTorch. It does this by adding a new axis to the image array.
    (int, int) -> (1, int, int) for PyTorch
    :param frame: the frame(state) image
    :return: the formatted image
    """
    return frame[np.newaxis, :, :]


def normalize(frame: np.ndarray) -> np.ndarray:
    """
    Normalizes the frame image.
    :param frame: the frame(state) image
    :return: the normalized image
    """
    frame = frame.astype(np.float32)
    frame /= 255.0
    return frame


def preprocess(frame: np.ndarray) -> np.ndarray:
    """
    Preprocesses the frame image.
    :param frame: the frame(state) image
    :return: the preprocessed image
    """
    frame = crop(frame)
    frame = resize(frame)
    frame = rgb2gray(frame)
    # frame = format2pytorch(frame)
    frame = normalize(frame)
    return frame

REPLAY MEMORY

In [None]:
class ReplayMemory:
    def __init__(self):
        self.capacity = MEMORY_CAPACITY
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.next_states = []
        self.index: int = 0

    def store(self, state, action, reward, done, next_state):
        if len(self.states) < self.capacity:
            self.states.append(state)
            self.actions.append(action)
            self.rewards.append(reward)
            self.dones.append(done)
            self.next_states.append(next_state)
        else:
            self.states[self.index] = state
            self.actions[self.index] = action
            self.rewards[self.index] = reward
            self.dones[self.index] = done
            self.next_states[self.index] = next_state

        self.index = (self.index + 1) % self.capacity

    def get_last3_frames(self):
        return [self.states[-3], self.states[-2], self.states[-1]]

    def _index_valid(self, index):
        if any(self.dones[i] for i in range(index - 3, index + 1)):
            return False
        return True

    def sample(self):
        states = []
        actions = []
        rewards = []
        dones = []
        next_states = []

        while len(states) < BATCH_SIZE:
            index = np.random.randint(4, len(self) - 1)
            if self._index_valid(index):
                states.append(
                    [self.states[index - 3], self.states[index - 2], self.states[index - 1], self.states[index]]
                )
                next_states.append(
                    [
                        self.next_states[index - 2],
                        self.next_states[index - 1],
                        self.next_states[index],
                        self.next_states[index + 1],
                    ]
                )
                actions.append(self.actions[index])
                rewards.append(self.rewards[index])
                dones.append(self.dones[index])

        states = torch.from_numpy(np.array(states)).float().to(DEVICE)
        actions = torch.from_numpy(np.array(actions)).to(DEVICE).reshape((-1, 1))
        rewards = torch.from_numpy(np.array(rewards)).float().to(DEVICE).reshape((-1, 1))
        dones = torch.from_numpy(np.array(dones)).to(DEVICE).reshape((-1, 1))
        next_states = torch.from_numpy(np.array(next_states)).float().to(DEVICE)

        return states, actions, rewards, dones, next_states

    def __len__(self):
        return len(self.states)

NET

In [None]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(INPUT_SHAPE[0], 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.flatten = nn.Flatten()

        self.fc = nn.Linear(64 * 6 * 6, 512)
        self.output = nn.Linear(512, NUM_ACTIONS)

        self.relu = nn.ReLU()

        # TODO: Maybe use RMSProp?
        self.optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE)
        # self.optimizer = optim.RMSprop(self.parameters(), lr=LEARNING_RATE)
        self.loss = nn.MSELoss()

    def _forward_features(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        return x

    def forward(self, x):
        x = self._forward_features(x)
        x = self.flatten(x)
        x = self.relu(self.fc(x))
        x = self.output(x)
        return x


AGENT

In [None]:
class Agent:
    def __init__(self, action_space):
        self.action_space = action_space
        self.epsilon: float = EPSILON_MAX
        self.replay_memory: ReplayMemory = ReplayMemory()

        self.policy_net: DQN = DQN().to(DEVICE)
        self.target_net: DQN = DQN().to(DEVICE)
        self.update_target_net()

        self.policy_net.train()
        self.target_net.eval()

        self.total_loss = 0.0

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def select_action(self, state):
        if random() < self.epsilon:
            return self.action_space.sample()

        last3_frames = self.replay_memory.get_last3_frames()
        stacked_frames = last3_frames + [state]
        stacked_frames = torch.from_numpy(np.array(stacked_frames)).float().unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            action = torch.argmax(self.policy_net(stacked_frames))

        return action.item()

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon * EPSILON_DECAY, EPSILON_MIN)

    def set_loss(self, loss):
        self.total_loss = loss

    def please_learn(self):
        if len(self.replay_memory) < BATCH_SIZE:
            return

        states, actions, rewards, dones, next_states = self.replay_memory.sample()

        predicted_qs = self.policy_net(states).gather(1, actions)
        target_qs = self.target_net(next_states)
        target_qs = torch.max(target_qs, dim=1).values.reshape(-1, 1)
        target_qs[dones] = 0.0
        target_qs = rewards + (GAMMA * target_qs)

        loss = self.policy_net.loss(predicted_qs, target_qs)
        self.policy_net.optimizer.zero_grad()
        loss.backward()
        self.policy_net.optimizer.step()

        self.total_loss += loss.item()

    def save(self):
        check_if_dirs_exist([MODELS_PATH])
        torch.save(self.policy_net.state_dict(), POLICY_NET_PATH)
        torch.save(self.target_net.state_dict(), TARGET_NET_PATH)

    def load(self):
        self.policy_net.load_state_dict(torch.load(POLICY_NET_PATH))
        self.target_net.load_state_dict(torch.load(TARGET_NET_PATH))
        self.target_net.eval()


MAIN

In [None]:
def reset(env: Env):
    state, _info = env.reset()
    state = preprocess(state)
    return state


def step(env: Env, action: int):
    next_state, reward, terminated, truncated, info = env.step(action)
    next_state = preprocess(next_state)
    done = terminated or truncated
    return next_state, reward, done, info


def init_memory(env: Env, agent: Agent):
    while len(agent.replay_memory) < max(MIN_MEMORY_CAPACITY, BATCH_SIZE):
        state = reset(env)
        done = False
        while not done:
            action = agent.select_action(state)
            next_state, reward, done, info = step(env, action)
            agent.replay_memory.store(state, action, reward, done, next_state)
            state = next_state


def train(env: Env, agent: Agent):
    init_memory(env, agent)
    print(f"Memory initialized with {BATCH_SIZE} samples! The training shall begin! Let's rock!")

    reward_history = []
    best_score = -np.inf

    for episode in range(NUM_EPISODES):
        start_time = timeit.default_timer()
        state = reset(env)
        done = False
        episode_reward = 0
        step_counter = 0
        agent.set_loss(0.0)

        while not done:
            action = agent.select_action(state)
            next_state, reward, done, info = step(env, action)
            agent.replay_memory.store(state, action, reward, done, next_state)
            agent.please_learn()

            state = next_state
            episode_reward += reward
            step_counter += 1

        # We update the target net every episode, one episode has around 4k steps
        agent.update_target_net()

        agent.decay_epsilon()
        reward_history.append(episode_reward)

        current_avg_score = np.mean(reward_history[-20:])  # moving average over last 20 episodes

        print(
            f"Episode {episode + 1} | Reward: {episode_reward} | Avg Reward: {current_avg_score} | Epsilon: {agent.epsilon}"
        )
        print(f"Avg Loss: {agent.total_loss / max(step_counter, 1)} | Steps: {step_counter}")
        print(f"Episode {episode + 1} took {timeit.default_timer() - start_time} seconds.")
        print("-" * 100)

        if current_avg_score > best_score:
            best_score = current_avg_score


if __name__ == "__main__":
    env: Env = make("ALE/Skiing-v5")
    agent = Agent(action_space=env.action_space)
    train(env, agent)