In [1]:
import torch
from transformers import (
    BertTokenizerFast, BertConfig, BertLMHeadModel, BertModel,
    AutoModel, EncoderDecoderModel, AutoTokenizer,
)
from torch.optim import AdamW

# pip install -U transformers datasets
import random, math
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW

from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModel, BertConfig, BertLMHeadModel, EncoderDecoderModel
)


####
SRC_CKPT = "bert-base-uncased"              # encoder (EN)
TGT_CKPT = "bert-base-multilingual-cased"   # decoder (FR-capable)


#########the data
ds = load_dataset("Helsinki-NLP/opus_books", "en-fr", split="train")  # ~1M pairs

pairs = [(ex["translation"]["en"], ex["translation"]["fr"]) for ex in ds.select(range(2000))]
random.shuffle(pairs)
pairs = pairs[:100]  # exactly 100

src_list, tgt_list = zip(*pairs)
#print(src_list)


# ---- tokenizers
torch.manual_seed(0)
enc =  "bert-base-uncased"
dec= "bert-base-multilingual-cased"   # decoder (FR-capable)

tok_src = BertTokenizerFast.from_pretrained(enc)
tok_tgt = BertTokenizerFast.from_pretrained(dec)
PAD_ID = tok_tgt.pad_token_id
EOS_ID = tok_tgt.sep_token_id
BOS_ID = tok_tgt.cls_token_id

MAX_SRC_LEN=100

def tokenize_src(text):
    x=tok_src(text,
              padding='max_length', 
              truncation=True, 
              max_length=MAX_SRC_LEN, 
              return_tensors="pt"
              )
    return x    

def tokenize_tgt(text):
    x=tok_tgt(text,
              padding='max_length',
              truncation=True, 
              max_length=MAX_SRC_LEN, 
              return_tensors="pt"
              )     
    return x


X=tokenize_src(src_list)
Y=tokenize_tgt(tgt_list)


print(src_list[:1])

BATCH_SIZE=8

train_dl = DataLoader(TensorDataset( X['input_ids'],X['attention_mask'],Y['input_ids']),
                                     batch_size=BATCH_SIZE,
                                     shuffle=True)


# X2 = tok_src(["cats are cute", "i like tea"],
#               return_tensors="pt",
#               padding=True,
#               truncation=True)


# Y2 = tok_tgt(["les chats sont mignons", "j'aime le thé"], 
#              return_tensors="pt",
#              padding=True,
#              truncation=True,
#              add_special_tokens=False)  # no [CLS]


# labels2 = Y2.input_ids#.clone() 
# labels2[labels2 == tok_tgt.pad_token_id] = -100



######### model!



# ---------- CORRECT: set decoder_start_token_id ON CONFIG before forward
dec_cfg_ok = BertConfig.from_pretrained(dec,
                                         is_decoder=True,
                                         add_cross_attention=True)


good = EncoderDecoderModel(encoder=AutoModel.from_pretrained(enc),
                           decoder=BertLMHeadModel.from_pretrained(dec, config=dec_cfg_ok),
                           )

# Required for loss computation (right-shift uses this)
good.config.decoder_start_token_id = tok_tgt.cls_token_id
good.config.eos_token_id = tok_tgt.sep_token_id
good.config.pad_token_id = tok_tgt.pad_token_id
good.config.vocab_size = good.config.decoder.vocab_size
good.config.tie_encoder_decoder = True

#add optimizer
optimizer = AdamW(good.parameters(), lr=1e-5)


model=good
# train then test again
model.train()
opt = AdamW(good.parameters(), lr=1e-5)

steps = 0
EPOCHS=10

