# Training a PPO Agent

**GOAL:** To train a _PPO Agent_ to play Connect4. This is the second part of our training pipeline (the second part also includes other RL algorithms):
  - **Part 1) Supervised Learning**
    - refer to *'src/train/part1_supervised_learning.ipynb'*.
    - RESULT: a pre-trained network with basic knowledge of the game
  - **Part 2) Reinforcement Learning**
    - In this case: Proximal Policy Optimization
    - **TRANFER LEARNING FROM PART 1:**
      - Load the pre-trained weights from Part 1
      - Freeze the convolutional block (*feature extractor*, it is not trained here)
      - Train the rest of Fully Connected to estimate the optimal policy and the state values.
<br>

**METHOD:**
   - We used an *Experience Buffer* to store different episodes
       - capacity = 2000
       - exponent for reward backpropagation = 3
       - for more details on the implementation refer to '*src/data/replay_memory.py*'
   - The network architecture we used is defined in '*src/models/architectures/cnet128.json*'
   - We applied *transfer learning* to use the knowledge learned in '*src/train/part1_supervised_learning.ipynb*'
       - 1. load the network weights from '*src/models/saved_models/network_128.pt*'
       - 2. freeze the convolutional block (*feature extractor*)
       - 3. train the fully-connected layers to learn the policy and the state values
   - There is an '*old agent*' that is an older and stable version of the agent. It is updated when:
       - the agent achieve a new best win rate against the 1-StepLA Agent
   - When the performance of the current network decreases significantly, the latest changes are undone and it goes back to the most recent *old weights*
<br>

**TRAINING:**
   - We trained for 100k time steps (aprox 320 iterations)
   - The learning hyperparameters are:
       - buffer capacity = 2000
       - ppo epochs = 5
       - c1 = 0.75
       - c2 = 0.04
       - learning rate = 1e-4
       - batch size = 32
       - weight decay (L2 regularization) = 5e-5
       - discount factor (gamma) = 0.95
       - loss function (critic) = Smooth L1
   - Every 1000 updates, the PPO agent competes against:
       - vs the Random Agent
       - vs the older network
       - vs the 1-Step Lookahead Agent
<br>

**PPO RESULTS:**
   - Our best PPO gent beats the 1-Step LookAhead Agent **≈84%** of the time
   - The weights of the model are saved in '*src/models/saved_models/best_ppo.pt*'
   - The training hyperaparameters are saved in '*src/models/saved_models/best_ppo_hparams.json*'
   - Plots of the training losses
   - Plots of the average game length in self-play games
   - Plots of the evolution of the win rate vs 1StepLA

## 1) Imports

In [None]:
import copy
import os
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
from torchsummary import summary
import matplotlib.pyplot as plt
from torch.utils.data import BatchSampler, SubsetRandomSampler

In [None]:
### YOUR PATH HERE
code_dir = '/home/marc/Escritorio/RL-connect4/'

if os.path.isdir(code_dir):
    # local environment
    os.chdir(code_dir)
    print(f"directory -> '{code_dir }'")
else:
    # google colab environment
    if os.path.isdir('./src'):
        print("'./src' dir already exists")
    else:  # not unzipped yet
        !unzip -q src.zip
        print("'./src.zip' file successfully unzipped")

In [None]:
from src.agents.baselines.random_agent import RandomAgent
from src.agents.baselines.n_step_lookahead_agent import NStepLookaheadAgent
from src.models.custom_network import CustomNetwork
from src.agents.trainable.pg_agent import PGAgent
from src.environment.connect_game_env import ConnectGameEnv
from src.data.replay_memory import ReplayMemory
from src.eval.competition import competition

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
timestamp =datetime.now().strftime("%d%m_%H%M")
print(f"'{timestamp}'")

In [None]:
SAVE_MODELS = False  # if False, it is debug mode 

## 2) Hyper parameters

In [None]:
hparams = {
    # environment, data, memory
    'buffer_capacity': 2000,
    'reward_backprop_exponent': 3,
    
    # agent properties and model architecture
    'avg_symmetric_probs': True,
    'model_arch_path': './src/models/architectures/cnet128.json',
    'pretrained_model_weights': './src/models/saved_models/supervised_cnet128.pt',
    'freeze_conv_block': True,
    
    # Information displayed while training
    'loss_log_every': 1,
    'comp_every': 5,
    'vs_1StepLA_win_rate_decrease_to_undo_updates': 0.08,
    'moving_avg': 100,
        
    # PPO hyperparameters
    'num_iterations': 320,
    'clip_param': 0.2,
    'ppo_epoch': 5,
    'c1': 0.75,
    'c2': 0.04,
    
    # Training loop params
    'batch_size': 32,
    'gamma' : 0.95,
    'weight_decay': 5e-5,
    'lr': 1e-4
}

