In [None]:
DOWNSTREAM_TASK = 'ner'

# START

In [None]:
import warnings; warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

import logging

import os
import numpy as np
import pandas as pd
import nltk
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import BertTokenizer, BertForTokenClassification
from pytorch_pretrained_bert import BertAdam

from utils.utils import get_available_models, get_available_datasets, prune_examples, ENV_VARIABLE
from utils.bert_dataset import BertDataset
from utils.input_example import InputExample
from utils.input_example_to_tensors import InputExampleToTensors
from utils.ner_processor import NerProcessor
from utils.ner_trainer import NERTrainer

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

In [None]:
DIR_PRETRAINED_MODELS = ENV_VARIABLE['DIR_PRETRAINED_MODELS']
DIR_DATASETS = os.path.join(ENV_VARIABLE['DIR_DATASETS'], DOWNSTREAM_TASK)
DIR_CHECKPOINTS = os.path.join(ENV_VARIABLE['DIR_CHECKPOINTS'], DOWNSTREAM_TASK)

### 0. Available Datasets

In [None]:
available_datasets = get_available_datasets(DOWNSTREAM_TASK)
available_datasets

### 1. Settings

In [None]:
pretrained_model_name = 'af-ai-center/bert-base-swedish-uncased'
#pretrained_model_name = 'bert-base-multilingual-uncased'

In [None]:
# dataset = 'SUC'
dataset = 'swedish_ner_corpus'

assert dataset in available_datasets

In [None]:
batch_size = 16
num_epochs = 2
prune_ratio = 0.1

### 2. Tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name, do_lower_case=False)  # needs to be False !!

### 3. Processor (Data)

In [None]:
if dataset == 'SUC':
    dataset_path = f'{DIR_DATASETS}/SUC/'
elif dataset == 'swedish_ner_corpus':
    dataset_path = f'{DIR_DATASETS}/swedish_ner_corpus/'
    
dataset_path

In [None]:
processor = NerProcessor(dataset_path, tokenizer, do_lower_case=True)  # needs to be True (applies .lower()) !!
processor

In [None]:
label_list = processor.get_label_list()
label_list

#### Prune Examples (Temp)

#### Train Data

In [None]:
train_input_examples_all = processor.get_input_examples('train')

In [None]:
train_input_examples = prune_examples(train_input_examples_all, ratio=prune_ratio)

In [None]:
print(train_input_examples[8].guid)
print(train_input_examples[8].text_a)
print(tokenizer.tokenize(train_input_examples[8].text_a))
print(train_input_examples[8].labels_a)

#### Validation Data

In [None]:
valid_input_examples_all = processor.get_input_examples('test')

In [None]:
valid_input_examples = prune_examples(valid_input_examples_all, ratio=prune_ratio)

In [None]:
print(len(valid_input_examples))
print(valid_input_examples[1].text_a)
print(tokenizer.tokenize(valid_input_examples[1].text_a))
print(valid_input_examples[1].labels_a)

#### Dataloader

In [None]:
samples_transformer = InputExampleToTensors(tokenizer, 
                                            max_seq_length=64, 
                                            label_tuple=tuple(label_list))

In [None]:
train_data = BertDataset(train_input_examples, 
                         transform=samples_transformer)
train_dataloader = DataLoader(train_data, 
                              sampler=RandomSampler(train_data), 
                              batch_size=batch_size)

In [None]:
valid_data = BertDataset(valid_input_examples, 
                         transform=samples_transformer)
valid_dataloader = DataLoader(valid_data, 
                              sampler=SequentialSampler(valid_data), 
                              batch_size=batch_size)

### 4. Model

In [None]:
model = BertForTokenClassification.from_pretrained(pretrained_model_name, 
                                                   num_labels=len(label_list))
model

### 5. Train

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
trainer = NERTrainer(model, 
                     train_dataloader, 
                     train_dataloader, 
                     label_list, 
                     fp16=True if torch.cuda.is_available() else False
                    )

trainer

In [None]:
trainer.fit(learning_rate=2e-5, 
            num_epochs=num_epochs, 
            verbose=False)

### 6. Results

In [None]:
def plot_learning_rate(metrics):
    lr = metrics['batch']['train']['lr']
    fig, ax = plt.subplots()
    ax.plot(lr, linestyle='', marker='.')
    ax.set_xlabel('batch')
    ax.set_ylabel('learning rate')
    
plot_learning_rate(trainer.metrics)

In [None]:
def plot_metric(metrics, num_epochs, metric, f1_spec=None, ax=None):
    ### PREP ###
    if f1_spec is None:
        batch_train = metrics['batch']['train'][metric]
        epoch_valid = metrics['epoch']['valid'][metric]
    else:
        batch_train = metrics['batch']['train'][metric][f1_spec[0]][f1_spec[1]]
        epoch_valid = metrics['epoch']['valid'][metric][f1_spec[0]][f1_spec[1]]
    
    clr = {'loss': 'r', 
           'acc': 'green', 
           'f1_macro': 'orange',
           'f1_micro': 'blue',
          }
    if f1_spec is None:
        metric_spec = metric
    else:
        f1_spec_1st = f1_spec[0]
        metric_spec = f'{metric}_{f1_spec_1st}'

    ### PLOT ###
    if ax == None:
        fig, ax = plt.subplots()
    
    ax.plot(batch_train, 
            linestyle='-', marker='.', color=clr[metric_spec], alpha=0.3, label='train')
    
    x = [len(batch_train)*float(i)/num_epochs for i in range(1, num_epochs+1)]
    ax.plot(x, epoch_valid, 
            linestyle='', marker='o', color=clr[metric_spec], label='valid')
    
    ax.set_xlabel('batch')
    ax.set_ylabel(metric)
    if metric == 'loss':
        ax.set_ylim([0, None])
    else:
        ax.set_ylim([0, 1])
    if metric in ['loss', 'acc']:
        ax.set_title(metric)
    elif metric == 'f1':
        f1_spec_1st = f1_spec[0]
        f1_spec_2nd = f1_spec[1]
        ax.set_title(f'f1 score: {f1_spec_1st}, {f1_spec_2nd}')
    ax.legend()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
plot_metric(trainer.metrics, num_epochs, 'loss', ax=ax[0])
plot_metric(trainer.metrics, num_epochs, 'acc', ax=ax[1])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
plot_metric(trainer.metrics, num_epochs, 'f1', ('macro', 'all'), ax=ax[0])
plot_metric(trainer.metrics, num_epochs, 'f1', ('macro', 'fil'), ax=ax[1])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
plot_metric(trainer.metrics, num_epochs, 'f1', ('micro', 'all'), ax=ax[0])
plot_metric(trainer.metrics, num_epochs, 'f1', ('micro', 'fil'), ax=ax[1])

#### Different training rates

#### Save Model Checkpoint

In [None]:
torch.save(model.state_dict(), f'./{DIR_CHECKPOINTS}/saved__{dataset}__{model_name}.pkl')