In [1]:
import nlp, transformers, torch, tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
from torch.utils.data import DataLoader

dataset = nlp.load_dataset('wmt_t2t')
model = transformers.BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn').to(device)
tokenizer = transformers.BartTokenizer.from_pretrained('facebook/bart-large-cnn')

dl = DataLoader(dataset['train'], batch_size=10)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.000001)

epochs = 1001
max_len = 50

data = next(iter(dl))

for e in range(epochs):

    en_tokenized_data = tokenizer.batch_encode_plus(
        data['translation']['en'],
        max_length=max_len, 
        truncation=True,
        pad_to_max_length=True, 
        return_attention_mask=True, 
        return_tensors='pt'
    )

    en_tokenized_data = en_tokenized_data.to(device)
    
    if e % 10 == 0:
        model.eval()
        generated_examples = model.generate(en_tokenized_data['input_ids'], 
                                            attention_mask=en_tokenized_data['attention_mask'], 
                                            decoder_start_token_id=tokenizer.bos_token_id)
        print(f"Examples: \n", tokenizer.batch_decode(generated_examples)[0])
    
    model.train()
    optimizer.zero_grad()

    result = model(
        input_ids=en_tokenized_data['input_ids'],
        attention_mask=en_tokenized_data['attention_mask'],
        decoder_input_ids=en_tokenized_data['input_ids'],
        decoder_attention_mask=en_tokenized_data['attention_mask'],
        labels=en_tokenized_data['input_ids']
    )

    loss, output, output2 = result[:3]
    print(f"Epoch: {e}, Loss: {loss.item()}")
    loss.backward()
    optimizer.step()

<s><s><s><pad><pad><pad>angaangaangaacacacabacaccacacaciacacancacacapacacAcacacacaacac ChickacacCacacACacacicacacciacac Tacacacaciesacacacyacacaclacacicasacaccasacaccracac acacclacacicaacacccacac cracacacanacaccapacacjacacaccuacacamacacac Acacaccatacac)...acacocalacacchiacac Cubanacacicanacac ACacacCapacacacciacacCubacacCatacacCamacac camacac</s>
Epoch: 740, Loss: 0.19808423519134521
Epoch: 741, Loss: 0.19004777073860168
Epoch: 742, Loss: 0.19002141058444977
Epoch: 743, Loss: 0.19766700267791748
Epoch: 744, Loss: 0.19140496850013733
Epoch: 745, Loss: 0.18866917490959167
Epoch: 746, Loss: 0.1875365674495697
Epoch: 747, Loss: 0.1900414377450943
Epoch: 748, Loss: 0.1918729990720749
Epoch: 749, Loss: 0.18556812405586243
Examples: 
 <s><s><s><pad><pad><pad>angaangaangaacacacabacaccacacaciacacancacacapacacAcacacacaacac ChickacacCacacACacacicacacciacac Tacacacaciesacacaclacacicasacacccacaccasacaccracac acacclacac cracacacanacacicaacaccapaciacicaciacaciaciaciac cacaciacaacicicacicciacciaciacci

# Análise dos resultados obtidos vs resultados esperados

In [3]:
result = model(
    input_ids=en_tokenized_data['input_ids'],
    attention_mask=en_tokenized_data['attention_mask'],
    decoder_input_ids=en_tokenized_data['input_ids'],
    decoder_attention_mask=en_tokenized_data['attention_mask'],
    labels=en_tokenized_data['input_ids']
)

In [4]:
loss, logits = result[0].cpu(), result[1].cpu()
inputs = en_tokenized_data['input_ids'].cpu()
logits.shape, inputs.shape

(torch.Size([10, 50, 50264]), torch.Size([10, 50]))

In [5]:
tokenizer.batch_decode(model.generate(inputs[:1].cuda()))

['</s><s><s><s>Resumption of of of the of the the session session sessionsession sessionsessionsession session session sessions session sessionSession session session Session session session resume session session.Resumptionumptionumption resumedumption resumed resumed resumedumption resume resume resumeumption resumed resume resume resumed resumed resumeumptionumption resumeumption resume resumed resume session suspension resume session resume resume']

In [6]:
tokenizer.batch_decode(model.generate(torch.argmax(torch.nn.functional.softmax(logits[:1].cuda(), dim=-1), dim=-1)))

['</s><s>Resumptionumption of of of the of the the session session sessionsession session sessionSession session session sessions session sessions sessions session session.Session sessionSessionSessionSession sessionsessionSessionSessionsessionSession session SessionSessionSessionrapSessionSession SessionSession Session sessionSessionResumptionResumptionSessionSessionRepSessionSessionSnapSnapSnapResResResSessionResResumptionRepResResumpResResRepResRecResResReportsResResPsResRespsResRes)",ResRes),"ResResumpsResResSpotResResspResRes.""ResRes<pad><pad><pad> \'Res<pad>\'\'Res \'ResRes\xa0Rep\xa0 \'Res "\'\'\'Res\xa0\'\' "\'" " \'\xa0Res \'']