In [1]:
from transformers import T5TokenizerFast, T5ForConditionalGeneration

tokenizer = T5TokenizerFast.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').cuda()

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [1]:
import datasets

dataset = datasets.load_dataset('ms_marco', 'v2.1', split='train[:5000]')

Downloading and preparing dataset ms_marco/v2.1 to /home/bary/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

In [5]:
import random

labels = dataset['query']
text = ['; '.join(item['passages']['passage_text']) for item in dataset]

def generate_queries(label):
    words = label.split()
    if len(words) < 2:
        return label
    rnd = random.randrange(1, len(words))
    return ' '.join(words[:rnd])

queries = [generate_queries(label) for label in labels]

In [6]:
dataset = datasets.Dataset.from_dict({'text': text, 'label': labels, 'query': queries})

In [7]:
dataset = dataset.train_test_split(test_size=0.1)

In [8]:
train_dataset = dataset['train']
test_dataset = dataset['test']

In [18]:
list(model.named_parameters())

[('shared.weight',
  Parameter containing:
  tensor([[ 2.5135, -0.7213, -1.6779,  ..., -0.9080,  1.2072,  0.6901],
          [ 1.2128, -0.0105, -0.5603,  ..., -0.7985,  1.0677,  1.2400],
          [ 4.6471,  7.6964,  5.1765,  ...,  6.5332, -9.7251,  0.5026],
          ...,
          [-0.3536,  0.0812, -0.2395,  ..., -0.1056,  0.0703, -0.1708],
          [-0.3536,  0.0810, -0.2397,  ..., -0.1056,  0.0704, -0.1708],
          [-0.3540,  0.0813, -0.2394,  ..., -0.1059,  0.0698, -0.1708]],
         device='cuda:0', requires_grad=True)),
 ('encoder.block.0.layer.0.SelfAttention.q.weight',
  Parameter containing:
  tensor([[ -0.0648,   3.7918,   0.4717,  ...,  -4.3506,   2.2311,  -4.4022],
          [ -0.0227,   4.6708,   1.0731,  ...,  -4.3392,   2.6174,  -4.0526],
          [ -0.0875,   4.0010,   0.3555,  ...,  -3.0482,   1.3064,  -3.8960],
          ...,
          [ -2.0381,  13.7276,   6.2290,  ...,   3.5680, -12.9209, -13.4118],
          [  0.8562, -13.4800,  -2.5859,  ...,  -1.2085,  

In [9]:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

def train(epochs, lr, gamma=0.8, print_every=10):
    global model
    batch_size = 10

    model = model.cuda()

    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    for epoch in range(epochs):

        total_loss = 0
        for i in range(0, len(train_dataset), batch_size):
            batch_number = i // batch_size
            batch = train_dataset[i:i+batch_size]
            tokenized = tokenizer.batch_encode_plus(
                batch['text'], truncation=True, padding='max_length', max_length=256, return_tensors='pt')
            decoder_input = tokenizer.batch_encode_plus(
                batch['query'], truncation=True, padding='max_length', max_length=256, return_tensors='pt')
            labels = tokenizer.batch_encode_plus(
                batch['label'], truncation=True, padding='max_length', max_length=256, return_tensors='pt')
            decoder_attention_mask = decoder_input['attention_mask']
            for i in range(len(decoder_attention_mask)):
                decoder_attention_mask[i][sum(decoder_attention_mask[i]) - 1] = 1
            decoder_input_ids = decoder_input['input_ids']
            decoder_input_ids[decoder_input_ids == 1] = 0.0

            # if i == 0:
            #     print(decoder_input_ids[0])
            #     print(decoder_attention_mask[0])
            #     print(labels['input_ids'][0])
            output = model(
                input_ids=tokenized['input_ids'].cuda(),
                attention_mask=tokenized['attention_mask'].cuda(),
                decoder_input_ids=decoder_input_ids.cuda(),
                decoder_attention_mask=decoder_attention_mask.cuda(),
                labels=labels['input_ids'].cuda(),
                )
            loss = output.loss
            # if i == 0:
            #     print(output)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

            optimizer.step()

            total_loss += loss.item()
            if batch_number % print_every == 0 and batch_number > 0:
                print(f'Epoch: {epoch}, Batch: {batch_number}, Loss: {total_loss / print_every}')
                total_loss = 0

        scheduler.step()
        

In [10]:
train(epochs=3, lr=1e-5, print_every=10)

Epoch: 0, Batch: 10, Loss: 21.79952049255371
Epoch: 0, Batch: 20, Loss: 16.413487052917482
Epoch: 0, Batch: 30, Loss: 14.275240612030029
Epoch: 0, Batch: 40, Loss: 13.070242500305175
Epoch: 0, Batch: 50, Loss: 11.542870140075683
Epoch: 0, Batch: 60, Loss: 10.304472351074219
Epoch: 0, Batch: 70, Loss: 9.359173011779784
Epoch: 0, Batch: 80, Loss: 7.773673248291016
Epoch: 0, Batch: 90, Loss: 7.07403826713562
Epoch: 0, Batch: 100, Loss: 6.049855470657349
Epoch: 0, Batch: 110, Loss: 5.013722658157349
Epoch: 0, Batch: 120, Loss: 4.387119317054749
Epoch: 0, Batch: 130, Loss: 3.3896918535232543
Epoch: 0, Batch: 140, Loss: 3.1415919065475464
Epoch: 0, Batch: 150, Loss: 2.505714476108551
Epoch: 0, Batch: 160, Loss: 2.363430690765381
Epoch: 0, Batch: 170, Loss: 1.948194432258606
Epoch: 0, Batch: 180, Loss: 1.7829601168632507
Epoch: 0, Batch: 190, Loss: 1.6313425540924071
Epoch: 0, Batch: 200, Loss: 1.5138441443443298
Epoch: 0, Batch: 210, Loss: 1.4116634130477905
Epoch: 0, Batch: 220, Loss: 1.255

KeyboardInterrupt: 

In [13]:
model = model.cuda()
for i in range(5):
    item = test_dataset[i]
    print('Query:', item['query'])
    print('Label:', item['label'])
    # print('Text:', item['text'])
    text = tokenizer(item['text'], return_tensors='pt', max_length=256, truncation=True, padding='max_length')
    query = tokenizer(item['query'], return_tensors='pt', max_length=256, truncation=True, padding='max_length')
    decoder_input_ids = query['input_ids']
    decoder_input_ids[decoder_input_ids == 1] = 0
    decoder_attention_mask = query['attention_mask']
    decoder_attention_mask = decoder_attention_mask[:, :sum(decoder_attention_mask[0]) - 1]
    output = model.generate(
        input_ids=text['input_ids'].cuda(),
        attention_mask=text['attention_mask'].cuda(),
        decoder_input_ids=decoder_input_ids.cuda(), 
        decoder_attention_mask=decoder_attention_mask.cuda(),
        max_new_tokens=10, num_beams=3, num_return_sequences=3, early_stopping=True)
    print('Prediction:', [tokenizer.decode(o, skip_special_tokens=True) for o in output])

Query: in what time period must a stroke victim be
Label: in what time period must a stroke victim be treated medically
Prediction: ['in what time period must a stroke victim be', 'in what time period must a stroke victim be', 'in what time period must a stroke victim be']
Query: how many tiles per square metre with tile size 200mm
Label: how many tiles per square metre with tile size 200mm x 100mm?
Prediction: ['how many tiles per square metre with tile size 200mm', 'how many tiles per square metre with tile size 200mm', 'how many tiles per square metre with tile size 200mm']
Query: initiation fee for planet
Label: initiation fee for planet fitness
Prediction: ['initiation fee for planet', 'initiation fee for planet ', 'initiation fee for planet']
Query: how creative writing is beneficial
Label: how creative writing is beneficial to formal writing
Prediction: ['how creative writing is beneficial', 'how creative writing is beneficial', 'how creative writing is beneficial']
Query: diet 

: 