In [8]:
import torch
from transformers import (
    BertTokenizerFast, BertConfig, BertLMHeadModel, BertModel,
    AutoModel, EncoderDecoderModel, AutoTokenizer,
)
from torch.optim import AdamW



torch.manual_seed(0)
enc = dec = "bert-base-uncased"
tok_src = BertTokenizerFast.from_pretrained(enc)
tok_tgt = tok_src#BertTokenizerFast.from_pretrained(dec)


# ---------- CORRECT: set decoder_start_token_id ON CONFIG before forward
dec_cfg_ok = BertConfig.from_pretrained(dec,
                                         is_decoder=True,
                                         add_cross_attention=True)


good = EncoderDecoderModel(encoder=AutoModel.from_pretrained(enc),
                           decoder=BertLMHeadModel.from_pretrained(dec, config=dec_cfg_ok),
                           )

# Required for loss computation (right-shift uses this)
good.config.decoder_start_token_id = tok_tgt.cls_token_id
good.config.eos_token_id = tok_tgt.sep_token_id
good.config.pad_token_id = tok_tgt.pad_token_id
good.config.vocab_size = good.config.decoder.vocab_size
good.config.tie_encoder_decoder = True


#add optimizer
optimizer = AdamW(good.parameters(), lr=1e-5)


X2 = tok_src(["cats are cute", "i like tea"],
              return_tensors="pt",
              padding=True,
              truncation=True)


Y2 = tok_tgt(["les chats sont mignons", "j'aime le thé"], 
             return_tensors="pt",
             padding=True,
             truncation=True,
             add_special_tokens=False)  # no [CLS]


labels2 = Y2.input_ids#.clone() 
labels2[labels2 == tok_tgt.pad_token_id] = -100



for i in range(0,5):
    out = good(input_ids=X2["input_ids"],
            attention_mask=X2["attention_mask"],
            labels=labels2) 

    loss=out.loss
    loss.backward()
    optimizer.step()

    gen2 = good.generate(X2["input_ids"],
                        attention_mask=X2["attention_mask"],
                        #num_beams=4,
                        #max_new_tokens=24,
                        no_repeat_ngram_size=3,
                        #early_stopping=True,
                        decoder_start_token_id=tok_tgt.cls_token_id,
                        # eos_token_id=tok_tgt.sep_token_id,
                        # pad_token_id=tok_tgt.pad_token_id
                        )

    print('run', str(i))
    print('loss', loss.detach().numpy())
    #print('Generated token:',gen2)
    print(" Generated words:", [tok_tgt.decode(g, skip_special_tokens=True) for g in gen2])
    print('\n')









Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.e

run 0
loss 11.266043
 Generated words: ['the i school school school a school school me school school up school school kindergarten school school two school school', 'i i i ce ce ce ke ce ce se ce ce ko ce ce ka ce cece ce']


run 1
loss 9.112206
 Generated words: ['. it it it him him him me me me him me him him sen sen sen me me sen', 'the. school school school sam sam sam school school so so so sam sam so sam school sam school']


run 2
loss 9.4951725
 Generated words: ['is is is can can can is is while while while can can while while like like like while while', 'a is is is be be be is is so so so is is it so so as as as']


run 3
loss 8.448451
 Generated words: ['##sss is is isssnessnessnessss sirss sen sen sens', 'know know know is is is know know be be be is is be be so so so is is']


run 4
loss 7.659356
 Generated words: ['##sss type type typessnessnessnessss sulnessness sulnesssness', "##sss issstssshsss 'ss sirssess"]


