In [1]:
import numpy as np
import random
import time

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import evaluate
from easyfsl.methods import FewShotClassifier, PrototypicalNetworks

from statistics import mean

from get_processed_data import get_processed_data
from sampling import undersample, oversample, smote, ncr
from FSLMethods import training_epoch, evaluate_model
from FSLDataset import FSLDataset
from FSLNetworks import DummyNetwork

  warn(


### Splitting data

In [2]:
df, X_train, y_train, X_val, y_val, X_test, y_test = get_processed_data()


Training set shape: (12335, 55) (12335,)
Validation set shape: (1542, 55) (1542,)
Test set shape: (1542, 55) (1542,)


Sampling

In [3]:
def get_sampled_data (X_train, y_train, sampling_method):
    new_X_train, new_y_train = None, None
    
    if sampling_method == 'oversampling':
        new_X_train, new_y_train = oversample(X_train, y_train)
    elif sampling_method == 'undersampling':
        new_X_train, new_y_train = undersample(X_train, y_train)
    elif sampling_method == 'ncr':
        new_X_train, new_y_train = ncr(X_train, y_train)
    else:
        new_X_train, new_y_train = smote(X_train, y_train)

    return new_X_train, new_y_train

In [4]:
new_X_train, new_y_train = get_sampled_data(X_train, y_train, sampling_method = 'smote')

train_df = new_X_train.copy()
train_df['Fraud'] = new_y_train
# train_df['Normal'] = 1 - y_train

validation_df = X_val.copy()
validation_df['Fraud'] = y_val
# validation_df['Normal'] = 1 - y_val

test_df = X_test.copy()
test_df['Fraud'] = y_test
# test_df['Normal'] = 1 - y_test

### Model training (meta-learning / episodic training)

Episodic training simulates the few-shot learning scenario to train a prototypical network. Training data is organized into episodes that resemble few-shot tasks.

Set up

In [5]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)

In [6]:
## Each task contains N_WAY * (N_SHOT + N_QUERY) samples
N_WAY = 2
N_SHOT = 2
N_QUERY = 3

N_TASKS_PER_EPOCH = 10 
N_VALIDATION_TASKS = 100 

## Datasets need to be a FewShotDataset / torch Dataset with .get_labels
train_set = FSLDataset(train_df)
validation_set = FSLDataset(validation_df)

## Sampliers used to generate tasks
train_sampler = TaskSampler(dataset = train_set, n_way = N_WAY, n_shot = N_SHOT, 
                            n_query = N_QUERY, n_tasks = N_TASKS_PER_EPOCH)
validation_sampler = TaskSampler(dataset = validation_set, n_way = N_WAY, n_shot = N_SHOT,
                                 n_query = N_QUERY, n_tasks = N_VALIDATION_TASKS)

## Loader generates an iterable given a dataset and a sampler
train_loader = DataLoader(dataset = train_set, batch_sampler = train_sampler, pin_memory = True,
                          collate_fn = train_sampler.episodic_collate_fn)
validation_loader = DataLoader(dataset = validation_set, batch_sampler = validation_sampler, pin_memory = True,
                               collate_fn = validation_sampler.episodic_collate_fn)

Disclaimer: It is assumed that the label is the last column of the input dataframe... ...
Disclaimer: It is assumed that the label is the last column of the input dataframe... ...


Initializing optimizer, loss function, etc

In [7]:
## Loss fn
LOSS_FN = nn.CrossEntropyLoss()


## Scheduler
    ## Scales learning rate by gamma at the designated milestones
scheduler_milestones = [70, 85]
scheduler_gamma = 0.1


## Optimizer
backbone = DummyNetwork(in_dim = len(X_train.columns), hidden_dim = 256, out_dim = 4) ## TODO: What should the correct dimensions be?
model = PrototypicalNetworks(backbone, use_softmax = True)

LEARNING_RATE = 0.001
MOMENTUM = 0.9
DECAY = 5e-4
train_optimizer = optim.SGD(params = model.parameters(), lr = LEARNING_RATE, momentum = MOMENTUM, 
                            weight_decay = DECAY)
