# Part 1) Supervised Learning Task. Training.

- **GOAL**: train a Neural Network to predict the actions of a mid-level player (classification task)
<br>
- **ABOUT THE TRAINING DATA:**
    - our mid-level player is a _1StepLookaheadAgent_ instance
    - the training data is a set of (_obs_, _action_) pairs:
         - _obs_ is a game board where 1 is the active player and -1 is the oppponent
         - _action_ is the column that a professional player will choose to play in _obs_
    - Our supervised learning task: given an '*obs*' (game board), predict '*action*' (classification)
    - To know more about the dataset creation, refer to '*src/data/part1_dataset_generator.ipynb*'
    - the dataset used here can be found in: '*src/data/part1_data/part1_supervised_learning_data.txt*'
<br> 
- **DATA PREPROCESSING:**
    - To understand our implementation refer to *'src/models/custom_network.py'*
    - the state of the environment is a _6 x 7_ board with values: {0: empty, 1: player, -1: opponent}
    - The input of the model is a onehot encoded version of shape *2 x 6 x 7*
        - first channel: {1: player, 0: empty or opponent}
        - second channel: {1: opponent, 0: empty or player}
    - For each channel, the empty positions that can be filled in the current turn are set to -1
<br>
- **NETWORK ARCHITECTURE:**
    - To understand our implementation refer to *'src/models/custom_network.py'*
    - We have implemented a basic network structure that allows different degrees of complexity
    - The number of trainable parameters indicate the complexity of the network
    - The name of this general architecture is _'CNET\<*N*\>_', where 'N' is the number of convolutional filters and hidden units
    - **!!!!!** We implemented a **two-headed network achitecture** because some of the Deep Reinforcement Learning algorithms that we will implement later require two different prediction heads and are based on the network trained here. However, **the second head is NOT trained for this supervised learning task, only the first head**.
    - The kernel sizes (4x4 and 2x2) and the size of the output (7) are fixed (do not depend on '*N*')
    - For instance, the hidden layers of '_CNET*128*_' are:
        - 2Dconvolutional layer with *128* 4x4 filters  (backbone network)
        - 2Dconvolutional layer with *128* 2x2 filters  (backbone network)
        - fully connected layer with *128* units   (backbone network)
        - fully connected layer with *128* units (first prediction head) -> outputs the POLICY
        - [NOT TRAINED HERE] fully connected layer with *128* units (second prediction head)
    - For the general case _'CNET\<*N*\>_', change the number '128' to \<N\>
<br>
- **TRAINING STEPS:**
    - Load the 200k data pairs (_obs, action_) from '*src/data/part1_data/part1_supervised_learning_data.txt*'
    - trainig-validation-test data split: 160k + 20k + 20k samples (respectively)
    - The following training parameters are the same for the all the networks:
        - number of epochs = 20
        - batch size = 64
        - learning rate = 5e-4
        - Loss function = Cross Entropy Loss
        - weight decay (L2 regularization) = 2e-3
    - Every 600 updates, we evaluate the model on the validation data
    - When the training ends, we evaluate the model on the test data
         
- **RESULTS**:
    - We used the network architecture 'CNET128' (defined above and in the code as well)
    - After training for 20 epochs and using the training hyper-parameter values described above:
        - ~87% training accuracy
        - ~85% validation accuracy
        - ~85% test accuracy 
        - minimum overfitting
    - At the end of the training loop, the training and validation accuracies and losses are plotted 
    - The architecture CNET128 is saved in '*src/models/architectures/cnet128.json*'
    - The best weights are saved in: '*src/models/saved_models/supervised_cnet128.pt*'
    - The training hyper-parameters are saved in: '*src/models/saved_models/supervised_cnet128_hparams.json*'

## 1) Imports

In [None]:
import os
import random
from datetime import datetime

import numpy as np
import torch
from torch.utils.data import BatchSampler, SubsetRandomSampler
from torchsummary import summary
import matplotlib.pyplot as plt

In [None]:
### YOUR PATH HERE
code_dir = '/home/marc/Escritorio/RL-connect4/'