## 3) PPO Agent

In [None]:
def load_state_dict(from_: nn.Module, to_: nn.Module) -> None:
    """
    Copies the weights from the module 'from_' to the module 'to_'
    Loading the state_dict from a module does not keep the
    convolutional_block frozen, so it has to be done manually.
    """

    to_.load_state_dict(from_.state_dict())
    if hparams['freeze_conv_block']:
        for param in to_.conv_block.parameters():
            param.requires_grad = False

            

def create_model() -> nn.Module:
    """
    Create a PPO network following the architecture in 'model_arch_file',
    and initializing the network with the weights stored in 'load_weights_path'. 
    """

    ppo_net = CustomNetwork.from_architecture(
        file_path=hparams['model_arch_path']
    ).to(device)
    
    ppo_net.load_weights(hparams['pretrained_model_weights'])
    
    for param in ppo_net.conv_block.parameters():
        param.requires_grad = False
    
    return ppo_net


def create_agent():
    """
    Create a new PPO Agent
    """

    model_ = create_model()
    agent_ = PGAgent(
        model=model_,
        stochastic_mode=True,
        avg_symmetric_probs=hparams['avg_symmetric_probs'],
        name='PPO Agent'
    )
    return agent_

In [None]:
ppo_agent = create_agent()

print("model device is cuda?", next(ppo_agent.model.parameters()).is_cuda)
print()
print(summary(ppo_agent.model, input_size=ppo_agent.model.input_shape))

In [None]:
old_ppo_agent = create_agent()
load_state_dict(from_=ppo_agent.model, to_=old_ppo_agent.model)
old_ppo_agent.model.eval()

In [None]:
agent_name = ppo_agent.name.replace(' ', '_')
save_best_vs_1StepLA_file = f'{agent_name}_'+'{win_rate}_vs_1StepLA_'+f'{timestamp}.pt'

print('"' + save_best_vs_1StepLA_file + '"')

## 4) Experience Buffer

In [None]:
buffer = ReplayMemory(
    capacity=hparams['buffer_capacity'],
    reward_backprop_exponent=hparams['reward_backprop_exponent']
)

## 5) Prepare the training lopp

In [None]:
critic_loss_func = nn.SmoothL1Loss()

optimizer = torch.optim.Adam(
    params=ppo_agent.model.parameters(), 
    lr=hparams['lr'],
    weight_decay=hparams['weight_decay']
)

In [None]:
def compute_avg_target_value(policy_: nn.Module, states_: torch.tensor) -> torch.tensor:
    """
    Compute the value of the given states.
    NOTE: output does not contain gradient, it is part of the target
    It takes advantage of the board symmetry.
    """

    with torch.no_grad():
        _, v = policy_(states_)

    sym_states = torch.flip(states_, dims=[-1])
    with torch.no_grad():
        _, sym_v = policy_(sym_states)
    
    avg_v = (v + sym_v) / 2
    return avg_v