train_scheduler = MultiStepLR(optimizer = train_optimizer, milestones = scheduler_milestones,
                              gamma = scheduler_gamma)


## Writer
log_dir = 'fsl_logs'
tb_writer = SummaryWriter(log_dir = log_dir)


Train the model

In [8]:
N_EPOCHS = 100
log_update_frequency = 10

## Track best parameters (weights and biases) and performance of model
best_state = model.state_dict()
best_f1_score = 0.0
best_recall = 0.0

for epoch in range(N_EPOCHS):
    print(f'Epoch: {epoch}')
    
    average_epoch_loss = training_epoch(model, train_loader, train_optimizer, LOSS_FN)

    actuals, predictions, _, _, recall, _, f1_score = evaluate_model(model, validation_loader)
    
    if f1_score > best_f1_score:
        best_f1_score = f1_score
        # best_state = model.state_dict()
        # print("Ding ding ding! We found a new best model!")

    if recall > best_recall:
        best_recall = recall
        best_state = model.state_dict()
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_epoch_loss, epoch)
    tb_writer.add_scalar('F1', f1_score, epoch)
    tb_writer.add_scalar('Recall', recall, epoch)

    ## Update the scheduler such that it knows when to adjust the learning rate
    train_scheduler.step()


## Retrieve the best model
missing_keys, unexpected_keys = model.load_state_dict(best_state)
print(f'Best f1-score after {N_EPOCHS} epochs of training: {best_f1_score}')
print(f'Best recall after {N_EPOCHS} epochs of training: {best_recall}')


Epoch: 0


100%|██████████| 10/10 [00:00<00:00, 76.62it/s, loss=0.759]
100%|██████████| 100/100 [00:01<00:00, 76.53it/s, f1=0.54, recall=0.537]


Ding ding ding! We found a new best model!
Epoch: 1


100%|██████████| 10/10 [00:00<00:00, 141.45it/s, loss=0.76]
100%|██████████| 100/100 [00:01<00:00, 78.68it/s, f1=0.485, recall=0.477]


Epoch: 2


100%|██████████| 10/10 [00:00<00:00, 112.57it/s, loss=0.718]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 82.70it/s, f1=0.486, recall=0.48]


Epoch: 3


100%|██████████| 10/10 [00:00<00:00, 125.03it/s, loss=0.701]
100%|██████████| 100/100 [00:01<00:00, 70.85it/s, f1=0.481, recall=0.477]


Epoch: 4


100%|██████████| 10/10 [00:00<00:00, 108.69it/s, loss=0.695]
100%|██████████| 100/100 [00:01<00:00, 67.75it/s, f1=0.488, recall=0.497]


Epoch: 5


100%|██████████| 10/10 [00:00<00:00, 115.21it/s, loss=0.714]
100%|██████████| 100/100 [00:01<00:00, 60.29it/s, f1=0.509, recall=0.5] 


Epoch: 6


100%|██████████| 10/10 [00:00<00:00, 80.03it/s, loss=0.697]
100%|██████████| 100/100 [00:01<00:00, 59.55it/s, f1=0.42, recall=0.387]


Epoch: 7


100%|██████████| 10/10 [00:00<00:00, 83.43it/s, loss=0.682]
100%|██████████| 100/100 [00:01<00:00, 63.41it/s, f1=0.477, recall=0.457]


Epoch: 8


100%|██████████| 10/10 [00:00<00:00, 140.42it/s, loss=0.704]
100%|██████████| 100/100 [00:01<00:00, 51.38it/s, f1=0.477, recall=0.453]


Epoch: 9


100%|██████████| 10/10 [00:00<00:00, 133.31it/s, loss=0.776]
100%|██████████| 100/100 [00:01<00:00, 54.83it/s, f1=0.493, recall=0.483]


Epoch: 10


