# Training a Vanilla DQN or a Dueling DQN Agent

**GOAL:** To train a _Vanilla Deep Q-Network_ or a _Dueling Deep Q-Network_ to play Connect4. The same notebook can be used to train both of them. This is the second part of our training pipeline (the second part also includes other RL algorithms):
  - **Part 1) Supervised Learning**
    - refer to *'src/train/part1_supervised_learning.ipynb'*.
    - RESULT: a pre-trained network with basic knowledge of the game
  - **Part 2) Reinforcement Learning**
    - In this case: Vanilla and Dueling DQNs
    - **TRANFER LEARNING FROM PART 1:**
      - Load the pre-trained weights from Part 1
      - Freeze the convolutional block (*feature extractor*, it is not trained here)
      - Train the rest of Fully Connected to estimate the optimal Q-values
<br>
                     
**METHOD:**
   - We implemented the *Minimax DQN* algorithm for turn-based games
   - We used an *Experience Replay Memory* to break the correlation between consecutive samples
       - capacity = 60k
       - minimum size = 30k
       - exponent for reward backpropagation = 3
       - for more details on the implementation refer to '*src/data/replay_memory.py*'
   - We used a *Target Network* to avoid overestimating the Q-values.
   - The network architecture we used is defined in '*src/models/architectures/cnet128.json*'
   - We applied *transfer learning* to use the knowledge learned in '*src/train/part1_supervised_learning.ipynb*'
       - 1. load the network weights from '*src/models/saved_models/network_128.pt*'
       - 2. freeze the convolutional block (*feature extractor*)
       - 3. train the fully-connected layers to learn the Q-values
   - We used a *Decaying epsilon-greedy* scheme to ensure sufficient exploration:
       - epsilon start = 0.8
       - epsilon decay = 600
       - epislon end = 0.05
   - There is an '*old agent*' that is an older and stable version of the agent. It is updated when:
       - the agent achieve a new best win rate against the 1-StepLA Agent
   - When the performance of the current network decreases significantly, the latest changes are undone and it goes back to the most recent *old weights*
    
    
**TRAINING:**
   - We trained for 100k time steps
   - The learning hyperparameters are:
       - learning rate = 1e-4
       - batch size = 48
       - weight decay (L2 regularization) = 5e-4
       - discount factor (gamma) = 0.95
       - 20 updates per new training episode
       - loss function = Mean Squared Error (MSE) loss function
       - update target network every 400 updates
   - Every 1000 updates, the DQNagent competes against:
       - vs the Random Agent
       - vs the older network
       - vs the 1-Step Lookahead Agent
       
**Vanilla DQN RESULTS:**
   - Our best DQNAgent beats the 1-Step LookAhead Agent **≈87%** of the time
   - The weights of the model are saved in '*src/models/saved_models/best_dqn.pt*'
   - The training hyperaparameters are saved in '*src/models/saved_models/best_dqn_hparams.json*'
   - Plots of the training losses
   - Plots of the average game length in self-play games
   - Plots of the evolution of the win rate vs 1StepLA
   
**Dueling DQN RESULTS:**
   - Our best DuelingDQNAgent beats the 1-Step LookAhead Agent **≈94%** of the time
   - The weights of the model are saved in '*src/models/saved_models/best_dueling_dqn.pt*'
   - The training hyperaparameters are saved in '*src/models/saved_models/best_dueling_dqn_hparams.json*'
   - Plots of the training losses
   - Plots of the average game length in self-play games
   - Plots of the evolution of the win rate vs 1StepLA

## 1) Imports

In [None]:
import math
import copy
import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
from torchsummary import summary
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FormatStrFormatter

In [None]:
# YOUR PATH HERE
code_dir = '/home/marc/Escritorio/RL-connect4/'

if os.path.isdir(code_dir):
    # local environment
    os.chdir(code_dir)
    print(f"directory -> '{code_dir }'")
else:
    # google colab environment
    if os.path.isdir('./src'):
        print("'./src' dir already exists")
    else:  # not unzipped yet
        !unzip -q src.zip
        print("'./src.zip' file successfully unzipped")