In [None]:
def training_step(policy_, 
                  optimizer_,
                  buffer_, 
                  hparams_, 
                  critic_loss_func_):

    gamma = hparams_['gamma']
    ppo_epoch = hparams_['ppo_epoch']
    batch_size = hparams_['batch_size']
    clip_param = hparams_['clip_param']
    c1 = hparams_['c1']
    c2 = hparams_['c2']  # entropy coefficient
    
    policy_.train()
    
    data = copy.deepcopy(buffer_.all_data())
    transitions = buffer_.Transition(*zip(*data))

    tuple_state_batch = tuple([policy_.obs_to_model_input(obs=s)
                               for s in transitions.state])
    tuple_next_state_batch = tuple([policy_.obs_to_model_input(obs=s_)
                                    for s_ in transitions.next_state])
    
    state_batch = torch.cat(tuple_state_batch).float().to(device)
    next_state_batch = torch.cat(tuple_next_state_batch).float().to(device)
    action_batch = torch.tensor(transitions.action, dtype=torch.long, device=device)
    reward_batch = torch.tensor(transitions.reward, dtype=torch.float, device=device).view(-1,1)
    not_done_mask = 1 - torch.tensor(transitions.done, device=device, dtype=torch.int).view(-1,1)

    
    old_a_logp = torch.tensor(transitions.log_prob, 
                              dtype=torch.float,
                              device=device).view(-1, 1)

    with torch.no_grad():
        target_v = reward_batch - gamma*not_done_mask*compute_avg_target_value(policy_, next_state_batch)
        adv = target_v - compute_avg_target_value(policy_, state_batch)
        
    total_losses, policy_losses, value_losses, entropies, ratios = [], [], [], [], []

    for epoch in range(ppo_epoch):
        for index in BatchSampler(SubsetRandomSampler(range(buffer_.capacity)), batch_size, False):
            logits, _ = policy_(state_batch[index])
            dist = torch.distributions.Categorical(logits=logits)
            entropy = dist.entropy()
            
            a_logp = dist.log_prob(action_batch[index]).unsqueeze(-1)
            a_logp_old = old_a_logp[index]

            # Compute ratio: pi/pi_old = e^(ln pi - ln pi_old)
            ratio = torch.exp(a_logp - a_logp_old)

            surr1 = ratio * adv[index]

            surr2 = torch.clamp(ratio, 1.0-clip_param, 1.0+clip_param) * adv[index]

            policy_loss = torch.min(surr1, surr2).mean()
            value_loss = critic_loss_func_(policy_(state_batch[index])[1], target_v[index])
            entropy = entropy.mean()

            loss = - policy_loss + c1*value_loss - c2*entropy

            optimizer_.zero_grad()
            loss.backward()
            optimizer_.step()
            
            total_losses.append(loss.item())
            policy_losses.append(-policy_loss.item())
            value_losses.append(value_loss.item())
            entropies.append(entropy.item())
            ratios.append(ratio.mean().item())

    return total_losses, policy_losses, value_losses, entropies, ratios

## 6) Training loop

In [None]:
history = {'total_losses': [], 'policy_losses': [], 'value_losses': [], 
           'entropies': [], 'ratios': [],
           'vs_random_win_rate': [], 'vs_random_avg_game_len': [],
           'vs_1StepLA_win_rate': [], 'vs_1StepLA_avg_game_len': [],
           'vs_old_self_win_rate': [], 'vs_old_self_avg_game_len': [],
           'comp_every': hparams['comp_every'], 'comp_n_episodes': 100,
          }

vs_1StepLA_best_win_rate = 0.5

if not os.path.exists('checkpoints'):
    os.makedirs('checkpoints')

env = ConnectGameEnv()
comp_env = ConnectGameEnv()

random_opponent = RandomAgent()
oneStepLA = NStepLookaheadAgent(n=1, prefer_central_columns=True)