100%|██████████| 10/10 [00:00<00:00, 122.08it/s, loss=0.731]
100%|██████████| 100/100 [00:01<00:00, 53.35it/s, f1=0.462, recall=0.44]


Epoch: 11


100%|██████████| 10/10 [00:00<00:00, 142.71it/s, loss=0.692]
100%|██████████| 100/100 [00:01<00:00, 71.22it/s, f1=0.526, recall=0.53]


Epoch: 12


100%|██████████| 10/10 [00:00<00:00, 139.14it/s, loss=0.706]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 74.84it/s, f1=0.499, recall=0.49]


Epoch: 13


100%|██████████| 10/10 [00:00<00:00, 117.42it/s, loss=0.703]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 74.24it/s, f1=0.494, recall=0.497]


Epoch: 14


100%|██████████| 10/10 [00:00<00:00, 85.83it/s, loss=0.688]
100%|██████████| 100/100 [00:01<00:00, 68.10it/s, f1=0.483, recall=0.473]


Epoch: 15


100%|██████████| 10/10 [00:00<00:00, 155.05it/s, loss=0.723]
100%|██████████| 100/100 [00:01<00:00, 67.23it/s, f1=0.487, recall=0.463]


Epoch: 16


100%|██████████| 10/10 [00:00<00:00, 131.12it/s, loss=0.697]
100%|██████████| 100/100 [00:01<00:00, 70.56it/s, f1=0.473, recall=0.46]


Epoch: 17


100%|██████████| 10/10 [00:00<00:00, 138.16it/s, loss=0.686]
100%|██████████| 100/100 [00:01<00:00, 70.54it/s, f1=0.481, recall=0.46]


Epoch: 18


100%|██████████| 10/10 [00:00<00:00, 119.34it/s, loss=0.72]
100%|██████████| 100/100 [00:01<00:00, 72.19it/s, f1=0.52, recall=0.517]


Epoch: 19


100%|██████████| 10/10 [00:00<00:00, 138.58it/s, loss=0.69]
100%|██████████| 100/100 [00:01<00:00, 64.37it/s, f1=0.497, recall=0.473]


Epoch: 20


100%|██████████| 10/10 [00:00<00:00, 121.57it/s, loss=0.731]
100%|██████████| 100/100 [00:01<00:00, 67.31it/s, f1=0.478, recall=0.457]


Epoch: 21


100%|██████████| 10/10 [00:00<00:00, 163.98it/s, loss=0.719]
100%|██████████| 100/100 [00:01<00:00, 68.47it/s, f1=0.521, recall=0.52]


Epoch: 22


100%|██████████| 10/10 [00:00<00:00, 87.60it/s, loss=0.739]
100%|██████████| 100/100 [00:01<00:00, 71.47it/s, f1=0.492, recall=0.48]


Epoch: 23


100%|██████████| 10/10 [00:00<00:00, 161.45it/s, loss=0.736]
100%|██████████| 100/100 [00:01<00:00, 57.17it/s, f1=0.503, recall=0.483]


Epoch: 24


100%|██████████| 10/10 [00:00<00:00, 130.50it/s, loss=0.659]
100%|██████████| 100/100 [00:01<00:00, 62.61it/s, f1=0.481, recall=0.457]


Epoch: 25


100%|██████████| 10/10 [00:00<00:00, 76.62it/s, loss=0.712]
100%|██████████| 100/100 [00:01<00:00, 67.43it/s, f1=0.51, recall=0.493]


Epoch: 26


100%|██████████| 10/10 [00:00<00:00, 138.17it/s, loss=0.671]
100%|██████████| 100/100 [00:01<00:00, 67.86it/s, f1=0.487, recall=0.493]


Epoch: 27


100%|██████████| 10/10 [00:00<00:00, 121.18it/s, loss=0.684]
100%|██████████| 100/100 [00:01<00:00, 62.14it/s, f1=0.493, recall=0.487]


Epoch: 28


