# Imports

In [1]:
import os
from typing import List

In [2]:
from transformers import T5Tokenizer
import transformers

from src.utils import read_txt
from src.input.example import InputExample
from src.input.feature import InputFeature

In [3]:
filepath = '../data/conll2003/train.txt'
text_examples = read_txt(filepath).split('\n\n')[1:-1]
text_examples[0]

'EU NNP B-NP B-ORG\nrejects VBZ B-VP O\nGerman JJ B-NP B-MISC\ncall NN I-NP O\nto TO B-VP O\nboycott VB I-VP O\nBritish JJ B-NP B-MISC\nlamb NN I-NP O\n. . O O'

In [4]:
labels2words = {
    'O': '[Other]',
    'PER': '[Person]',
    'LOC': '[Local]',
    'MISC': '[Miscellaneous]',
    'ORG': '[Organization]'
}

# Examples

In [5]:
def convert_text_to_example_with(text, labels2words={}, split_line_by='\n', split_row_by=' '):
    words, labels = [], []
    for row in text.split(split_line_by):
        ws = row.split(split_row_by)
        words.append(ws[0])
        labels.append(ws[-1])

    source_words = []
    target_words = []

    i = 0
    while len(source_words) < len(words):
        w = words[i]
        l = labels[i]

        if l == 'O':
            source_words.append(w)
            target_words.extend([w, labels2words.get(l, f'<{l}>')])
            i += 1
            continue
        else: # found a B-ENT
            j = i+1
            ent_label = labels[i].split('-')[-1]
            while j < len(labels) and labels[j] == f'I-{ent_label}':
                j += 1
            # adds the span
            source_words.extend(words[i:j])
            target_words.extend(words[i:j] + [labels2words.get(ent_label, f'<{ent_label}>')])
            i = j

    return InputExample(source_words, target_words)

In [6]:
example = convert_text_to_example_with(text_examples[0], labels2words=labels2words)

In [7]:
print(example)

Source: EU rejects German call to boycott British lamb .
Target: EU [Organization] rejects [Other] German [Miscellaneous] call [Other] to [Other] boycott [Other] British [Miscellaneous] lamb [Other] . [Other]


# Features

In [8]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [9]:
def convert_example_to_feature(example: InputExample, tokenizer: transformers.PreTrainedTokenizer,
                               max_length : int = 512,
                               source_max_length: int = None,
                               target_max_length: int = None,
                               prefix: str = 'Extract Entities:') -> InputFeature:
    
    
    source = f'{prefix} {example.source}'.strip()
    target = example.target

    source_tokens = tokenizer.tokenize(source)
    target_tokens = tokenizer.tokenize(target)

    if source_max_length is None:
        source_max_length = max_length
    if target_max_length is None:
        target_max_length = max_length
    
    _source_max = source_max_length - 1  # we will add eos token to the end of both lists
    _target_max = target_max_length - 1
    source_tokens = source_tokens[:min(len(source_tokens), _source_max)]
    target_tokens = target_tokens[:min(len(target_tokens), _target_max)]

    # adding the eos
    source_tokens += [tokenizer.eos_token]
    target_tokens += [tokenizer.eos_token]

    # attention mask
    attention_mask = [1] * len(source_tokens)

    # padding source
    missing_source = max(0, source_max_length - len(source_tokens))
    source_tokens += missing_source * [tokenizer.pad_token]
    attention_mask += missing_source * [0]
    source_token_ids = tokenizer.convert_tokens_to_ids(source_tokens)

    # padding target
    missing_target = max(0, target_max_length - len(target_tokens))
    target_token_ids = tokenizer.convert_tokens_to_ids(
        target_tokens) + missing_target * [-100]

    assert source_max_length == len(
        source_token_ids), f'Max length is {source_max_length} and len(source_token_ids) is {len(source_tokens)}'
    assert target_max_length == len(
        target_token_ids), f'Max length is {target_max_length} and len(target_token_ids) is {len(target_tokens)}'
    assert source_max_length == len(
        attention_mask), f'Max length is {source_max_length} and len(attention_mask) is {len(attention_mask)}'

    return InputFeature(source_token_ids, target_token_ids, attention_mask, example)

In [13]:
feature = convert_example_to_feature(example, tokenizer, max_length=256)

In [14]:
tokenizer.decode(feature.source_token_ids)

'Extract Entities: EU rejects German call to boycott British lamb.'

In [15]:
tokenizer.decode(feature.target_ids)

'EU [Organization] rejects [Other] German [Miscellaneous] call [Other] to [Other] boycott [Other] British [Miscellaneous] lamb [Other]. [Other]'