# Imports

In [19]:
import os
from typing import List, Dict

In [45]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [50]:
import pytorch_lightning as pl

In [3]:
from src.data.make_conll2003 import get_example_sets, InputExample

# Note:
Some of the steps might be done twice during the course of this notebook, but its for debugging purposes.

Final codes will not have duplicities

# Data

In [6]:
folderpath = '../data/conll2003/'
sets_dict = get_example_sets(folderpath)
sets_dict.keys()

dict_keys(['train', 'valid', 'test'])

In [7]:
sets_dict['train'][0]

Source: EU rejects German call to boycott British lamb .
Target: EU <ORG> rejects <O> German <MISC> call <O> to <O> boycott <O> British <MISC> lamb <O> . <O>

# Tokenizer

In [8]:
NER_LABELS = [
    '<O>',
    '<PER>',
    '<ORG>',
    '<LOC>',
    '<MISC>'
]

In [9]:
pretrained_model = 't5-small'

In [10]:
tokenizer = T5Tokenizer.from_pretrained(pretrained_model)

In [11]:
tokenizer.add_tokens(NER_LABELS)

5

# Features

In [124]:
def get_entities_from_token_ids(token_ids:List[int], tokenizer: T5Tokenizer, NER_LABELS: List[str]):
    entities = {k:[] for k in NER_LABELS}
    current_entity = []
    sequence_tokens = tokenizer.convert_ids_to_tokens(token_ids)
    for token in sequence_tokens:
        if token in NER_LABELS:
            entities[token].append(tokenizer.convert_tokens_to_string(current_entity))
            current_entity.clear()
        else:
            current_entity.append(token)
    return entities

In [125]:
class InputFeature:
    
    def __init__(self, source_token_ids, target_token_ids, attention_mask, example=None):
        self.source_token_ids = source_token_ids
        self.target_token_ids = target_token_ids
        self.attention_mask = attention_mask
        self.example = example

In [126]:
def convert_example_to_feature(example: InputExample, tokenizer:T5Tokenizer, max_length:int=512, prefix:str='Extract Entities:') -> InputFeature:
    source = f'{prefix} {example.source}'
    target = example.target
    
    source_tokens = tokenizer.tokenize(source)
    target_tokens = tokenizer.tokenize(target)
    
    _max = max_length - 1 # we will add eos token to the end of both lists
    source_tokens = source_tokens[:min(len(source_tokens), _max)]
    target_tokens = target_tokens[:min(len(target_tokens), _max)]
    
    # adding the eos
    source_tokens += [tokenizer.eos_token]
    target_tokens += [tokenizer.eos_token]
    
    # attention mask
    attention_mask = [1] * len(source_tokens)
    
    # padding
    missing_source = max(0, max_length - len(source_tokens))
    missing_target = max(0, max_length - len(target_tokens))
    source_tokens += missing_source * [tokenizer.pad_token]
    target_tokens += missing_target * [tokenizer.pad_token]
    attention_mask += missing_source * [0]
    
    source_token_ids = tokenizer.convert_tokens_to_ids(source_tokens)
    target_token_ids = tokenizer.convert_tokens_to_ids(target_tokens)
    
    assert max_length == len(source_tokens), f'Max length is {max_length} and len(source_tokens) is {len(source_tokens)}'
    assert max_length == len(target_tokens), f'Max length is {max_length} and len(target_tokens) is {len(target_tokens)}'
    assert max_length == len(attention_mask), f'Max length is {max_length} and len(attention_mask) is {len(attention_mask)}'
    
    return InputFeature(source_token_ids, target_token_ids, attention_mask, example)

In [127]:
def convert_examples_to_features(examples: List[InputExample], tokenizer:T5Tokenizer, max_length:int=512, prefix:str='Extract Entities:')->List[InputFeature]:
    return [convert_example_to_feature(example, tokenizer, max_length, prefix) for example in examples]

In [128]:
def convert_example_sets_to_features_sets(examples_sets: Dict[str, List[InputExample]], tokenizer:T5Tokenizer, max_length:int=512, prefix:str='Extract Entities:') -> Dict[str, List[InputFeature]]:
    return {
        key:convert_examples_to_features(examples, tokenizer, max_length, prefix) for key, examples in examples_sets.items()
    }

In [129]:
features_sets = convert_example_sets_to_features_sets(sets_dict, tokenizer, max_length=128)

In [130]:
feature = features_sets['train'][0]

In [131]:
feature.example

Source: EU rejects German call to boycott British lamb .
Target: EU <ORG> rejects <O> German <MISC> call <O> to <O> boycott <O> British <MISC> lamb <O> . <O>

In [132]:
tokenizer.decode(feature.source_token_ids), tokenizer.decode(feature.target_token_ids)

('Extract Entities: EU rejects German call to boycott British lamb.',
 'EU <ORG> rejects <O> German <MISC> call <O> to <O> boycott <O> British <MISC> lamb <O>. <O> ')

# Dataset