100%|██████████| 10/10 [00:00<00:00, 82.77it/s, loss=0.727]
100%|██████████| 100/100 [00:01<00:00, 58.26it/s, f1=0.485, recall=0.48]


Epoch: 29


100%|██████████| 10/10 [00:00<00:00, 85.15it/s, loss=0.701]
100%|██████████| 100/100 [00:01<00:00, 62.89it/s, f1=0.492, recall=0.463]


Epoch: 30


100%|██████████| 10/10 [00:00<00:00, 106.33it/s, loss=0.661]
100%|██████████| 100/100 [00:01<00:00, 60.19it/s, f1=0.527, recall=0.537]


Epoch: 31


100%|██████████| 10/10 [00:00<00:00, 133.08it/s, loss=0.66]
100%|██████████| 100/100 [00:01<00:00, 52.08it/s, f1=0.404, recall=0.37]


Epoch: 32


100%|██████████| 10/10 [00:00<00:00, 154.41it/s, loss=0.71]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 54.44it/s, f1=0.486, recall=0.457]


Epoch: 33


100%|██████████| 10/10 [00:00<00:00, 126.63it/s, loss=0.68]
100%|██████████| 100/100 [00:01<00:00, 53.61it/s, f1=0.463, recall=0.43]


Epoch: 34


100%|██████████| 10/10 [00:00<00:00, 116.06it/s, loss=0.669]
100%|██████████| 100/100 [00:01<00:00, 66.54it/s, f1=0.481, recall=0.477]


Epoch: 35


100%|██████████| 10/10 [00:00<00:00, 125.20it/s, loss=0.685]
100%|██████████| 100/100 [00:01<00:00, 53.17it/s, f1=0.495, recall=0.47]


Epoch: 36


100%|██████████| 10/10 [00:00<00:00, 138.96it/s, loss=0.73]
100%|██████████| 100/100 [00:01<00:00, 63.11it/s, f1=0.527, recall=0.503]


Epoch: 37


100%|██████████| 10/10 [00:00<00:00, 134.70it/s, loss=0.686]
100%|██████████| 100/100 [00:01<00:00, 60.90it/s, f1=0.454, recall=0.43]


Epoch: 38


100%|██████████| 10/10 [00:00<00:00, 86.44it/s, loss=0.622]
100%|██████████| 100/100 [00:01<00:00, 67.69it/s, f1=0.489, recall=0.463]


Epoch: 39


100%|██████████| 10/10 [00:00<00:00, 123.71it/s, loss=0.679]
100%|██████████| 100/100 [00:01<00:00, 62.56it/s, f1=0.478, recall=0.463]


Epoch: 40


100%|██████████| 10/10 [00:00<00:00, 130.53it/s, loss=0.733]
100%|██████████| 100/100 [00:01<00:00, 62.07it/s, f1=0.47, recall=0.45] 


Epoch: 41


100%|██████████| 10/10 [00:00<00:00, 123.42it/s, loss=0.684]
100%|██████████| 100/100 [00:02<00:00, 46.56it/s, f1=0.521, recall=0.52]


Epoch: 42


100%|██████████| 10/10 [00:00<00:00, 113.58it/s, loss=0.659]
100%|██████████| 100/100 [00:02<00:00, 48.36it/s, f1=0.514, recall=0.51]


Epoch: 43


100%|██████████| 10/10 [00:00<00:00, 95.85it/s, loss=0.714]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 50.49it/s, f1=0.534, recall=0.557]


Ding ding ding! We found a new best model!
Epoch: 44


100%|██████████| 10/10 [00:00<00:00, 80.57it/s, loss=0.598]
100%|██████████| 100/100 [00:01<00:00, 56.27it/s, f1=0.472, recall=0.437]


Epoch: 45


100%|██████████| 10/10 [00:00<00:00, 107.95it/s, loss=0.647]
100%|██████████| 100/100 [00:01<00:00, 53.24it/s, f1=0.523, recall=0.53]


Epoch: 46