if os.path.isdir(code_dir):
    # local environment
    os.chdir(code_dir)
    print(f"directory -> '{code_dir }'")
else:
    # google colab environment
    if os.path.isdir('./src'):
        print("'./src' dir already exists")
    else:  # not unzipped yet
        !unzip -q src.zip
        print("'./src.zip' file successfully unzipped")

In [None]:
from src.models.custom_network import CustomNetwork
from src.environment.connect_game_env import ConnectGameEnv

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 2) Hyper parameters

In [None]:
hparams = {
    # where the data is located
    'data_file_path': './src/data/part1_data/part1_supervised_learning_data.txt',
    
    # network architecture
    'network_arch': {
        'conv_block': [[128,4,0], 'relu', [128,2,0], 'relu'],
        'fc_block': [128, 'relu'],
        'first_head': [128, 'relu', 7],
        'second_head': [128, 'relu', 1]
    },
    'model_name': 'cnet128',

    # train-test split (sizes)
    'train_size': 160000,
    'val_size': 20000,
    'test_size': 20000,

    # Training params
    'num_epochs': 20,
    'batch_size': 64,
    'loss_log_every': 200,
    'validation_every': 600,
    'weight_decay': 2e-3,
    'lr': 5e-4,
    
    # save models
    'save_model': True,  # debug mode
    'save_model_file_path': './src/models/saved_models/supervised_{model_name}.pt'
}


hparams['save_model_file_path'] = hparams['save_model_file_path'].format(model_name=hparams['model_name'])
hparams['n_samples'] = hparams['train_size'] + hparams['val_size'] + hparams['test_size']

## 3) Create the model

In [None]:
model = CustomNetwork(**hparams['network_arch']).to(device)
print(summary(model, input_size=model.input_shape))

In [None]:
optimizer = torch.optim.Adam(
    params=model.parameters(),
    lr=hparams['lr'],
    weight_decay=hparams['weight_decay']
)

## 4) Preprocess the data

In [None]:
# read the data
sep = ';'

with open(hparams['data_file_path'], 'r') as file:
    lines = file.read().split(sep)[:-1]  # last line is ''
assert hparams['n_samples'] <= len(lines), f'not enough data: {hparams["n_samples"]} > {len(lines)}'

obs_list, actions_list = [], []
for line in lines[:hparams['n_samples']]:
    flat_obs = [int(i) for i in line[:-1]]
    obs = np.array(flat_obs).reshape((6,7))
    obs[obs==2] = -1
    obs_list.append(obs)
    actions_list.append(int(line[-1]))

print(f"{len(obs_list)} data samples loaded from '{hparams['data_file_path']}'")

In [None]:
# turn the data points into pytorch tensors to feed the model

obs_tensor = torch.cat([model.obs_to_model_input(obs=o) for o in obs_list])
actions_tensor = torch.tensor(actions_list, dtype=torch.long)

# train-val-test split
data_idx = random.sample(range(len(obs_list)), k=hparams['n_samples'])
train_idx = data_idx[:hparams['train_size']]
val_idx = data_idx[hparams['train_size']:-hparams['test_size']]
test_idx = data_idx[-hparams['test_size']:]

# train_x and train_y will be moved to device in batches while training
train_x = obs_tensor[train_idx]
train_y = actions_tensor[train_idx]

val_x = obs_tensor[val_idx].to(device)
val_y = actions_tensor[val_idx].to(device)

test_x = obs_tensor[test_idx].to(device)
test_y = actions_tensor[test_idx].to(device)

print(f'train: {train_x.shape},  {train_y.shape}')
print(f'val: {val_x.shape},  {val_y.shape}')
print(f'test:  {test_x.shape},  {test_y.shape}')

## 5) Training loop

In [None]:
loss_func = torch.nn.CrossEntropyLoss()

In [None]:
def accuracy(output: torch.tensor, labels: torch.tensor):
    """
    Computes the accuracy of the given predictions
    
    :param output: network output (predictions)
    :param labels: ground truth
    """
    
    pred_y = torch.max(output, 1)[1].data.squeeze()
    acc = torch.sum(pred_y == labels).item() / float(labels.size(0))
    return acc


