# Imports

In [None]:
import os
from argparse import Namespace

In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from seqeval.metrics import f1_score, classification_report
import pytorch_lightning as pl

In [3]:
from src.data.make_conll2003 import get_example_sets, InputExample
from src.models.modeling_t5conll2003 import T5ForConll2003

In [4]:
hparams = {"experiment_name": "Overfit T5 on CoNLL2003",
           "batch_size": 2, "num_workers": 2,
           "optimizer": "Adam", "lr": 5e-3,
           "datapath": "../data/conll2003",
           "shuffle_train": False,
           "source_max_length": 128,
           "target_max_length": 256,
           "labels_mode": 'words'
           }
hparams = Namespace(**hparams)

In [5]:
model = T5ForConll2003.from_pretrained('t5-small', hparams=hparams)

# Overfit and save

In [6]:
overfit = True
overfit_ckpt = 'overfit_words_mode.ckpt'

In [7]:
model.prepare_data()
dl_train = model.train_dataloader()
batch = next(iter(dl_train))

In [8]:
if overfit:
    device = 'cuda'

    model.to(device)
    batch = [x.to(device) for x in batch]

    optimizer = model.configure_optimizers()

    for _ in range(10):
        loss = model.training_step(batch, 0)['loss']
        loss.backward()
        optimizer.step()
        model.zero_grad()

        print(loss.item())

    torch.save(model.state_dict(), overfit_ckpt)
else:
    print(model.load_state_dict(torch.load(overfit_ckpt)))

3.2000813484191895
3.6374001502990723
3.527728796005249
2.0330588817596436
1.4024171829223633
1.1162036657333374
0.7967498898506165
0.3651909828186035
0.15743477642536163
0.0523308590054512


# Evaluation

In [9]:
tokenizer = model.tokenizer

In [10]:
entities_tokens = model.entities_tokens
entities_tokens

['<O>', '<PER>', '<ORG>', '<LOC>', '<MISC>']

In [11]:
target_token_ids = batch[2].cpu()
target_token_ids = target_token_ids.where(target_token_ids != -100, torch.tensor(tokenizer.pad_token_id)).cpu()

In [12]:
predicted_token_ids = model.generate(input_ids=batch[0], attention_mask=batch[1], max_length=model.max_length)

In [13]:
tokenizer.decode(target_token_ids[0])

'EU [Organization] rejects [Other] German [Miscellaneous] call [Other] to [Other] boycott [Other] British [Miscellaneous] lamb [Other]. [Other]'

In [14]:
tokenizer.decode(predicted_token_ids[0])

'EU [Organization] rejects [Other] German [Miscellaneous] call [Other] to [Other] boycott [Other] British [Miscellaneous] lamb [Other]. [Other]'

In [18]:
labels2words = {
    'O': '[Other]',
    'PER': '[Person]',
    'LOC': '[Local]',
    'MISC': '[Miscellaneous]',
    'ORG': '[Organization]'
}

entities2tokens = {w: f'<{l}>' for l,w in labels2words.items()}
entities2tokens

{'[Other]': '<O>',
 '[Person]': '<PER>',
 '[Local]': '<LOC>',
 '[Miscellaneous]': '<MISC>',
 '[Organization]': '<ORG>'}

In [21]:
sentence = tokenizer.decode(predicted_token_ids[0])
for ent, tok in entities2tokens.items():
    sentence = sentence.replace(ent, tok)
tokenizer.tokenize(sentence)

['▁EU',
 '<ORG>',
 '▁reject',
 's',
 '<O>',
 '▁German',
 '<MISC>',
 '▁call',
 '<O>',
 '▁to',
 '<O>',
 '▁boycott',
 '<O>',
 '▁British',
 '<MISC>',
 '▁lamb',
 '<O>',
 '▁',
 '.',
 '<O>']

In [22]:
def get_entities_from_tokens(tokens, tokenizer, entities_tokens, length=0, fill_token='O'):
    sequence_entities = [] # will save all the entities
    current_entity = [] # will save current entity
    if tokens[0] == tokenizer.pad_token:
        tokens = tokens[1:]
    for token in tokens:
        if token in entities_tokens:
            entity = token[1:-1] # remove <,>
            if entity == 'O':
                blabel = ilabel = entity
            else:
                blabel = f'B-{entity}'
                ilabel = f'I-{entity}'
            _len = len(current_entity)
            sequence_entities += [blabel] + [ilabel] * (_len - 1)
            current_entity.clear()
        elif token in (tokenizer.eos_token, tokenizer.pad_token):
            break
        else:
            current_entity.append(token)
    if length > 0:
        seq_len = len(sequence_entities)
        if seq_len > length:
            sequence_entities = sequence_entities[:length]
        elif seq_len < length:
            sequence_entities = sequence_entities + [fill_token] * (length - seq_len)
    return sequence_entities

In [25]:
def get_tokens(token_ids, tokenizer, entities):
    if isinstance(entities, dict):
        sentence = tokenizer.decode(token_ids)
        for ent, tok in entities.items():
            sentence = sentence.replace(ent, tok)
        return tokenizer.tokenize(sentence)
    else:
        return tokenizer.convert_ids_to_tokens(token_ids)

In [26]:
def get_trues_and_preds_entities(target_token_ids, predicted_token_ids,
                                tokenizer, entities, fill_token='O'):
    assert len(target_token_ids) ==  len(predicted_token_ids) # ensure batch size is the same
    all_target_entities = []
    all_predicted_entities = []
    entities_tokens = list(entities.values()) if isinstance(entities, dict) else entities
    for idx in range(len(target_token_ids)):
        # convert to tokens
        target_tokens = get_tokens(target_token_ids[idx], tokenizer, entities)
        predicted_tokens = get_tokens(predicted_token_ids[idx], tokenizer, entities)
        # convert to entities
        target_entities = get_entities_from_tokens(target_tokens, tokenizer, entities_tokens)
        predicted_entities = get_entities_from_tokens(predicted_tokens, tokenizer, entities_tokens, length=len(target_entities), fill_token=fill_token)
        # append
        all_target_entities.append(target_entities)
        all_predicted_entities.append(predicted_entities)
    return all_target_entities, all_predicted_entities

In [27]:
target_entities, predicted_entities = get_trues_and_preds_entities(target_token_ids, predicted_token_ids, tokenizer, entities=entities2tokens)

In [29]:
target_entities, predicted_entities

([['B-ORG', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O'],
  ['B-PER', 'I-PER', 'I-PER']],
 [['B-ORG', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O'],
  ['B-PER', 'I-PER', 'I-PER']])

# Seqeval

In [111]:
print(classification_report(target_entities, predicted_entities))

           precision    recall  f1-score   support

      PER       1.00      1.00      1.00         1
     MISC       1.00      1.00      1.00         2
      ORG       1.00      1.00      1.00         1

micro avg       1.00      1.00      1.00         4
macro avg       1.00      1.00      1.00         4



In [114]:
print(target_entities)

[['B-ORG', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O'], ['B-PER', 'I-PER', 'I-PER']]


In [115]:
print(predicted_entities)

[['B-ORG', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O'], ['B-PER', 'I-PER', 'I-PER']]