100%|██████████| 10/10 [00:00<00:00, 80.10it/s, loss=0.643]
100%|██████████| 100/100 [00:01<00:00, 55.43it/s, f1=0.519, recall=0.51]


Epoch: 47


100%|██████████| 10/10 [00:00<00:00, 80.68it/s, loss=0.693]
100%|██████████| 100/100 [00:01<00:00, 57.00it/s, f1=0.536, recall=0.527]


Epoch: 48


100%|██████████| 10/10 [00:00<00:00, 124.78it/s, loss=0.67]
100%|██████████| 100/100 [00:01<00:00, 54.42it/s, f1=0.523, recall=0.54]


Epoch: 49


100%|██████████| 10/10 [00:00<00:00, 108.33it/s, loss=0.655]
100%|██████████| 100/100 [00:02<00:00, 49.45it/s, f1=0.482, recall=0.47]


Epoch: 50


100%|██████████| 10/10 [00:00<00:00, 68.32it/s, loss=0.651]
100%|██████████| 100/100 [00:02<00:00, 43.79it/s, f1=0.567, recall=0.59]


Ding ding ding! We found a new best model!
Epoch: 51


100%|██████████| 10/10 [00:00<00:00, 91.97it/s, loss=0.655]
100%|██████████| 100/100 [00:01<00:00, 59.26it/s, f1=0.486, recall=0.46]


Epoch: 52


100%|██████████| 10/10 [00:00<00:00, 82.89it/s, loss=0.608]
100%|██████████| 100/100 [00:01<00:00, 65.11it/s, f1=0.497, recall=0.48]


Epoch: 53


100%|██████████| 10/10 [00:00<00:00, 101.81it/s, loss=0.635]
100%|██████████| 100/100 [00:01<00:00, 59.33it/s, f1=0.51, recall=0.517]


Epoch: 54


100%|██████████| 10/10 [00:00<00:00, 96.39it/s, loss=0.686]
100%|██████████| 100/100 [00:01<00:00, 69.69it/s, f1=0.497, recall=0.483]


Epoch: 55


100%|██████████| 10/10 [00:00<00:00, 111.23it/s, loss=0.725]
100%|██████████| 100/100 [00:01<00:00, 69.46it/s, f1=0.514, recall=0.5] 


Epoch: 56


100%|██████████| 10/10 [00:00<00:00, 84.97it/s, loss=0.628]
100%|██████████| 100/100 [00:01<00:00, 63.62it/s, f1=0.525, recall=0.523]


Epoch: 57


100%|██████████| 10/10 [00:00<00:00, 94.22it/s, loss=0.685]
100%|██████████| 100/100 [00:01<00:00, 66.38it/s, f1=0.497, recall=0.483]


Epoch: 58


100%|██████████| 10/10 [00:00<00:00, 74.79it/s, loss=0.651]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 69.03it/s, f1=0.515, recall=0.513]


Epoch: 59


100%|██████████| 10/10 [00:00<00:00, 92.64it/s, loss=0.674]
100%|██████████| 100/100 [00:01<00:00, 67.59it/s, f1=0.506, recall=0.503]


Epoch: 60


100%|██████████| 10/10 [00:00<00:00, 79.19it/s, loss=0.652]
100%|██████████| 100/100 [00:01<00:00, 67.48it/s, f1=0.543, recall=0.543]


Epoch: 61


100%|██████████| 10/10 [00:00<00:00, 114.67it/s, loss=0.707]
100%|██████████| 100/100 [00:01<00:00, 65.59it/s, f1=0.482, recall=0.45]


Epoch: 62


100%|██████████| 10/10 [00:00<00:00, 85.91it/s, loss=0.69]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 66.90it/s, f1=0.504, recall=0.51]


Epoch: 63


100%|██████████| 10/10 [00:00<00:00, 79.52it/s, loss=0.654]
100%|██████████| 100/100 [00:01<00:00, 60.76it/s, f1=0.48, recall=0.463]


Epoch: 64