eps_ = np.finfo(np.float32).eps.item()
n_updates = 0
for i_iter in range(hparams['num_iterations']):
    
    buffer.reset()
    ppo_agent.model.train()
    
    while len(buffer) != buffer.capacity:
        buffer.push_self_play_episode_transitions(
            agent=ppo_agent,
            env=env,
            init_random_obs=True,
            push_symmetric=True,
            exploration_rate=0.10,
        )
    
    total_losses, policy_loss, value_loss, avg_entropy, ratios = training_step(
        policy_=ppo_agent.model, 
        optimizer_=optimizer, 
        buffer_=buffer, 
        hparams_=hparams,
        critic_loss_func_=critic_loss_func,
    )
    n_updates += len(total_losses)

    history['total_losses'].extend(total_losses)
    history['policy_losses'].extend(policy_loss)
    history['value_losses'].extend(value_loss)
    history['entropies'].extend(avg_entropy)
    history['ratios'].extend(ratios)
    
    if (i_iter+1) % hparams['loss_log_every'] == 0:
        last_total_loss_vals = history['total_losses'][-hparams['moving_avg']:]
        last_policy_loss_vals = history['policy_losses'][-hparams['moving_avg']:]
        last_value_loss_vals = history['value_losses'][-hparams['moving_avg']:]
        last_avg_entropy_vals = history['entropies'][-hparams['moving_avg']:]
        last_ratio_vals = history['ratios'][-hparams['moving_avg']:]
        print(f"Iter: {i_iter+1}/{hparams['num_iterations']}   " +
              f"update: {n_updates}   "
              f"Loss: {round(np.mean(last_total_loss_vals), 4)}   " +
              f"P_Loss: {round(np.mean(last_policy_loss_vals), 4)}   " +
              f"V_Loss: {round(np.mean(last_value_loss_vals), 4)}   " +
              f"S: {round(np.mean(last_avg_entropy_vals), 3)}   ")
    
    # compete against the opponents to measure the performance
    if (i_iter+1) % hparams['comp_every'] == 0:
        # compete against the Random Agent
        ppo_agent.model.eval()
        with torch.no_grad():
            res1, o1 = competition(
                env=comp_env, 
                agent1=ppo_agent, 
                agent2=random_opponent,
                progress_bar=False,
            )
        win_rate_rand = round(res1['win_rate1'], 3)
        print(f"    {win_rate_rand} vs. RAND" +
              f"    avg_len={round(res1['avg_game_len'], 2)}")
        history['vs_random_win_rate'].append(win_rate_rand)
        history['vs_random_avg_game_len'].append(res1['avg_game_len'])

        # compete againts the old (stable) version of the network
        old_ppo_agent.model.eval()
        with torch.no_grad():
            res2, o2 = competition(
                env=comp_env, 
                agent1=ppo_agent, 
                agent2=old_ppo_agent,
                progress_bar=False,
            )
        win_rate_self = round(res2['win_rate1'], 3)
        print(f"    {win_rate_self} vs. SELF" +
              f"    avg_len={round(res2['avg_game_len'], 2)}")
        history['vs_old_self_win_rate'].append(win_rate_self)
        history['vs_old_self_avg_game_len'].append(res2['avg_game_len'])

        # compete against the Professional Player
        with torch.no_grad():
            res3, o3 = competition(
                env=comp_env, 
                agent1=ppo_agent,
                agent2=oneStepLA,
                progress_bar=False,
            )
        win_rate_1StepLA = round(res3['win_rate1'], 3)
        print(f"    {win_rate_1StepLA} vs. 1StepLA" +
              f"    avg_len={round(res3['avg_game_len'], 2)}")
        history['vs_1StepLA_win_rate'].append(win_rate_1StepLA)
        history['vs_1StepLA_avg_game_len'].append(res3['avg_game_len'])

        if win_rate_1StepLA > vs_1StepLA_best_win_rate:
            vs_1StepLA_best_win_rate = win_rate_1StepLA
            load_state_dict(from_=ppo_agent.model, to_=old_ppo_agent.model)
            old_ppo_agent.model.eval()
            if SAVE_MODELS:
                file_name = (
                    f"checkpoints/" + save_best_vs_1StepLA_file.format(win_rate=int(win_rate_1StepLA*100))
                )
                ppo_agent.model.save_weights(
                    file_path=file_name,
                    training_hparams=hparams,
                )
                print(f"        new best {file_name} is saved!!!")
        elif win_rate_1StepLA <= vs_1StepLA_best_win_rate-hparams['vs_1StepLA_win_rate_decrease_to_undo_updates']:
            load_state_dict(from_=old_ppo_agent.model, to_=ppo_agent.model)
            print("        undoing last updates...")

## 7) Plot training results

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

In [None]:
data = moving_average(total_losses[:100000], w=1000)
x_vals = [x/1000 for x in range(len(data))]

plt.plot(x_vals, data)
plt.title('PPO Training Loss')
plt.xlabel("updates (in thousands)")
plt.ylabel("loss")
#plt.gca().xaxis.set_major_locator(MultipleLocator(10))
#plt.gca().yaxis.set_major_locator(MultipleLocator(0.025))
plt.show()

In [None]:
#num_updates = 100000
#x_vals = range(0, num_updates, num_updates//len(history['vs_1StepLA_win_rate']))
#x_vals = [x/1000 for x in x_vals]
data = history['vs_1StepLA_win_rate']
x_vals = [x/1000 for x in range(len(data))]

plt.plot(x_vals, data)
plt.title('PPO win rate vs 1StepLA')
plt.xlabel("updates (in thousands)")
plt.ylabel("win rate")
#plt.gca().xaxis.set_major_locator(MultipleLocator(10000))
plt.axhline(1, linestyle='--', alpha=0.4)
plt.axhline(0.5, linestyle='--', alpha=0.4)
plt.ylim(0.35, 1.09)
#plt.xlim(0, 105)
plt.show()

In [None]:
#num_updates = 100000
#x_vals = range(0, num_updates, num_updates//len(history['vs_1StepLA_win_rate']))
#x_vals = [x/1000 for x in x_vals]
data = history['vs_old_self_avg_game_len']
x_vals = [x/1000 for x in range(len(data))]

plt.title('PPO self-play game length')
plt.plot(x_vals, data)
plt.xlabel("updates (in thousands)")
plt.ylabel("game length")
#plt.gca().xaxis.set_major_locator(MultipleLocator(10))
plt.axhline(42, linestyle='--', alpha=0.4)
plt.axhline(7, linestyle='--', alpha=0.4)
plt.ylim(-1, 45)