# Imports

In [2]:
import os
from argparse import Namespace

In [3]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [4]:
import pytorch_lightning as pl

In [5]:
from src.data.make_conll2003 import get_example_sets, InputExample
from src.input.dataset import T5NERDataset
from src.models.modeling_t5conll2003 import T5ForConll2003

In [6]:
hparams = {"experiment_name": "Overfit T5 on CoNLL2003",
           "batch_size": 2, "num_workers": 2,
           "optimizer": "Adam", "lr": 5e-4,
           "datapath": "../data/conll2003"
           }
hparams = Namespace(**hparams)

In [7]:
class OverfitT5(T5ForConll2003):
    
    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

In [8]:
model = OverfitT5.from_pretrained('t5-small', hparams=hparams)

# Overfit with PL

In [9]:
trainer = pl.Trainer(gpus=1, max_epochs=50, overfit_pct=0.001)

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(model)


    | Name                                                            | Type                  | Params
------------------------------------------------------------------------------------------------------
0   | shared                                                          | Embedding             | 16 M  
1   | encoder                                                         | T5Stack               | 35 M  
2   | encoder.block                                                   | ModuleList            | 18 M  
3   | encoder.block.0                                                 | T5Block               | 3 M   
4   | encoder.block.0.layer                                           | ModuleList            | 3 M   
5   | encoder.block.0.layer.0                                         | T5LayerSelfAttention  | 1 M   
6   | encoder.block.0.layer.0.SelfAttention                           | T5Attention           | 1 M   
7   | encoder.block.0.layer.0.SelfAttention.q                         | 

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Process Process-131:
Process Process-132:
Traceback (most recent call last):
  File "/home/israel/miniconda3/envs/t5ner/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/israel/miniconda3/envs/t5ner/lib/python3.8/multiprocessing/util.py", line 334, in _exit_function
    _run_finalizers(0)
  File "/home/israel/miniconda3/envs/t5ner/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/israel/miniconda3/envs/t5ner/lib/python3.8/multiprocessing/util.py", line 334, in _exit_function
    _run_finalizers(0)
  File "/home/israel/miniconda3/envs/t5ner/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/israel/miniconda3/envs/t5ner/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/israel/miniconda3/envs/t5ner/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callba

In [30]:
trainer.test(model)



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'f1': 0,
 'report': '           precision    recall  f1-score   support\n'
           '\n'
           '     MISC       0.00      0.00      0.00         8\n'
           '      ORG       0.00      0.00      0.00         9\n'
           '      PER       0.00      0.00      0.00         9\n'
           '      LOC       0.00      0.00      0.00         7\n'
           '\n'
           'micro avg       0.00      0.00      0.00        33\n'
           'macro avg       0.00      0.00      0.00        33\n',
 'test_loss': tensor(1.1957, device='cuda:0')}
--------------------------------------------------------------------------------



In [31]:
dl_train = model.train_dataloader()
batch = next(iter(dl_train))

In [32]:
batch = [x.cuda() for x in batch]

In [33]:
outputs, target_entities, predicted_entities = model._handle_eval_batch(batch)

In [34]:
target_entities

[['O', 'O', 'O', 'O'],
 ['O',
  'O',
  'B-MISC',
  'I-MISC',
  'I-MISC',
  'I-MISC',
  'I-MISC',
  'I-MISC',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O']]

In [35]:
model.tokenizer.decode(batch[0][0])

'Extract Entities: Attendance 3,000'

In [36]:
target_token_ids = model.get_target_token_ids(batch)

In [37]:
model.tokenizer.decode(target_token_ids[0])

'Attendance <O> 3,000 <O> '

In [38]:
predicted_entities

[['O', 'O', 'O', 'O'],
 ['O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O']]