100%|██████████| 10/10 [00:00<00:00, 100.21it/s, loss=0.675]
100%|██████████| 100/100 [00:01<00:00, 66.14it/s, f1=0.518, recall=0.517]


Epoch: 65


100%|██████████| 10/10 [00:00<00:00, 95.75it/s, loss=0.63]
100%|██████████| 100/100 [00:01<00:00, 58.52it/s, f1=0.515, recall=0.513]


Epoch: 66


100%|██████████| 10/10 [00:00<00:00, 77.35it/s, loss=0.719]
100%|██████████| 100/100 [00:01<00:00, 56.06it/s, f1=0.477, recall=0.457]


Epoch: 67


100%|██████████| 10/10 [00:00<00:00, 93.00it/s, loss=0.635]
100%|██████████| 100/100 [00:01<00:00, 65.89it/s, f1=0.517, recall=0.507]


Epoch: 68


100%|██████████| 10/10 [00:00<00:00, 116.94it/s, loss=0.709]
100%|██████████| 100/100 [00:01<00:00, 63.02it/s, f1=0.537, recall=0.553]


Epoch: 69


100%|██████████| 10/10 [00:00<00:00, 90.49it/s, loss=0.646]
100%|██████████| 100/100 [00:01<00:00, 62.14it/s, f1=0.492, recall=0.49]


Epoch: 70


100%|██████████| 10/10 [00:00<00:00, 76.03it/s, loss=0.645]
100%|██████████| 100/100 [00:01<00:00, 68.73it/s, f1=0.497, recall=0.48]


Epoch: 71


100%|██████████| 10/10 [00:00<00:00, 89.68it/s, loss=0.624]
100%|██████████| 100/100 [00:01<00:00, 66.42it/s, f1=0.472, recall=0.447]


Epoch: 72


100%|██████████| 10/10 [00:00<00:00, 94.13it/s, loss=0.762]
100%|██████████| 100/100 [00:01<00:00, 64.48it/s, f1=0.536, recall=0.523]


Epoch: 73


100%|██████████| 10/10 [00:00<00:00, 73.33it/s, loss=0.642]
100%|██████████| 100/100 [00:01<00:00, 68.23it/s, f1=0.463, recall=0.433]


Epoch: 74


100%|██████████| 10/10 [00:00<00:00, 119.15it/s, loss=0.694]
100%|██████████| 100/100 [00:01<00:00, 64.98it/s, f1=0.536, recall=0.527]


Epoch: 75


100%|██████████| 10/10 [00:00<00:00, 97.81it/s, loss=0.691]
100%|██████████| 100/100 [00:01<00:00, 66.15it/s, f1=0.527, recall=0.54]


Epoch: 76


100%|██████████| 10/10 [00:00<00:00, 88.53it/s, loss=0.687]
100%|██████████| 100/100 [00:01<00:00, 64.37it/s, f1=0.492, recall=0.477]


Epoch: 77


100%|██████████| 10/10 [00:00<00:00, 102.73it/s, loss=0.66]
100%|██████████| 100/100 [00:01<00:00, 65.12it/s, f1=0.539, recall=0.553]


Epoch: 78


100%|██████████| 10/10 [00:00<00:00, 82.30it/s, loss=0.651]
100%|██████████| 100/100 [00:01<00:00, 59.76it/s, f1=0.438, recall=0.413]


Epoch: 79


100%|██████████| 10/10 [00:00<00:00, 74.90it/s, loss=0.636]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 62.06it/s, f1=0.452, recall=0.423]


Epoch: 80


100%|██████████| 10/10 [00:00<00:00, 80.96it/s, loss=0.673]
100%|██████████| 100/100 [00:01<00:00, 63.42it/s, f1=0.51, recall=0.517]


Epoch: 81


100%|██████████| 10/10 [00:00<00:00, 92.10it/s, loss=0.625]
100%|██████████| 100/100 [00:01<00:00, 61.99it/s, f1=0.522, recall=0.503]


Epoch: 82


