# Imports

In [1]:
import os
from argparse import Namespace

In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from seqeval.metrics import f1_score, classification_report
import pytorch_lightning as pl

In [3]:
from src.data.make_conll2003 import get_example_sets, InputExample
from src.models.modeling_t5conll2003 import T5ForConll2003

In [7]:
hparams = {
    "experiment_name": "Overfit T5 on CoNLL2003",
    "batch_size": 2, "num_workers": 2,
    "optimizer": "Adam", "lr": 5e-3,
    "datapath": "../data/conll2003",
    "shuffle_train": False,
    "source_max_length": 128,
    "target_max_length": 256,
    "labels_mode": 'tokens',
    "merge_O": True,
}
hparams = Namespace(**hparams)

In [8]:
model = T5ForConll2003.from_pretrained('t5-small', hparams=hparams)

In [9]:
model.prepare_data()
dl_train = model.train_dataloader()
batch = next(iter(dl_train))

In [47]:
def trim_matrix(mat, value):
    eq_val = (mat == value).float()
    eq_val = eq_val.cumsum(-1)
    index  = torch.nonzero(eq_val == 1.)[:,1].max().item()
    return mat[:,:index]

In [48]:
trim_matrix(batch[2], -100)

tensor([[ 3371, 32102, 15092,     7, 32100,  2968, 32104,   580,    12, 30242,
         32100,  2390, 32104, 17871,     3,     5, 32100,     1],
        [ 2737,  1589,  7223, 32101,     1,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100]])

In [39]:
eq_val = (batch[2] == -100).float()
eq_val = eq_val.cumsum(-1)

In [46]:
torch.nonzero(eq_val == 1.)[:,1].max().item()

18

In [49]:
batch[0]

tensor([[18742,  4443,  2197,    10,  3371, 15092,     7,  2968,   580,    12,
         30242,  2390, 17871,     3,     5,     1,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [51]:
model.config.pad_token_id

0

In [None]:
trim_matrix(batc)