In [133]:
class FeaturesDataset(Dataset):
    
    def __init__(self, features):
        self.features = features
        
    def __len__(self,):
        return len(self.features)
    
    def __getitem__(self, idx):
        raise NotImplementedError()

In [134]:
class T5NERDataset(FeaturesDataset):
    
    def __init__(self, features, *args, tokenizer=None, NER_LABELS=None, **kwargs):
        super().__init__(features, *args, **kwargs)
        self.tokenizer = tokenizer
        self.NER_LABELS = NER_LABELS
    
    def __getitem__(self, idx):
        feat = self.features[idx]
        input_ids = torch.tensor(feat.source_token_ids, dtype=torch.long)
        attention_mask = torch.tensor(feat.attention_mask, dtype=torch.long)
        lm_labels = torch.tensor(feat.target_token_ids, dtype=torch.long)
        
        outputs = (input_ids, attention_mask, lm_labels)
        return outputs

In [135]:
ds_debug = T5NERDataset(features_sets['train'][:10])
dl_debug = DataLoader(ds_debug, batch_size=2, shuffle=False)

In [136]:
input_ids, attention_mask, lm_labels = next(iter(dl_debug))

In [137]:
input_ids.shape, attention_mask.shape, lm_labels.shape

(torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128]))

# Model

In [164]:
class T5ForNERWithPL(T5ForConditionalGeneration, pl.LightningModule):
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
        cls.pretrained_model_name_or_path = pretrained_model_name_or_path
        return super(T5ForConditionalGeneration, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
    
    def _handle_batch(self, batch):
        input_ids, attention_mask, lm_labels = batch
        outputs = self(input_ids=input_ids, attention_mask=attention_mask, lm_labels=lm_labels)
        return outputs
    
    def _average_key(self, outputs, key):
        return torch.mean(torch.stack([o[key] for o in outputs]).float())
    
    def train_step(self, batch):
        outputs = self._handle_batch(batch)
        return {'loss': outputs[0]}
    
    def validation_step(self, batch):
        outputs = self._handle_batch(batch)
        return {'val_loss': outputs[0]}
    
    def test_step(self, batch):
        outputs = self._handle_batch(batch)
        return {'test_loss': outputs[0]}
    
    def validation_epoch_end(self, outputs):
        loss_avg = self._average_key(outputs, 'val_loss')
        return {'val_loss': loss_avg}
    
    def test_epoch_end(self, outputs):
        loss_avg = self._average_key(outputs, 'test_loss')
        return {'test_loss': loss_avg}
        
    def configure_optimizers(self):
        raise NotImplementedError
        
    def train_dataloader(self):
        raise NotImplementedError

    def val_dataloader(self):
        raise NotImplementedError

    def test_dataloader(self):
        raise NotImplementedError

In [168]:
class T5ForConll2003(T5ForNERWithPL):
    
    @property
    def pretrained_model_name(self,):
        return self.pretrained_model_name_or_path
    
    @property
    def datapath(self,):
        return '../data/conll2003/'
    
    @property
    def max_length(self,):
        return 128
    
    @property
    def batch_size(self,):
        return 2
    
    @property
    def num_workers(self,):
        return 2
    
    def get_examples(self,):
        return get_example_sets(self.datapath)
    
    def get_tokenizer(self,):
        return T5Tokenizer.from_pretrained(self.pretrained_model_name)
    
    def get_features(self,examples, tokenizer, max_length):
        return convert_example_sets_to_features_sets(examples, tokenizer, max_length=max_length)
    
    def get_datasets(self, features):
        return features['train'], features['valid'], features['test']
    
    def prepare_data(self,):
        examples = self.get_example_sets()
        self.tokenizer = self.get_tokenizer()
        features = self.get_features(examples, self.tokenizer, max_length=self.max_length)
        self.train_dataset, self.val_dataset, self.test_dataset = self.get_datasets(features)
        
    def train_dataloader(self):
        raise DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        raise DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        raise DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        
    def configure_optimizers(self):
        raise torch.optim.Adam(self.parameters(), lr=1e-4)

In [169]:
model = T5ForConll2003.from_pretrained(pretrained_model)

In [170]:
with torch.no_grad():
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, lm_labels=lm_labels)

In [171]:
len(outputs), outputs[0]

(4, tensor(13.1443))

In [172]:
predicted_token_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=50)

In [173]:
predicted_token_ids.shape

torch.Size([2, 5])

In [174]:
i = 0
tokenizer.decode(predicted_token_ids[i]), tokenizer.decode(input_ids[i]), tokenizer.decode(lm_labels[i])

('<extra_id_0>:',
 'Extract Entities: EU rejects German call to boycott British lamb.',
 'EU <ORG> rejects <O> German <MISC> call <O> to <O> boycott <O> British <MISC> lamb <O>. <O> ')

In [175]:
sequence = tokenizer.decode(lm_labels[i])
sequence

'EU <ORG> rejects <O> German <MISC> call <O> to <O> boycott <O> British <MISC> lamb <O>. <O> '