# It's a good idea to have a reference implementation

In [1]:
import gymnasium as gym
import math
import random
from collections import namedtuple, deque

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
device = torch.device("mps") 

In [2]:
env = gym.make("Blackjack-v1") # is the stochasticity going to help in highlighting the usefulness of a distributional approach? let's hope so...

[Environment description](https://gymnasium.farama.org/environments/toy_text/blackjack/)

In [3]:
env.action_space, env.observation_space

(Discrete(2), Tuple(Discrete(32), Discrete(11), Discrete(2)))

[Reference implementation of DQN](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html?highlight=dqn)

In [4]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory():
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [5]:
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 64)
        self.layer3 = nn.Linear(64, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [6]:
class EpsilonScheduler:
    def __init__(self, eps_start, eps_end, eps_decay):
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        self.steps_done = 0

    def get_epsilon(self):
        epsilon = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1. * self.steps_done / self.eps_decay)
        self.steps_done += 1
        return epsilon

    def reset(self):
        self.steps_done = 0

In [7]:
env.observation_space

Tuple(Discrete(32), Discrete(11), Discrete(2))

In [8]:
env.reset()[0]

(12, 8, 0)

In [9]:
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
LR = 1e-4

n_actions = env.action_space.n
n_observations = len(env.reset()[0])

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10_000)
epsilon_scheduler = EpsilonScheduler(EPS_START, EPS_END, EPS_DECAY)

In [17]:
def select_action(state, eps_scheduler):
    sample = random.random()
    epsilon = eps_scheduler.get_epsilon()

    if sample < epsilon:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
    else:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)

def optimize_model(policy_net, target_net, memory, batch_size=BATCH_SIZE, gamma=GAMMA):
    # fill up the memory
    if len(memory) < batch_size:
        return
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # update params
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    expected_state_action_values = (next_state_values * gamma) + reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

In [18]:
num_episodes = 100
for i_episode in range(num_episodes):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    while True:
        action = select_action(state, epsilon_scheduler)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        memory.push(state, action, next_state, reward)

        state = next_state

        optimize_model(policy_net, target_net, memory)

        policy_net_state_dict = policy_net.state_dict()
        target_net_state_dict = target_net.state_dict()
        target_net.load_state_dict(target_net_state_dict)

        if done:
            break

In [19]:
def evaluate_agent(env, num_episodes, device, policy_net=None):
    agent = "Random" if policy_net is None else "Trained"
    if policy_net is None:
        action_selection = lambda state: torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
    else:
        action_selection = lambda state: policy_net(state).max(1)[1].view(1, 1)

    wins = 0
    total_earnings = 0
    for _ in range(num_episodes):
        state, info = env.reset()
        state = torch.tensor([state], dtype=torch.float32, device=device)

        while True:
            with torch.no_grad():
                action = action_selection(state)
            next_state, reward, done, _, _ = env.step(action.item())
            total_earnings += reward

            state = torch.tensor([next_state], dtype=torch.float32, device=device)

            if done:
                if reward > 0:
                    wins += 1
                break

    win_rate = wins / num_episodes
    average_earnings = total_earnings / num_episodes
    print(f'{agent} agent => Evaluation over {num_episodes} episodes. Win Rate: {win_rate:.2f}; Avg. Earnings: {average_earnings:.2f}')

    return win_rate

In [20]:
num_evaluation_episodes = 1000
evaluate_agent(env, num_evaluation_episodes, device, policy_net=policy_net);
evaluate_agent(env, num_evaluation_episodes, device);

Trained agent => Evaluation over 1000 episodes. Win Rate: 0.40; Avg. Earnings: -0.13
Random agent => Evaluation over 1000 episodes. Win Rate: 0.30; Avg. Earnings: -0.34


Not too shabby.