In [1]:
import nlp, transformers, torch, tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
from torch.utils.data import DataLoader

dataset = nlp.load_dataset('wmt_t2t')
model = transformers.BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn').to(device)
tokenizer = transformers.BartTokenizer.from_pretrained('facebook/bart-large-cnn')

dl = DataLoader(dataset['train'], batch_size=10)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)

epochs = 201
max_len = 50

data = next(iter(dl))

for e in range(epochs):

    en_tokenized_data = tokenizer.batch_encode_plus(
        data['translation']['en'],
        max_length=max_len, 
        truncation=True,
        pad_to_max_length=True, 
        return_attention_mask=True, 
        return_tensors='pt'
    )

    en_tokenized_data = en_tokenized_data.to(device)
    
    if e % 10 == 0:
        model.eval()
        generated_examples = model.generate(en_tokenized_data['input_ids'], 
                                            attention_mask=en_tokenized_data['attention_mask'], 
                                            decoder_start_token_id=tokenizer.bos_token_id)
        print(f"Examples: \n", tokenizer.batch_decode(generated_examples)[0])
    
    model.train()
    optimizer.zero_grad()

    result = model(
        input_ids=en_tokenized_data['input_ids'],
        attention_mask=en_tokenized_data['attention_mask'],
        decoder_input_ids=en_tokenized_data['input_ids'],
        decoder_attention_mask=en_tokenized_data['attention_mask'],
        labels=en_tokenized_data['input_ids']
    )

    loss, output, output2 = result[:3]
    print(f"Epoch: {e}, Loss: {loss.item()}")
    loss.backward()
    optimizer.step()

><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Epoch: 0, Loss: 10.76230525970459
Epoch: 1, Loss: 10.030280113220215
Epoch: 2, Loss: 9.497891426086426
Epoch: 3, Loss: 8.852696418762207
Epoch: 4, Loss: 8.497407913208008
Epoch: 5, Loss: 8.063995361328125
Epoch: 6, Loss: 7.533941268920898
Epoch: 7, Loss: 7.111499786376953
Epoch: 8, Loss: 6.74473237991333
Epoch: 9, Loss: 6.354745388031006
Examples: 
 <s><s><s>Resumption of the session. Resumption of of the the session of the House of Representatives. Resuming of the Senate session. resuming the session the Senate resumed of the chamber. resumption the session resumed the session and resumed the the House session.Resumption the the</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><p

# Análise dos resultados obtidos vs resultados esperados

In [3]:
result = model(
    input_ids=en_tokenized_data['input_ids'],
    attention_mask=en_tokenized_data['attention_mask'],
    decoder_input_ids=en_tokenized_data['input_ids'],
    decoder_attention_mask=en_tokenized_data['attention_mask'],
    labels=en_tokenized_data['input_ids']
)

In [4]:
loss, logits = result[0].cpu(), result[1].cpu()
inputs = en_tokenized_data['input_ids'].cpu()
logits.shape, inputs.shape

(torch.Size([10, 50, 50264]), torch.Size([10, 50]))

### Resultados do método `model.generate`

In [5]:
tokenizer.batch_decode(model.generate(inputs[:1].cuda()))

["</s><s><s>Resumptionumptionumption suspension suspension suspensionumption suspensionumptionumptionResumption suspension cabin<pad><pad><pad>\xa0c<pad><pad> 'c 'C'' ''''\xa0 '\xa0' 'C' 'ig 'g''g' 'gig'' g''G' 'G''Gu''igigg 'gg 'igaggggigggG 'gagg 'Gg 'gal 'gagi 'gagu 'g gggagaggigag 'gaggingggagiggGugagiagiagigagiaggagGgagiGuGagagiagiGagiagiGuGuagiagi"]

In [6]:
tokenizer.batch_decode(model.generate(torch.argmax(torch.nn.functional.softmax(logits[:1].cuda(), dim=-1), dim=-1)))

['</s><s><s><s>The The The The the The TheThe TheTheTheThe the theThe The theThe theTheThethetheThetheTheTheDespiteTheThebutThebut thebutbutbut theThebutbut...but.Thebut.but)...but......but....but......but......but...)...but)...)...)...but...)...)......)............)......)...)......"but......"...)...............)...............,...).........,...)...)........)...........)...)..."...)...)...,...)......")......"..."......"..."..................."..."bek<pad><pad><pad>)..."...<pad><pad> \'<pad><pad>']

### Resultados da saída do modelo

In [8]:
inputs[:1].cpu()

tensor([[    0, 20028, 21236,     9,     5,  1852,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]])

In [11]:
torch.argmax(torch.nn.functional.softmax(output[:1].cpu(), dim=-1), dim=-1)

tensor([[    0, 20028, 21236,     9,     5,  1852,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]])