In [None]:
DOWNSTREAM_TASK = 'ner'

In [None]:
import os
import pickle
import matplotlib.pyplot as plt

from utils.utils import ENV_VARIABLE

In [None]:
DIR_PRETRAINED_MODELS = ENV_VARIABLE['DIR_PRETRAINED_MODELS']
DIR_DATASETS = os.path.join(ENV_VARIABLE['DIR_DATASETS'], DOWNSTREAM_TASK)
DIR_CHECKPOINTS = os.path.join(ENV_VARIABLE['DIR_CHECKPOINTS'], DOWNSTREAM_TASK)

# SETTINGS

In [None]:
# pretrained_model_name = 'af-ai-center/bert-base-swedish-uncased'
pretrained_model_name = 'bert-base-multilingual-uncased'

In [None]:
model_name = pretrained_model_name.split('/')[-1]

In [None]:
# dataset = 'SUC'
dataset = 'swedish_ner_corpus'

In [None]:
num_epochs = 2
prune_ratio = 0.1

# LOAD METRICS

In [None]:
pkl_path = f'./{DIR_CHECKPOINTS}/metrics__{dataset}__{model_name}__{num_epochs}__{prune_ratio}.pkl'
with open(pkl_path, 'rb') as f:
    metrics = pickle.load(f)

# START

In [None]:
def display(_metrics):
    
    print('--- train ---')
    print('> batch')
    print(_metrics['batch']['train'])
    print('--- valid ---')
    print('> batch')
    print(_metrics['batch']['valid'])
    print('> epoch')
    print(_metrics['epoch']['valid'])
    
# display(metrics)

In [None]:
def plot_learning_rate(metrics):
    lr = metrics['batch']['train']['lr']
    fig, ax = plt.subplots()
    ax.plot(lr, linestyle='', marker='.')
    ax.set_xlabel('batch')
    ax.set_ylabel('learning rate')
    
plot_learning_rate(metrics)

In [None]:
def plot_metric(metrics, num_epochs, metric, f1_spec=None, ax=None):
    ### PREP ###
    if f1_spec is None:
        batch_train = metrics['batch']['train'][metric]
        epoch_valid = metrics['epoch']['valid'][metric]
    else:
        batch_train = metrics['batch']['train'][metric][f1_spec[0]][f1_spec[1]]
        epoch_valid = metrics['epoch']['valid'][metric][f1_spec[0]][f1_spec[1]]
    
    clr = {'loss': 'r', 
           'acc': 'green', 
           'f1_macro': 'orange',
           'f1_micro': 'blue',
          }
    if f1_spec is None:
        metric_spec = metric
    else:
        f1_spec_1st = f1_spec[0]
        metric_spec = f'{metric}_{f1_spec_1st}'

    ### PLOT ###
    if ax == None:
        fig, ax = plt.subplots()
    
    ax.plot(batch_train, 
            linestyle='-', marker='.', color=clr[metric_spec], alpha=0.3, label='train')
    
    x = [len(batch_train)*float(i)/num_epochs for i in range(1, num_epochs+1)]
    ax.plot(x, epoch_valid, 
            linestyle='', marker='o', color=clr[metric_spec], label='valid')
    
    ax.set_xlabel('batch')
    ax.set_ylabel(metric)
    if metric == 'loss':
        ax.set_ylim([0, None])
    else:
        ax.set_ylim([0, 1])
    if metric in ['loss', 'acc']:
        ax.set_title(metric)
    elif metric == 'f1':
        f1_spec_1st = f1_spec[0]
        f1_spec_2nd = f1_spec[1]
        ax.set_title(f'f1 score: {f1_spec_1st}, {f1_spec_2nd}')
    ax.legend()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
plot_metric(metrics, num_epochs, 'loss', ax=ax[0])
plot_metric(metrics, num_epochs, 'acc', ax=ax[1])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
plot_metric(metrics, num_epochs, 'f1', ('macro', 'all'), ax=ax[0])
plot_metric(metrics, num_epochs, 'f1', ('macro', 'fil'), ax=ax[1])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
plot_metric(metrics, num_epochs, 'f1', ('micro', 'all'), ax=ax[0])
plot_metric(metrics, num_epochs, 'f1', ('micro', 'fil'), ax=ax[1])