In [None]:
from src.agents.baselines.random_agent import RandomAgent
from src.agents.baselines.n_step_lookahead_agent import NStepLookaheadAgent
from src.models.custom_network import CustomNetwork
from src.agents.trainable.dqn_agent import DQNAgent
from src.agents.trainable.dueling_dqn_agent import DuelingDQNAgent
from src.environment.connect_game_env import ConnectGameEnv
from src.data.replay_memory import ReplayMemory
from src.eval.competition import competition
from src.environment.env_utils import get_illegal_actions

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
timestamp = datetime.now().strftime("%d%m_%H%M")
print(f"'{timestamp}'")

In [None]:
SAVE_MODELS = False  # if False, it is debug mode 

## 2) Hyper parameters

In [None]:
hparams = {
    # Agent type (Vanilla DQN or DuelingDQN)
    'is_dueling_dqn': True,
    
    # environment, data, memory
    'memory_capacity': 60000,
    'min_memory_size': 30000,
    'reward_backprop_exponent': 3,
    
    # agent properties and model architecture
    'avg_symmetric_q_vals': True,
    'model_arch_path': './src/models/architectures/cnet128.json',
    'pretrained_model_weights': './src/models/saved_models/supervised_cnet128.pt',
    'freeze_conv_block': True,
    
    # Information displayed while training
    'loss_log_every': 200,
    'comp_every': 1000,
    'vs_1StepLA_win_rate_decrease_to_undo_updates': 0.08,
        
    # Exploration Scheme (epsilon decay)
    'eps_start': 0.8,
    'eps_end': 0.05,
    'eps_decay': 600,
    
    # Training loop params
    'num_steps': 100000,
    'batch_size': 48,
    'gamma' : 0.95,
    'weight_decay': 5e-4,
    'lr': 1e-4,
    'n_updates_per_new_episode': 20,

    # target network update
    'target_net_update_every': 400
}

In [None]:
def load_state_dict(from_: nn.Module, to_: nn.Module) -> None:
    """
    Copies the weights from the module 'from_' to the module 'to_'
    Loading the state_dict from a module does not keep the
    convolutional_block frozen, so it has to be done manually.
    """

    to_.load_state_dict(from_.state_dict())
    if hparams['freeze_conv_block']:
        for param in to_.conv_block.parameters():
            param.requires_grad = False

            

def create_model() -> nn.Module:
    """
    Create a DQN (or DuelingDQN) following the architecture in 'model_arch_file',
    and initializing the network with the weights stored in 'load_weights_path'. 
    If it is DQN (not Dueling), remove the second prediction head 
    since it is not used.
    """

    q_net = CustomNetwork.from_architecture(
        file_path=hparams['model_arch_path']
    ).to(device)

    q_net.load_weights(hparams['pretrained_model_weights'])
    
    for param in q_net.conv_block.parameters():
        param.requires_grad = False
    
    if not hparams['is_dueling_dqn']:
        q_net.second_head = nn.Sequential()
    
    return q_net


def create_agent():
    """
    Create a new DQN (or DuelingDQN) Agent
    """

    model_ = create_model()
    
    if hparams['is_dueling_dqn']: 
        agent_ = DuelingDQNAgent(
            model=model_,
            avg_symmetric_q_vals=hparams['avg_symmetric_q_vals'],
        )
    else:
        agent_ = DQNAgent(
            model=model_,
            avg_symmetric_q_vals=hparams['avg_symmetric_q_vals'],
        )

    return agent_

In [None]:
dqn_agent = create_agent()

print("model device is cuda?", next(dqn_agent.model.parameters()).is_cuda)
print()
print(summary(dqn_agent.model, input_size=dqn_agent.model.input_shape))

In [None]:
agent_name = dqn_agent.name.replace(' ', '_')
save_best_vs_1StepLA_file = f'{agent_name}_'+'{win_rate}_vs_1StepLA_'+f'{timestamp}.pt'

print('"' + save_best_vs_1StepLA_file + '"')

