### Load libraries

In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import tqdm
import random
from collections import deque, namedtuple
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

### Define the policy network

In [24]:
class QNet(nn.Module):
    def __init__(self, n_states, n_actions, n_hidden=64):
        super(QNet, self).__init__()
        self.fc1 = nn.Linear(n_states, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.fc3 = nn.Linear(n_hidden, n_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    

### Define the replay buffer

In [25]:
class ReplayBuffer():
    def __init__(self, n_actions, memory_size, batch_size):
        self.n_actions = n_actions
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.memory = deque(maxlen=memory_size)
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

    def __len__(self):
        return len(self.memory)
    
    def add(self, state, action, reward, next_state, done):
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)

    def sample(self):
        experiences = random.sample(self.memory, k=self.batch_size)
        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)

        return (states, actions, rewards, next_states, dones)

### Define the DQN agent

In [26]:
class DQN():
    def __init__(self, n_states, n_actions, batch_size=64, lr=1e-4, gamma=0.99, memory_size=int(1e5), tau=1e-3, learn_step=5):
        self.n_states = n_states
        self.n_actions = n_actions
        self.batch_size = batch_size
        self.lr = lr
        self.gamma = gamma
        self.memory_size = memory_size
        self.tau = tau
        self.learn_step = learn_step

        # model
        self.net_eval = QNet(n_states, n_actions).to(device)
        self.net_target = QNet(n_states, n_actions).to(device)
        self.optimizer = optim.Adam(self.net_eval.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

        # memory
        self.memory = ReplayBuffer(n_actions, memory_size, batch_size)
        self.counter = 0 # Update cycle counter
    
    def act(self, state, epsilon):
        # print(state)
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)

        self.net_eval.eval()
        with torch.no_grad():
            action_values = self.net_eval(state)
        self.net_eval.train()

        # epsilon-greedy policy
        if random.random() < epsilon:
            action = random.choice(np.arange(self.n_actions))
        else:
            action = np.argmax(action_values.cpu().data.numpy())
        
        return action
    
    def save_to_memory(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)

        self.counter += 1
        if self.counter % self.learn_step == 0:
            if len(self.memory) > self.batch_size:
                experiences = self.memory.sample()
                self.learn(experiences)
    
    def learn(self, experiences):
        states, actions, rewards, next_states, dones = experiences

        q_target = self.net_target(next_states).detach().max(axis=1)[0].unsqueeze(1)
        y_j = rewards + self.gamma * (1 - dones) * q_target
        q_eval = self.net_eval(states).gather(1, actions)

        # loss backpropagation
        loss = self.criterion(q_eval, y_j)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target network
        self.soft_update()

    def soft_update(self):
        for target_param, eval_param in zip(self.net_target.parameters(), self.net_eval.parameters()):
            target_param.data.copy_(self.tau * eval_param.data + (1.0 - self.tau) * target_param.data)

### Define the train and test functions

In [27]:
def train(env, agent, n_episodes=2000, max_steps=1000, eps_start=1.0, eps_end=0.1, eps_decay=0.995, target=200):
    score_history = []
    epislon = eps_start

    bar_format = '{l_bar}{bar:10}| {n:4}/{total_fmt} [{elapsed:>7}<{remaining:>7}, {rate_fmt}{postfix}]'
    progress_bar = tqdm.trange(n_episodes, unit="ep", bar_format=bar_format, ascii=True)

    for i in progress_bar:
        state, _ = env.reset()

        score = 0
        for t in range(max_steps):
            action = agent.act(state, epislon)
            next_state, reward, done, _, _ = env.step(action)
            agent.save_to_memory(state, action, reward, next_state, done)
            state = next_state
            score += reward
            if done:
                break
        score_history.append(score)
        score_avg = np.mean(score_history[-100:])
        epislon = max(epislon * eps_decay, eps_end)

        progress_bar.set_postfix_str(f"Score: {score: 7.2f}, 100 score avg: {score_avg: 7.2f}")
        progress_bar.update(0)

        # Early stopping
        if len(score_history) > 100:
            if score_avg >= target:
                break
        
    if (i + 1) < n_episodes:
        print("\nTarget score reached!")
    else:
        print("\nDone!")

    torch.save(agent.net_eval.state_dict(), f"./dqn-trained.h5")

    return score_history

In [28]:
def test(env, agent, loop=3):
    for i in range(loop):
        state, _ = env.reset()
        for t in range(500):
            action = agent.act(state, 0)
            env.render()
            state, reward, terminated, truncated, _ = env.step(action)
            if terminated or truncated:
                break
    env.close()
            

### Define visualisation function

In [29]:
def plotScore(scores):
    plt.figure(figsize=(10, 5))
    plt.plot(scores)
    plt.title("Score History")
    plt.xlabel("Episode")
    plt.ylabel("Score")
    plt.show()

### Training

In [32]:
# Set hyperparameters
BATCH_SIZE = 128
LR = 1e-3
EPISODES = 5000
TARGET_SCORE = 250.     # early training stop at avg score of last 100 episodes
GAMMA = 0.99            # discount factor
MEMORY_SIZE = 10000     # max memory buffer size
LEARN_STEP = 5          # how often to learn
TAU = 1e-3              # for soft update of target parameters

train_env = gym.make('LunarLander-v3')
num_states = train_env.observation_space.shape[0]
num_actions = train_env.action_space.n
agent = DQN(
    n_states = num_states,
    n_actions = num_actions,
    batch_size = BATCH_SIZE,
    lr = LR,
    gamma = GAMMA,
    memory_size = MEMORY_SIZE,
    learn_step = LEARN_STEP,
    tau = TAU
)

In [None]:
score_hist = train(train_env, agent, n_episodes=EPISODES, target=TARGET_SCORE)

In [None]:
# Visualise training history
plotScore(score_hist)

### Test the trained agent

In [None]:
test_env = gym.make('LunarLander-v3', render_mode="human")
num_states = train_env.observation_space.shape[0]
num_actions = train_env.action_space.n
agent = DQN(
    n_states = num_states,
    n_actions = num_actions,
    batch_size = BATCH_SIZE,
    lr = LR,
    gamma = GAMMA,
    memory_size = MEMORY_SIZE,
    learn_step = LEARN_STEP,
    tau = TAU
)

# Load the trained agent
agent.net_eval.load_state_dict(torch.load(f'./dqn-trained.h5', weights_only=True))

In [None]:
test(test_env, agent, loop=10)