def validate():
    """
    Validate the model on the validation data: global val_x and val_y
    Returns the valiation loss and validation accuracy
    """
    
    is_training = model.training
    model.eval()
    
    with torch.no_grad():
        val_pred, _ = model(val_x)
        val_loss = loss_func(val_pred, val_y)
        val_acc = accuracy(output=val_pred, labels=val_y)
    
    if is_training:
        model.train()
    
    return val_loss.item(), val_acc


def test():
    """
    Test the model on the test data: global test_x and test_y
    Returns the test loss and test accuracy
    """
    
    is_training = model.training
    model.eval()
    
    with torch.no_grad():
        test_pred, _ = model(test_x)
        test_loss = loss_func(test_pred, test_y)
        test_acc = accuracy(output=test_pred, labels=test_y)
    
    if is_training:
        model.train()
    
    return test_loss.item(), test_acc

In [None]:
history = {'loss': [], 'acc': [], 
           'val_loss': [], 'val_acc': [],
           'test_loss': 0, 'test_acc': 0}

env = ConnectGameEnv()

model.train()

num_epoch_steps = int(np.ceil(len(train_x)/hparams['batch_size']))
step_count = 0
for epoch in range(hparams['num_epochs']):
    
    epoch_step = 0
    index_list = SubsetRandomSampler(range(len(train_x)))
    for batch_index in BatchSampler(index_list, hparams['batch_size'], False):

        batch_x = train_x[batch_index].to(device)
        batch_y = train_y[batch_index].to(device)

        batch_pred, _ = model(batch_x)
        loss = loss_func(batch_pred, batch_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        history['loss'].append(loss.item())
        history['acc'].append(accuracy(batch_pred.detach(), batch_y))

        step_count += 1
        epoch_step += 1

        if step_count % hparams['loss_log_every'] == 0:
            avg_loss = float(np.mean(history['loss'][-hparams['loss_log_every']]))
            avg_acc = float(np.mean(history['acc'][-hparams['loss_log_every']]))
            print(f'Epoch: {epoch+1}/{hparams["num_epochs"]},    ' +
                  f'{epoch_step}/{num_epoch_steps} steps,    ' +
                  f'avg_loss={round(avg_loss, 3)},    ' +
                  f'avg_acc={round(avg_acc, 3)}')

        if step_count % hparams['validation_every'] == 0:
            val_loss, val_acc = validate()
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            print(f'  --> val_loss={round(val_loss, 3)},   val_acc={round(val_acc, 3)}')

In [None]:
# test the final model

test_loss, test_acc = test()
history['test_loss'] = test_loss
history['test_acc'] = test_acc
print(f'test_loss: {round(test_loss, 4)}')
print(f'test_acc: {round(test_acc, 4)}')

In [None]:
# save the final model

if hparams['save_model']:
    model.save_weights(
        file_path=hparams['save_model_file_path'],
        training_hparams=hparams,
    )
    print('model saved!')

## 6) Plot training results

In [None]:
def moving_average(data, m):
    avg_data = []
    for i in range(len(data) - m):
        avg_data.append(np.mean(data[i:i+m]))
    return avg_data

In [None]:
fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(12, 4))

moving = 50

max_val_x = len(history['val_loss']) * hparams['validation_every'] 
val_x_range = range(0, max_val_x, hparams['validation_every'])

avg_loss = moving_average(history['loss'], m=moving)

ax0.plot(range(len(avg_loss)), avg_loss, label='training loss')
ax0.plot(val_x_range, history['val_loss'], label='validation loss')
ax0.legend()
ax0.set_xlabel('updates')
ax0.set_ylabel('loss')
ax0.set_title('Training and Validation Losses')

avg_acc = moving_average(history['acc'], m=moving)
ax1.plot(range(len(avg_acc)), avg_acc, label='training acc')
ax1.plot(val_x_range, history['val_acc'], label='validation acc')
ax1.legend()
ax1.set_xlabel('updates')
ax1.set_ylabel('acc')
ax1.set_title('Training and Validation Accuracies')
ax1.axhline(y=1, linestyle='--', alpha=0.3, color='black')
ax1.set_ylim(-0.1, 1.1)