In [None]:
old_dqn_agent = create_agent()
load_state_dict(from_=dqn_agent.model, to_=old_dqn_agent.model)
old_dqn_agent.model.eval()

In [None]:
target_net = create_model()
load_state_dict(from_=dqn_agent.model, to_=target_net)
target_net.eval()

## 4) Experience Replay Memory

In [None]:
memory = ReplayMemory(
    capacity=hparams['memory_capacity'],
    reward_backprop_exponent=hparams['reward_backprop_exponent']
)

## 5) Prepare the training loop

In [None]:
loss_func = nn.MSELoss()

optimizer = torch.optim.Adam(
    params=dqn_agent.model.parameters(), 
    lr=hparams['lr'],
    weight_decay=hparams['weight_decay']
)

In [None]:
def compute_eps_threshold(episode: int) -> int:
    """
    Implements the decaying epsilon scheme.
    Returns the exploration rate for the given episode number.
    """

    return (hparams['eps_end'] + (hparams['eps_start'] - hparams['eps_end']) 
            * math.exp(-1. * episode / hparams['eps_decay']))


# Plot the exploration rate for each training episode
num_episodes = hparams['num_steps'] // hparams['n_updates_per_new_episode']
episodes_x = range(num_episodes)
eps_y = [compute_eps_threshold(x) for x in episodes_x]

plt.plot(episodes_x, eps_y)
plt.axhline(hparams['eps_start'], linestyle='--', alpha=0.4)
plt.axhline(hparams['eps_end'], linestyle='--', alpha=0.4)
plt.ylim(0, 1.1)
plt.ylabel('eps')
plt.xlabel('training episode')
plt.title('decaying epsilon scheme')

In [None]:
def compute_q_vals(model_: nn.Module, state_batch_: torch.tensor) -> torch.tensor:
    """
    Compute the Q-vals for the given batch of states.
    This function accepts both DQN and DuelingDQN models
    NOTE: output contains gradients since it is a prediction to learn from
    """

    if hparams['is_dueling_dqn']:
        adv, v = model_(state_batch_)
        q_vals = v + (adv - adv.mean(dim=1).unsqueeze(-1))
    else:
        q_vals = model_(state_batch_)

    return q_vals


def compute_next_q_vals(target_net_: nn.Module, next_state_batch_: torch.tensor) -> torch.tensor:
    """
    Compute the Q-vals for the given batch of next states.
    This function accepts both DQN and DuelingDQN models
    NOTE: output does not contain gradient, it is part of the target
    It can apply symmetric average.
    It filters out illegal actions in next states (-inf Q-val)
    """

    with torch.no_grad():
        next_q_vals = compute_q_vals(model_=target_net_,
                                     state_batch_=next_state_batch_)

    if hparams['avg_symmetric_q_vals']:
        sym_next_state_batch = torch.flip(next_state_batch_, dims=[-1])
        with torch.no_grad():
            sym_next_q_vals = compute_q_vals(
                model_=target_net_, state_batch_=sym_next_state_batch
            )
        next_q_vals = (next_q_vals + sym_next_q_vals.flip(dims=[1])) / 2

    return next_q_vals

In [None]:
def training_step(online_net_, 
                  target_net_, 
                  optimizer_, 
                  loss_func_,
                  memory_, 
                  batch_size_, 
                  gamma_):
    
    if len(memory_) < hparams['min_memory_size']:
        raise Exception("len(memory) is below its minimum value")
    
    online_net_.train()

    # get a batch of transitions (copy)
    transitions = copy.deepcopy(memory_.sample(batch_size_))
    batch = memory_.Transition(*zip(*transitions))
    
    # preprocess the states to feed the model
    tuple_state_batch = tuple([online_net_.obs_to_model_input(obs=s)
                               for s in batch.state])
    tuple_next_state_batch = tuple([online_net_.obs_to_model_input(obs=s_)
                                    for s_ in batch.next_state])
    
    # turn the batch elements (lists) into pytorch tensors
    state_batch = torch.cat(tuple_state_batch).float().to(device)
    next_state_batch = torch.cat(tuple_next_state_batch).float().to(device)
    action_batch = torch.tensor(batch.action, device=device)
    reward_batch = torch.tensor(batch.reward, device=device)
    not_done_mask = 1 - torch.tensor(batch.done, device=device, dtype=torch.int)

    # predict the Q-values for the (state,action) pairs in the batch
    state_action_values = compute_q_vals(model_=online_net_, state_batch_=state_batch)
    state_action_values = state_action_values.gather(1, action_batch.unsqueeze(1))

    next_state_values = compute_next_q_vals(target_net_=online_net_, 
                                            next_state_batch_=next_state_batch)
        
    next_state_values = next_state_values.max(1)[0]

    # nega-max target (MINIMAX DQN)
    target = reward_batch - not_done_mask*gamma_*next_state_values

    # Compute regression loss between predicted Q values and targets
    loss = loss_func_(state_action_values, target.unsqueeze(1))

    # Take an SGD step
    optimizer_.zero_grad()
    loss.backward()
    optimizer_.step()

    return loss.item()

