# Imports

In [1]:
import os
from argparse import Namespace

In [112]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from seqeval.metrics import f1_score, classification_report
import pytorch_lightning as pl

In [4]:
from src.data.make_conll2003 import get_example_sets, InputExample
from src.models.modeling_t5conll2003 import T5ForConll2003

In [5]:
hparams = {"experiment_name": "Overfit T5 on CoNLL2003",
           "batch_size": 2, "num_workers": 2,
           "optimizer": "Adam", "lr": 5e-3,
           "datapath": "../data/conll2003",
           "shuffle_train": False
           }
hparams = Namespace(**hparams)

In [6]:
model = T5ForConll2003.from_pretrained('t5-small', hparams=hparams)

# Overfit and save

In [9]:
overfit = False
overfit_ckpt = 'overfit.ckpt'

In [8]:
model.prepare_data()
dl_train = model.train_dataloader()
batch = next(iter(dl_train))

In [10]:
if overfit:
    device = 'cuda'

    model.to(device)
    batch = [x.to(device) for x in batch]

    optimizer = model.configure_optimizers()

    for _ in range(100):
        loss = model.training_step(batch, 0)['loss']
        loss.backward()
        optimizer.step()
        model.zero_grad()

        print(loss.item())

    torch.save(model.state_dict(), overfit_ckpt)
else:
    print(model.load_state_dict(torch.load(overfit_ckpt)))

16.06240463256836
12.487531661987305
9.963590621948242
8.65821647644043
7.520566940307617
5.975543022155762
4.154529094696045
3.12961483001709
3.3882970809936523
3.0742173194885254
2.8745617866516113
3.2312960624694824
2.7839860916137695
2.5749077796936035
2.272475481033325
2.0055387020111084
1.9658942222595215
1.7902837991714478
1.694254994392395
1.5763726234436035
1.526963472366333
1.4052767753601074
1.2932432889938354
1.2178740501403809
1.1427955627441406
1.0735819339752197
1.0112868547439575
0.9507379531860352
0.8920056223869324
0.8353450298309326
0.7822046279907227
0.7321492433547974
0.6795676946640015
0.6310276389122009
0.5843082070350647
0.5381196737289429
0.4926541745662689
0.44900447130203247
0.4069681167602539
0.3668065369129181
0.32949236035346985
0.2950741946697235
0.2641274034976959
0.2370295524597168
0.21271643042564392
0.19188906252384186
0.17414654791355133
0.1584789901971817
0.14444348216056824
0.13255152106285095
0.12253027409315109
0.11349175870418549
0.1049143001437

# Evaluation

In [11]:
tokenizer = model.tokenizer

In [12]:
entities_tokens = model.entities_tokens
entities_tokens

['<O>', '<PER>', '<ORG>', '<LOC>', '<MISC>']

In [17]:
target_token_ids = batch[2].cpu()
target_token_ids = true_targets.where(target_token_ids != -100, torch.tensor(tokenizer.pad_token_id)).cpu()

In [27]:
predicted_token_ids = model.generate(input_ids=batch[0], attention_mask=batch[1], max_length=model.max_length)

In [18]:
tokenizer.decode(target_token_ids[0])

'EU <ORG> rejects <O> German <MISC> call <O> to <O> boycott <O> British <MISC> lamb <O>. <O> '

In [26]:
tokenizer.decode(predicted_token_ids[1])

'Peter Blackburn <PER> '

In [95]:
def get_entities_from_tokens(tokens, tokenizer, entities_tokens, length=0, fill_token='O'):
    sequence_entities = [] # will save all the entities
    current_entity = [] # will save current entity
    if tokens[0] == tokenizer.pad_token:
        tokens = tokens[1:]
    for token in tokens:
        if token in entities_tokens:
            entity = token[1:-1] # remove <,>
            if entity == 'O':
                blabel = ilabel = entity
            else:
                blabel = f'B-{entity}'
                ilabel = f'I-{entity}'
            _len = len(current_entity)
            sequence_entities += [blabel] + [ilabel] * (_len - 1)
            current_entity.clear()
        elif token in (tokenizer.eos_token, tokenizer.pad_token):
            break
        else:
            current_entity.append(token)
    if length > 0:
        seq_len = len(sequence_entities)
        if seq_len > length:
            sequence_entities = sequence_entities[:length]
        elif seq_len < length:
            sequence_entities = sequence_entities + [fill_token] * (length - seq_len)
    return sequence_entities

In [108]:
def get_trues_and_preds_entities(target_token_ids, predicted_token_ids,
                                tokenizer, entities_tokens, fill_token='O'):
    assert len(target_token_ids) ==  len(predicted_token_ids)
    all_target_entities = []
    all_predicted_entities = []
    for idx in range(len(target_token_ids)):
        # convert to tokens
        target_tokens = tokenizer.convert_ids_to_tokens(target_token_ids[idx])
        predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids[idx])
        # convert to entities
        target_entities = get_entities_from_tokens(target_tokens, tokenizer, entities_tokens)
        predicted_entities = get_entities_from_tokens(predicted_tokens, tokenizer, entities_tokens, length=len(target_entities), fill_token=fill_token)
        # append
        all_target_entities.append(target_entities)
        all_predicted_entities.append(predicted_entities)
    return all_target_entities, all_predicted_entities

In [109]:
target_entities, predicted_entities = get_trues_and_preds_entities(target_token_ids, predicted_token_ids, tokenizer, entities_tokens)

# Seqeval

In [111]:
print(classification_report(target_entities, predicted_entities))

           precision    recall  f1-score   support

      PER       1.00      1.00      1.00         1
     MISC       1.00      1.00      1.00         2
      ORG       1.00      1.00      1.00         1

micro avg       1.00      1.00      1.00         4
macro avg       1.00      1.00      1.00         4



In [114]:
print(target_entities)

[['B-ORG', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O'], ['B-PER', 'I-PER', 'I-PER']]


In [115]:
print(predicted_entities)

[['B-ORG', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O'], ['B-PER', 'I-PER', 'I-PER']]