100%|██████████| 10/10 [00:00<00:00, 109.27it/s, loss=0.692]
100%|██████████| 100/100 [00:01<00:00, 58.16it/s, f1=0.515, recall=0.497]


Epoch: 83


100%|██████████| 10/10 [00:00<00:00, 100.48it/s, loss=0.628]
100%|██████████| 100/100 [00:01<00:00, 55.08it/s, f1=0.499, recall=0.48]


Epoch: 84


100%|██████████| 10/10 [00:00<00:00, 90.49it/s, loss=0.782]
100%|██████████| 100/100 [00:01<00:00, 60.43it/s, f1=0.527, recall=0.53]


Epoch: 85


100%|██████████| 10/10 [00:00<00:00, 103.62it/s, loss=0.668]
100%|██████████| 100/100 [00:01<00:00, 56.74it/s, f1=0.446, recall=0.403]


Epoch: 86


100%|██████████| 10/10 [00:00<00:00, 114.94it/s, loss=0.679]
100%|██████████| 100/100 [00:01<00:00, 58.32it/s, f1=0.505, recall=0.503]


Epoch: 87


100%|██████████| 10/10 [00:00<00:00, 125.00it/s, loss=0.681]
100%|██████████| 100/100 [00:01<00:00, 61.78it/s, f1=0.544, recall=0.54]


Epoch: 88


100%|██████████| 10/10 [00:00<00:00, 74.89it/s, loss=0.587]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 54.81it/s, f1=0.487, recall=0.487]


Epoch: 89


100%|██████████| 10/10 [00:00<00:00, 81.95it/s, loss=0.605]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 58.71it/s, f1=0.524, recall=0.517]


Epoch: 90


100%|██████████| 10/10 [00:00<00:00, 92.16it/s, loss=0.605]
100%|██████████| 100/100 [00:01<00:00, 56.97it/s, f1=0.558, recall=0.573]


Epoch: 91


100%|██████████| 10/10 [00:00<00:00, 94.77it/s, loss=0.601]
100%|██████████| 100/100 [00:01<00:00, 55.65it/s, f1=0.486, recall=0.483]


Epoch: 92


100%|██████████| 10/10 [00:00<00:00, 90.49it/s, loss=0.588]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 57.94it/s, f1=0.545, recall=0.57]


Epoch: 93


100%|██████████| 10/10 [00:00<00:00, 71.41it/s, loss=0.615]
100%|██████████| 100/100 [00:01<00:00, 59.90it/s, f1=0.509, recall=0.497]


Epoch: 94


100%|██████████| 10/10 [00:00<00:00, 79.35it/s, loss=0.586]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:01<00:00, 55.04it/s, f1=0.485, recall=0.487]


Epoch: 95


100%|██████████| 10/10 [00:00<00:00, 106.95it/s, loss=0.639]
100%|██████████| 100/100 [00:01<00:00, 58.51it/s, f1=0.494, recall=0.447]


Epoch: 96


100%|██████████| 10/10 [00:00<00:00, 82.28it/s, loss=0.691]
100%|██████████| 100/100 [00:01<00:00, 58.91it/s, f1=0.524, recall=0.477]


Epoch: 97


100%|██████████| 10/10 [00:00<00:00, 96.61it/s, loss=0.622]
100%|██████████| 100/100 [00:01<00:00, 55.37it/s, f1=0.513, recall=0.513]


Epoch: 98


100%|██████████| 10/10 [00:00<00:00, 103.61it/s, loss=0.66]
100%|██████████| 100/100 [00:01<00:00, 58.48it/s, f1=0.523, recall=0.5] 


Epoch: 99


100%|██████████| 10/10 [00:00<00:00, 77.81it/s, loss=0.675]
100%|██████████| 100/100 [00:01<00:00, 61.82it/s, f1=0.509, recall=0.523]

Best f1-score after 100 epochs of training: 0.5673076923076923
Best recall after 100 epochs of training: 0.59





### Model evaluation

In [9]:
# evaluate(model, test_loader) ##TODO: Implement method