In [4]:
import numpy as np
import pandas as pd
import random
import warnings

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import evaluate
from easyfsl.methods import FewShotClassifier, PrototypicalNetworks

from statistics import mean

from get_processed_data import get_processed_data
from FSLMethods import form_datasets, training_epoch, evaluate_model
from FSLDataset import FSLDataset
from FSLNetworks import DummyNetwork

warnings.filterwarnings('ignore')

### Splitting data

In [5]:
df, X_train, y_train, X_val, y_val, X_test, y_test = get_processed_data()

## Datasets need to be a FewShotDataset / torch Dataset with .get_labels
train_set, validation_set = form_datasets(X_train, y_train, X_val, y_val, X_test, y_test)

Training set shape: (12335, 55) (12335,)
Validation set shape: (1542, 55) (1542,)
Test set shape: (1542, 55) (1542,)
Disclaimer: It is assumed that the label is the last column of the input dataframe... ...
Disclaimer: It is assumed that the label is the last column of the input dataframe... ...


### Model training (meta-learning / episodic training)

Episodic training simulates the few-shot learning scenario to train a prototypical network. Training data is organized into episodes that resemble few-shot tasks.

Set up

In [6]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)

In [None]:
## Each task contains N_WAY * (N_SHOT + N_QUERY) samples
N_WAY = 2
N_SHOT = 2
N_QUERY = 3

N_TASKS_PER_EPOCH = 10 
N_VALIDATION_TASKS = 100 

## Sampliers used to generate tasks
train_sampler = TaskSampler(dataset = train_set, n_way = N_WAY, n_shot = N_SHOT, 
                            n_query = N_QUERY, n_tasks = N_TASKS_PER_EPOCH)
validation_sampler = TaskSampler(dataset = validation_set, n_way = N_WAY, n_shot = N_SHOT,
                                 n_query = N_QUERY, n_tasks = N_VALIDATION_TASKS)

## Loader generates an iterable given a dataset and a sampler
train_loader = DataLoader(dataset = train_set, batch_sampler = train_sampler, pin_memory = True,
                          collate_fn = train_sampler.episodic_collate_fn)
validation_loader = DataLoader(dataset = validation_set, batch_sampler = validation_sampler, pin_memory = True,
                               collate_fn = validation_sampler.episodic_collate_fn)

Initializing optimizer, loss function, etc

In [None]:
## Loss fn
LOSS_FN = nn.CrossEntropyLoss()


## Scheduler
    ## Scales learning rate by gamma at the designated milestones
scheduler_milestones = [70, 85]
scheduler_gamma = 0.1


## Optimizer
backbone = DummyNetwork(in_dim = len(X_train.columns), hidden_dim = 256, out_dim = 4) ## TODO: What should the correct dimensions be?
model = PrototypicalNetworks(backbone, use_softmax = True)

LEARNING_RATE = 0.001
MOMENTUM = 0.9
DECAY = 5e-4
train_optimizer = optim.SGD(params = model.parameters(), lr = LEARNING_RATE, momentum = MOMENTUM, 
                            weight_decay = DECAY)
train_scheduler = MultiStepLR(optimizer = train_optimizer, milestones = scheduler_milestones,
                              gamma = scheduler_gamma)


## Writer
log_dir = 'fsl_logs'
tb_writer = SummaryWriter(log_dir = log_dir)


Train the model

In [None]:
N_EPOCHS = 100
log_update_frequency = 10

## Track best parameters (weights and biases) and performance of model
best_state = model.state_dict()
best_f1_score = 0.0
best_recall = 0.0

for epoch in range(N_EPOCHS):
    print(f'Epoch: {epoch}')
    
    average_epoch_loss = training_epoch(model, train_loader, train_optimizer, LOSS_FN)

    actuals, predictions, _, _, recall, _, f1_score = evaluate_model(model, validation_loader)
    
    if f1_score > best_f1_score:
        best_f1_score = f1_score
        # best_state = model.state_dict()
        # print("Ding ding ding! We found a new best model!")

    if recall > best_recall:
        best_recall = recall
        best_state = model.state_dict()
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_epoch_loss, epoch)
    tb_writer.add_scalar('F1', f1_score, epoch)
    tb_writer.add_scalar('Recall', recall, epoch)

    ## Update the scheduler such that it knows when to adjust the learning rate
    train_scheduler.step()


## Retrieve the best model
missing_keys, unexpected_keys = model.load_state_dict(best_state)
print(f'Best f1-score after {N_EPOCHS} epochs of training: {best_f1_score}')
print(f'Best recall after {N_EPOCHS} epochs of training: {best_recall}')


### Model evaluation

In [None]:
# evaluate(model, test_loader) ##TODO: Implement method