In [None]:
DOWNSTREAM_TASK = 'ner'

# START

In [None]:
import warnings; warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

import logging
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import pandas as pd
import nltk
import matplotlib.pyplot as plt

from transformers import *
from pytorch_pretrained_bert import BertAdam

from utils import *
from utils.datasets import BertDataset as BDS

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

In [None]:
DIR_PRETRAINED_MODELS = ENV_VARIABLE['DIR_PRETRAINED_MODELS']
DIR_DATASETS = os.path.join(ENV_VARIABLE['DIR_DATASETS'], DOWNSTREAM_TASK)
DIR_CHECKPOINTS = os.path.join(ENV_VARIABLE['DIR_CHECKPOINTS'], DOWNSTREAM_TASK)

### 0. Available Models & Datasets

In [None]:
available_models = get_available_models()
available_models

In [None]:
available_datasets = get_available_datasets(DOWNSTREAM_TASK)
available_datasets

### 1. Settings

In [None]:
model_name = 'swe-uncased_L-12_H-768_A-12'
#model_name = 'swe-uncased_L-24_H-1024_A-16'
#model_name = 'bert-base-uncased'
#model_name = 'bert-base-multilingual-uncased'

In [None]:
#dataset = 'SUC'
dataset = 'swedish_ner_corpus'

In [None]:
assert model_name in available_models
assert dataset in available_datasets

### 2. Example

In [None]:
example_sentence = \
    'iran har hittills inte reagerat på någondera av de stora påkarna som saudier och irakier hött med :' + \
    ' landsbergis ansåg att gorbatjovs lördagsappell visade att denne ignorerar vädjanden från väst om att' + \
    ' börja tala med regeringen i vilnius .'

In [None]:
if model_name.startswith('swe'):
    pretrained_model_name = f'{DIR_PRETRAINED_MODELS}/{model_name}'
else:
    pretrained_model_name = model_name

pretrained_model_name

In [None]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name, do_lower_case=True)
tokenized_text = tokenizer.tokenize(example_sentence)

In [None]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
len(indexed_tokens), indexed_tokens[:3]

### 3. Processor

In [None]:
if dataset == 'SUC':
    dataset_path = f'{DIR_DATASETS}/SUC/moreTags/'
    processor = SUCProcessor(dataset_path, tokenizer, do_lower_case=True)
elif dataset == 'swedish_ner_corpus':
    dataset_path = f'{DIR_DATASETS}/swedish_ner_corpus/'
    processor = SwedishNERCorpusProcessor(dataset_path, tokenizer, do_lower_case=True)

dataset_path, processor

In [None]:
label_list = processor.get_label_list()
label_list

#### Prune Examples (Temp)

In [None]:
prune_ratio = 0.01

#### Train Data

In [None]:
train_examples_all = processor.get_train_examples()

In [None]:
train_examples = prune_examples(train_examples_all, ratio=prune_ratio)

In [None]:
print(train_examples[8].guid)
print(train_examples[8].text_a)
print(train_examples[8].text_b)
print(train_examples[8].label)

#### Validation Data

In [None]:
val_examples_all = processor.get_test_examples()

In [None]:
val_examples = prune_examples(val_examples_all, ratio=prune_ratio)

In [None]:
print(len(val_examples))
print(val_examples[1].text_a)
print(val_examples[1].label)

#### Dataloader

In [None]:
batch_size = 16

In [None]:
#B-LOC O B-TME O O O O O O O O O O O O O O O B-PRS O O B-PRS O O O O O O O O O O O O O O O B-LOC O
samples_transformer = InputExampleToTensors(tokenizer, max_seq_length=64, label_list=label_list)

In [None]:
train_data = BDS(train_examples, transform=samples_transformer)
train_dataloader = DataLoader(train_data, sampler=RandomSampler(train_data), batch_size=batch_size)

In [None]:
validation_data = BDS(val_examples, transform=samples_transformer)
valid_dataloader = DataLoader(validation_data, sampler=SequentialSampler(validation_data), batch_size=batch_size)

### 4. Train

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
pretrained_model_name

In [None]:
model = BertForTokenClassification.from_pretrained(pretrained_model_name, num_labels=len(label_list))
model

In [None]:
trainer = NERTrainer(model, train_dataloader, valid_dataloader, label_list, fp16=True if device == "cuda" else False)
trainer

In [None]:
trainer.fit(learning_rate=2e-5, num_epochs=1)  # 4

#### Different training rates

#### Save Model Checkpoint

In [None]:
torch.save(model.state_dict(), f'./{DIR_CHECKPOINTS}/saved__{dataset}__{model_name}.pkl')

### 5. Investigate

In [None]:
import numpy as np

In [None]:
sent_tokenizer = tokenizer
sentence_lower = 'En som arbetar mycket hårt är erik som är politisk aktiv inom anderst i sverige .'
sentence_lower = sentence_lower.lower()
print(sentence_lower)

In [None]:
example = InputExample("", sentence_lower, label='O')
example

In [None]:
to_tensors = InputExampleToTensors(tokenizer, max_seq_length=128, label_list=label_list)
input_ids, input_mask, segment_ids, label_id = to_tensors(example)
input_ids, input_mask, segment_ids, label_id

In [None]:
tokens_tensor = input_ids.view(1,-1)
segments_tensors = segment_ids.view(1,-1)

In [None]:
if device == 'cuda':
    model.to('cuda')
    tokens_tensor.to('cuda')
    segments_tensors.to('cuda')

In [None]:
model.eval()

In [None]:
logits = model(tokens_tensor, segments_tensors)
logits

In [None]:
res = []
res.extend(logits[0].argmax(-1))
res

In [None]:
np_logits = logits[0].detach().cpu().numpy()
np.argmax(np_logits, axis=2)

In [None]:
lst = np.argmax(np_logits, axis=2)[0].tolist()
lst = lst[1:]
lst

In [None]:
splitinput = tokenizer.tokenize(sentence_lower)
splitinput

In [None]:
for num, word in zip(lst, splitinput):
    if num == 4:
        print("PERSON: " + word)
    elif num == 5:
        print("ORG: " + word)
    elif num == 6:
        print("LOCATION: " + word)
    elif num == 7:
        print("WORK: " + word)
    elif num == 8:
        print("PRODUCT: " + word)
    else:
        print(num, word)
#    if num == 4 or num == 5:
#        print('{} {}'.format(num, word))

In [None]:
print("PER: " + str(trainer.total_per_correct / trainer.total_per))
print("LOC: " + str(trainer.total_loc_correct / trainer.total_loc))
print("ORG: " + str(trainer.total_org_correct / trainer.total_org))

In [None]:
trainer.labelDict

In [None]:
trainer.val_f1_score_hist

In [None]:
def cluster(my_list, n):
    final = [my_list[i * n:(i + 1) * n] for i in range((len(my_list) + n - 1) // n )]
    return list(map(lambda x: sum(x)/len(x), final))