## 6) Training loop

In [None]:
history = {'losses': [],
           'vs_random_win_rate': [], 'vs_random_avg_game_len': [],
           'vs_1StepLA_win_rate': [], 'vs_1StepLA_avg_game_len': [],
           'vs_old_self_win_rate': [], 'vs_old_self_avg_game_len': [],
           'comp_every': hparams['comp_every'], 'comp_n_episodes': 100,
          }

vs_1StepLA_best_win_rate = 0.5

if not os.path.exists('checkpoints'):
    os.makedirs('checkpoints')

env = ConnectGameEnv()
comp_env = ConnectGameEnv()

random_opponent = RandomAgent()
oneStepLA = NStepLookaheadAgent(n=1, prefer_central_columns=True)

episode_count = 0
step_count = 0
while step_count < hparams['num_steps']:
    
    # compute the exploration rate for the trainig episode
    eps = compute_eps_threshold(step_count)

    # generate a new training episode (self-play)
    memory.push_self_play_episode_transitions(
        agent=dqn_agent,
        env=env,
        init_random_obs=True,
        push_symmetric=True,
        exploration_rate=eps
    )
        
    episode_count += 1
    if len(memory) < hparams['min_memory_size']:
        continue
    
    for _ in range(hparams['n_updates_per_new_episode']):
        step_count += 1
        # Perform one step of the optimization
        loss = training_step(online_net_=dqn_agent.model, 
                             target_net_=target_net, 
                             optimizer_=optimizer,
                             loss_func_=loss_func, 
                             memory_=memory, 
                             batch_size_=hparams['batch_size'], 
                             gamma_=hparams['gamma'])

        history['losses'].append(loss)
        
        # update the target_network
        if step_count % hparams['target_net_update_every'] == 0:
            load_state_dict(from_=dqn_agent.model, to_=target_net)
    
        # display information about the training process
        if step_count % hparams['loss_log_every'] == 0:
            last_losses = history['losses'][-hparams['loss_log_every']:]
            print(f"Step: {step_count}/{hparams['num_steps']}    " +
                  f"AvgLoss: {round(np.mean(last_losses),4)}")
        
        # compete against the opponents to measure the performance
        if step_count % hparams['comp_every'] == 0:
            # compete against the Random Agent
            dqn_agent.model.eval()
            with torch.no_grad():
                res1, o1 = competition(
                    env=comp_env, 
                    agent1=dqn_agent, 
                    agent2=random_opponent,
                    progress_bar=False)
            win_rate_rand = round(res1['win_rate1'], 3)
            print(f"    {win_rate_rand} vs. RAND" +
                  f"    avg_len={round(res1['avg_game_len'], 2)}")
            history['vs_random_win_rate'].append(win_rate_rand)
            history['vs_random_avg_game_len'].append(res1['avg_game_len'])
        
            # compete againts the old (stable) version of the network
            old_dqn_agent.model.eval()
            with torch.no_grad():
                res2, o2 = competition(
                    env=comp_env, 
                    agent1=dqn_agent, 
                    agent2=old_dqn_agent,
                    progress_bar=False,
                )
            win_rate_self = round(res2['win_rate1'], 3)
            print(f"    {win_rate_self} vs. SELF" +
                  f"    avg_len={round(res2['avg_game_len'], 2)}")
            history['vs_old_self_win_rate'].append(win_rate_self)
            history['vs_old_self_avg_game_len'].append(res2['avg_game_len'])
        
            # compete against the 1StepLA
            with torch.no_grad():
                res3, o3 = competition(
                    env=comp_env, 
                    agent1=dqn_agent,
                    agent2=oneStepLA,
                    progress_bar=False,
                )
            win_rate_1StepLA = round(res3['win_rate1'], 3)
            print(f"    {win_rate_1StepLA} vs. 1StepLA" +
                  f"    avg_len={round(res3['avg_game_len'], 2)}")
            history['vs_1StepLA_win_rate'].append(win_rate_1StepLA)
            history['vs_1StepLA_avg_game_len'].append(res3['avg_game_len'])
            
            if win_rate_1StepLA > vs_1StepLA_best_win_rate:
                vs_1StepLA_best_win_rate = win_rate_1StepLA
                load_state_dict(from_=dqn_agent.model, to_=old_dqn_agent.model)
                old_dqn_agent.model.eval()
                load_state_dict(from_=dqn_agent.model, to_=target_net)
                target_net.eval()
                if SAVE_MODELS:
                    file_name = f"checkpoints/" + save_best_vs_1StepLA_file.format(win_rate=int(win_rate_1StepLA*100))
                    dqn_agent.model.save_weights(
                        file_path=file_name,
                        training_hparams=hparams,
                    )
                    print(f"        new best {file_name} is saved!!!")
            elif win_rate_1StepLA <= vs_1StepLA_best_win_rate-hparams['vs_1StepLA_win_rate_decrease_to_undo_updates']:
                load_state_dict(from_=old_dqn_agent.model, to_=dqn_agent.model)
                load_state_dict(from_=old_dqn_agent.model, to_=target_net)
                target_net.eval()
                print("        undoing last updates...")

