In [1]:
import transformers, torch, nlp

datasets = [dataset.id for dataset in nlp.list_datasets()]
print(datasets)

['aeslc', 'ag_news', 'ai2_arc', 'allocine', 'anli', 'arcd', 'art', 'billsum', 'biomrc', 'blended_skill_talk', 'blimp', 'blog_authorship_corpus', 'bookcorpus', 'boolq', 'break_data', 'c4', 'cfq', 'civil_comments', 'cmrc2018', 'cnn_dailymail', 'coarse_discourse', 'com_qa', 'commonsense_qa', 'compguesswhat', 'coqa', 'cornell_movie_dialog', 'cos_e', 'cosmos_qa', 'crime_and_punish', 'csv', 'definite_pronoun_resolution', 'discofuse', 'docred', 'drop', 'eli5', 'emotion', 'empathetic_dialogues', 'eraser_multi_rc', 'esnli', 'event2Mind', 'fever', 'flores', 'fquad', 'gap', 'germeval_14', 'ghomasHudson/cqc', 'gigaword', 'glue', 'hansards', 'hellaswag', 'hyperpartisan_news_detection', 'imdb', 'jeopardy', 'json', 'k-halid/ar', 'kor_nli', 'lc_quad', 'lhoestq/c4', 'librispeech_lm', 'lince', 'lm1b', 'math_dataset', 'math_qa', 'mlqa', 'movie_rationales', 'multi_news', 'multi_nli', 'multi_nli_mismatch', 'mwsc', 'natural_questions', 'newsroom', 'openbookqa', 'opinosis', 'pandas', 'para_crawl', 'pg19', 'p

In [2]:
dataset = nlp.load_dataset('wmt_t2t')

In [3]:
# encoder_config = transformers.BertConfig()
# decoder_config = transformers.BertConfig()

# model_config = transformers.EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)

# model = transformers.EncoderDecoderModel(model_config)

model = transformers.EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer

In [4]:
en_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
de_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
from torch.utils.data import DataLoader

In [6]:
dl = DataLoader(dataset['train'], batch_size=1)

In [7]:
data = next(iter(dl))
data

{'translation': {'de': ['Wiederaufnahme der Sitzungsperiode'],
  'en': ['Resumption of the session']}}

In [8]:
max_len = 50

de_tokenized_data = de_tokenizer.batch_encode_plus(
    data['translation']['de'],
    max_length=max_len, 
    truncation=True,
    pad_to_max_length=True, 
    return_attention_mask=True, 
    return_tensors='pt'
)

en_tokenized_data = en_tokenizer.batch_encode_plus(
    data['translation']['en'],
    max_length=max_len, 
    truncation=True,
    pad_to_max_length=True, 
    return_attention_mask=True, 
    return_tensors='pt'
)

In [11]:
optimizer = torch.optim.AdamW(model.parameters())
epochs = 100

for e in range(epochs):

    optimizer.zero_grad()
    result = model(
        input_ids=en_tokenized_data['input_ids'],
        attention_mask=en_tokenized_data['attention_mask'],
        decoder_input_ids=de_tokenized_data['input_ids'],
        decoder_attention_mask=de_tokenized_data['attention_mask'],
        labels=de_tokenized_data['input_ids']
    )

    loss, output = result[:2]
    print(f"Epoch: {e}, Loss: {loss.item()}")
    loss.backward()
    optimizer.step()
    
    if e % 10 == 0:
        generated_examples = model.generate(en_tokenized_data['input_ids'], attention_mask=en_tokenized_data['attention_mask'], decoder_start_token_id=101)
        print(f"Examples: \n", de_tokenizer.batch_decode(generated_examples))

Epoch: 0, Loss: 13.902228355407715
Examples: 
 ['[unused105]uptuptuptuptuptuptuptuptuptuptuptuptuptuptuptuptuptuptupt']
Epoch: 1, Loss: 7.812264442443848
Epoch: 2, Loss: 3.7805967330932617
Epoch: 3, Loss: 2.996269941329956
Epoch: 4, Loss: 2.232943296432495
Epoch: 5, Loss: 1.7437382936477661
Epoch: 6, Loss: 1.2134402990341187
Epoch: 7, Loss: 1.0368692874908447
Epoch: 8, Loss: 0.9304614067077637
Epoch: 9, Loss: 0.967970073223114
Epoch: 10, Loss: 0.9651787281036377
Examples: 
 ['[unused105] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']
Epoch: 11, Loss: 0.8796783089637756
Epoch: 12, Loss: 0.8711527585983276
Epoch: 13, Loss: 0.8390348553657532
Epoch: 14, Loss: 0.7865239381790161
Epoch: 15, Loss: 0.7854759097099304
Epoch: 16, Loss: 0.7550427317619324
Epoch: 17, Loss: 0.7746438980102539
Epoch: 18, Loss: 0.7383840680122375
Epoch: 19, Loss: 0.7383842468261719
Epoch: 20, Loss: 0.745490550994873
Examples: 
 ['[unused105] [PAD] 