In [1]:
import torch
from transformers import MBartTokenizer, MBartConfig

In [2]:
def model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024 ** 2
    return size_all_mb

In [3]:
torch.backends.cudnn.benchmark = True

In [4]:
from MBart import MBart

print(torch.cuda.is_available())

tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-cc25", src_lang="en_XX")
mbart_config = MBartConfig(encoder_layers=6, decoder_layers=6,
                           encoder_ffn_dim=512, decoder_ffn_dim=512,
                           encoder_attention_heads=8, decoder_attention_heads=8,
                           d_model=512, max_length=128, vocab_size=tokenizer.vocab_size)

model: MBart = MBart(mbart_config)
print(model_size(model))

True
589.6158866882324


In [5]:
from datasets import load_from_disk

dataset_loaded = load_from_disk("europarl_eng_tokenized_and_masked_128")
dataset_loaded.set_format(type='pt', columns=['input_ids', 'attention_mask', 'masked_ids'])

In [6]:
from CustomDataset import CustomDataset
from torch.utils.data import DataLoader

#dataset_loaded = dataset_loaded[1:500]

input_ids = dataset_loaded['input_ids']
attention_mask = dataset_loaded['attention_mask']
masked_ids = dataset_loaded['masked_ids']

ds_en_loader = DataLoader(CustomDataset(masked_ids, input_ids, attention_mask),
                          batch_size=4, drop_last=True, shuffle=True,
                          pin_memory=True, pin_memory_device='cuda', num_workers=4)

In [7]:
from torch.optim import Adam

model.fit(ds_en_loader, Adam(model.parameters()), epochs=5)

Epoch 0


  0%|          | 115/427452 [00:34<35:23:33,  3.35it/s, loss=2.9669]


KeyboardInterrupt: 

In [7]:
sentence = "<mask> is in France"
test_ids = tokenizer([sentence], add_special_tokens=True, return_tensors="pt")["input_ids"]
logits = model.model(test_ids.to('cuda')).logits
masked_index = (test_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
tokenizer.decode(predictions).split()

['The', 'This', 'I', 'We', 'Mr']

In [8]:
outputs = model.model.generate(test_ids.to('cuda'), decoder_start_token_id=tokenizer.lang_code_to_id['en_XX'],
                               num_beams=2)
print(tokenizer.batch_decode(outputs))



['en_XX The Commission, is the the the the the the the the the the the the the the the the the the the the the the the the the have the the the the the Commission, the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the</s>']