## 7) Plot training results

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

In [None]:
# Plot the training losses
                       
if hparams['is_dueling_dqn']:
    name = 'Dueling DQN'
else:
    name = 'Vanilla DQN'

data = moving_average(history['losses'][1000:100000], w=200)
x_vals = [x/1000 for x in range(len(data))]
plt.plot(x_vals, data)
plt.title(f'{name} Training Loss (MSE)')
plt.xlabel("updates (in thousands)")
plt.ylabel("loss")
#plt.gca().xaxis.set_major_locator(MultipleLocator(10))
#plt.gca().yaxis.set_major_locator(MultipleLocator(0.025))
plt.ylim(0.36,0.58)
plt.show()

In [None]:
num_updates = len(history['vs_old_self_win_rate']) * history['comp_every']
x_vals = range(0, num_updates, history['comp_every'])
x_vals = [x/1000 for x in x_vals]
data = history['vs_old_self_avg_game_len']

plt.title(f'{name} self-play game length')
plt.plot(x_vals, data)
plt.xlabel("updates (in thousands)")
plt.ylabel("game length")
#plt.gca().xaxis.set_major_locator(MultipleLocator(10))
plt.axhline(42, linestyle='--', alpha=0.4)
plt.axhline(7, linestyle='--', alpha=0.4)
plt.ylim(-1, 45)

In [None]:
num_updates = len(history['vs_1StepLA_win_rate']) * history['comp_every']
x_vals = range(0, num_updates, history['comp_every'])
x_vals = [x/1000 for x in x_vals]
data = history['vs_1StepLA_win_rate']

plt.plot(x_vals, data)
plt.title(f'{name} win rate vs 1StepLA')
plt.xlabel("updates (in thousands)")
plt.ylabel("win rate")
#plt.gca().xaxis.set_major_locator(MultipleLocator(10))
plt.axhline(1, linestyle='--', alpha=0.4)
plt.axhline(0.5, linestyle='--', alpha=0.4)
plt.ylim(0.35, 1.09)
plt.show()