In [1]:
from datasets import load_dataset
from transformers import MT5Tokenizer, MT5ForConditionalGeneration, MT5Config
import torch
from transformers.utils import is_torch_fx_proxy

from noise_functions.MT6NoiseFunction import MT6NoiseFunction

In [3]:
def model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024 ** 2
    return size_all_mb

In [89]:
tok_en = MT5Tokenizer.from_pretrained("google/mt5-base")

cuda_dev = "cpu"

model = MT5ForConditionalGeneration.from_pretrained(
    "/home/n.dallanoce/PyCharm/pretraining/weights/mt6/checkpoint-4800")
model = model.to(cuda_dev)
model.train(False)
print(model_size(model))

542.5087890625


In [5]:
dataset = load_dataset("cc100", lang="en",
                       cache_dir="/data/n.dallanoce/cc100/hugg_en",
                       split=f"train[{0}:{1024}]",
                       ignore_verifications=True)

Using custom data configuration en-lang=en
Found cached dataset cc100 (/data/n.dallanoce/cc100/hugg_en/cc100/en-lang=en/0.0.0/8159941b93eb06d0288bb80be26ddfe8213c0c5e33286619c85ad8e1ee0eb91c)


In [95]:
index = 111
sent = dataset[index]['text']
src_sent, tgt_sent = MT6NoiseFunction().compute_for_mt5(sent, seed=index, noise_density=0.35)
print(f"original: {sent} \n \nsource: {src_sent}")

original: Read about what EEI’s International Programs has accomplished in the first-half of 2018
 
 
source: <extra_id_0> what EEI’s International <extra_id_1> has accomplished <extra_id_2> of <extra_id_3>


In [96]:
input_ids = tok_en(src_sent, return_tensors="pt").input_ids
sequence_ids = model.generate(input_ids, max_length=64)
sequences = tok_en.batch_decode(sequence_ids, skip_special_tokens=False)
print(f"prediction: {sequences} \n \nlabel: {tgt_sent}")

prediction: ['<pad> <extra_id_0> vitamins <extra_id_1> a <extra_id_2> life after <extra_id_3> bypass <extra_id_4> recommend:</s>'] 
 
label: Read about <extra_id_0> Programs <extra_id_1> in the first-half <extra_id_2> 2018

