In [None]:
DOWNSTREAM_TASK = 'ner'

# START

In [None]:
import warnings; warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

import logging

import os
import pickle
import numpy as np
import pandas as pd
import nltk
import torch


from transformers import BertTokenizer, BertForTokenClassification
from pytorch_pretrained_bert import BertAdam

from utils.utils import get_available_models
from utils.utils import get_available_datasets
from utils.utils import prune_examples
from utils.utils import ENV_VARIABLE
from utils.utils import preprocess_data
from utils.utils import get_dataset_path
from utils.utils import save_model_checkpoint
from utils.utils import save_metrics
from utils.ner_trainer import NERTrainer

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

In [None]:
DIR_PRETRAINED_MODELS = ENV_VARIABLE['DIR_PRETRAINED_MODELS']
DIR_DATASETS = os.path.join(ENV_VARIABLE['DIR_DATASETS'], DOWNSTREAM_TASK)
DIR_CHECKPOINTS = os.path.join(ENV_VARIABLE['DIR_CHECKPOINTS'], DOWNSTREAM_TASK)

### 0. Model & Dataset

#### Model

In [None]:
pretrained_model_name = 'af-ai-center/bert-base-swedish-uncased'
# pretrained_model_name = 'bert-base-multilingual-uncased'

#### Dataset

In [None]:
available_datasets = get_available_datasets(DOWNSTREAM_TASK)
available_datasets

In [None]:
# dataset = 'SUC'
dataset = 'swedish_ner_corpus'

assert dataset in available_datasets

In [None]:
dataset_path = get_dataset_path(DIR_DATASETS, dataset)
dataset_path

### 1. Parameters

In [None]:
batch_size = 16
max_seq_length = 64
num_epochs = 3
prune_ratio = 0.02
learning_rate = {
    'lr_max': 2e-5,
    'lr_schedule': 'linear_with_warmup',
    'lr_warmup_fraction': 0.1,
}

### 2. Tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name, do_lower_case=False)  # needs to be False !!

### 3. Processor (Data)

In [None]:
dataloader, label_list = preprocess_data(dataset_path, 
                                         tokenizer, 
                                         batch_size, 
                                         max_seq_length=max_seq_length, 
                                         prune_ratio=prune_ratio
                                        )

### 4. Model

In [None]:
model = BertForTokenClassification.from_pretrained(pretrained_model_name, 
                                                   num_labels=len(label_list))

### 5. Train

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
trainer = NERTrainer(model, 
                     train_dataloader=dataloader['train'], 
                     valid_dataloader=dataloader['valid'], 
                     label_list=label_list, 
                     fp16=True if torch.cuda.is_available() else False
                    )

# trainer

In [None]:
trainer.fit(num_epochs=num_epochs,
            **learning_rate,
            verbose=False)

### 6. Save

In [None]:
save_model_checkpoint(model, DIR_CHECKPOINTS, dataset, pretrained_model_name, num_epochs, prune_ratio)

In [None]:
save_metrics(trainer, DIR_CHECKPOINTS, dataset, pretrained_model_name, num_epochs, prune_ratio)