for epoch in range(EPOCHS):
    Loss=0
    for ids,mask, labels in train_dl:
        opt.zero_grad()
        out = model(input_ids=ids, attention_mask=mask, labels=labels)
        out.loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        steps += 1
        Loss+=out.loss
    print(f"epoch {epoch+1}/{EPOCHS} done")
    print(f'Loss: {Loss/len(train_dl)}')


    
    X=tokenize_src(src_list[:1])
    gen=model.generate( X['input_ids'],attention_mask= X['attention_mask'],
                        num_beams=4, max_new_tokens=64, early_stopping=True,
                        decoder_start_token_id=BOS_ID, eos_token_id=EOS_ID, pad_token_id=PAD_ID,
                        bad_words_ids=[[PAD_ID]],          # block PAD
                        repetition_penalty=1.1,            # mild
                        no_repeat_ngram_size=3             # optional hygiene 
                        )   
#     print('run', str(i))
#     print('loss', loss.detach().numpy())
#     #print('Generated token:',gen2)
    print(" Generated words:", [tok_tgt.decode(g, skip_special_tokens=True) for g in gen])
#     print('\n')    
#model.eval()


# good.train()
# for i in range(0,10):
#     out = good(input_ids=X2["input_ids"],
#             attention_mask=X2["attention_mask"],
#             labels=labels2) 

#     loss=out.loss
#     loss.backward()
#     optimizer.step()

#     gen2 = good.generate(X2["input_ids"],
#                         attention_mask=X2["attention_mask"],
#                         num_beams=4,
#                         #max_new_tokens=24,
#                         no_repeat_ngram_size=3,
#                         #early_stopping=True,
#                         decoder_start_token_id=tok_tgt.cls_token_id,
#                         # eos_token_id=tok_tgt.sep_token_id,
#                         # pad_token_id=tok_tgt.pad_token_id
#                         )

#     print('run', str(i))
#     print('loss', loss.detach().numpy())
#     #print('Generated token:',gen2)
#     print(" Generated words:", [tok_tgt.decode(g, skip_special_tokens=True) for g in gen2])
#     print('\n')









('He was holding in his hand a little wheel of blackened wood; a string of partly burnt squibs was twisted round it; evidently a Catherine wheel from the fireworks display on the fourteenth of July.', 'Young Roy, half choking with laughter, pushes us allfrombehind to hurry us out.', 'Nothing moves yet in that clear wintry lansdcape.', 'I am caught in failure; here is my chance to quicken their curiosity: I decide to explain who this gipsy was, where he came from; his strange fate . . . Boujardon and Delouche do not care to listen.', 'The table has not been laid; we all eat off our knees, each settling where best he can in the dark classroom.', "They even confided to us, while M. Seurel was starting off again at the head of our party : 'There was another chap as went by. That tall fellow, you know. . .", 'This, as we knew, would be the only sight the whole day, which would pass like muddy water along the gutter.', 'About midnight I woke up suddenly.', "And we shall have to find the rest

Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bia

epoch 1/10 done
Loss: 12.095444679260254
 Generated words: ['- - - 提 携 携 携 攜 携 攜 攜 攜 携 便 携 携 帶 带 帶 带 带 带 帶 帶 帶 攜 携 带 带 攜 攜 带 帶 攜 帶 帶 带 攜 帶 带 帯 带 带 携 携 带 攜 带 带 帶 帶 携 帶 帶 帶 带 带 帶 带 带']
epoch 2/10 done
Loss: 6.216301918029785
 Generated words: ['visaugi mature mature matureessen mature mature adulte mature matureicata mature mature ave mature mature Astragalus mature mature صاحب mature mature strict mature mature పౌర mature mature gaya mature mature깃 mature mature style mature mature 펙 mature maturesite mature maturegant mature mature vite mature maturesica mature maturerain mature maturepara mature mature conforme mature mature삿 mature']
epoch 3/10 done
Loss: 3.8064959049224854
 Generated words: [",..,, à à à et à à, à de à à une à, d à à par à à tout, à d d à au à - à à se à à termine tout à à pour à à à à chi à à d, à ' à à au au à à les à"]
epoch 4/10 done
Loss: 2.886213779449463
 Generated words: [',,...,, à à,., de,,,,, la,, d.,., à une, à.,. à à. à, à à et., une., un, à,, pour,,s à