In [1]:
from torch.utils.data import DataLoader
from datasets import Dataset
from typing import Dict, List, Any
from datasets import load_dataset
from transformers import AutoTokenizer


def create_dataloader(trans_pair_ds: Dataset, input_column: str, fn_kwargs: Dict[str, Any],
                      batch_size: int) -> DataLoader:
    trans_pair_ds = trans_pair_ds.map(tokenize, batched=True, input_columns=[input_column],
                                      fn_kwargs=fn_kwargs)
    # trans_pair_ds = trans_pair_ds.remove_columns(column_names=['translation', 'original_text'])
    trans_pair_ds = trans_pair_ds.with_format('torch', columns=["input_ids", "labels", "attention_mask"],
                                              output_all_columns=False)

    # ids = [e['input_ids'].view(1, -1) for e in iter(trans_pair_ds)]
    test_loader = DataLoader(trans_pair_ds, batch_size=batch_size, drop_last=True, pin_memory=False)
    return test_loader


def get_wmt_dataset(lang_pair: str, num_of_rows: int = None) -> Dataset:
    wmt14 = "wmt14"
    split = "test"
    lang_config = lang_pair.split("-")
    assert len(lang_config) == 2
    if lang_config[0] == "en":
        lang_config[0], lang_config[1] = lang_config[1], lang_config[0]
    lang_config = "-".join(lang_config)
    if "es" in lang_config:
        wmt14 = "nikodallanoce/wmt14"
        split = "validation"
    split = split if num_of_rows is None else split + f"[:{num_of_rows}]"
    ds = load_dataset(wmt14, lang_config,
                      cache_dir="/data/n.dallanoce/wmt14",
                      split=split,
                      verification_mode='no_checks')
    return ds


def tokenize(examples: List[Dict[str, str]], **kwargs):
    tokenizer = kwargs['tokenizer']
    src_lang: str = kwargs['lang1']
    tgt_lang: str = kwargs['lang2']
    if "task" in kwargs:
        task: str = kwargs['task']
        batch_src: List[str] = [task + e[src_lang] for e in examples]
    else:
        batch_src: List[str] = [e[src_lang] for e in examples]
    batch_tgt: List[str] = [e[tgt_lang] for e in examples]
    # tokenize the batch of sentences
    outputs = tokenizer(batch_src, text_target=batch_tgt, return_special_tokens_mask=False,
                        add_special_tokens=True, truncation=True,
                        max_length=128, padding='max_length',
                        return_attention_mask=True, return_tensors='pt')
    # labels = tokenizer(batch_tgt, truncation=False)
    # batch_tgt = tokenizer.batch_decode(labels['input_ids'], skip_special_tokens=True)

    return {'input_ids': outputs['input_ids'], 'labels': outputs['labels'], 'attention_mask': outputs['attention_mask']}




In [2]:
import os
from CosineSim import CosineSim

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# mBART

In [3]:
from transformers import MBartConfig, MBartForConditionalGeneration
from utilities.models import get_all_mbart_models

tok_mbart = AutoTokenizer.from_pretrained("nikodallanoce/mbart-cc4-vanilla-32k-5", src_lang="en_XX", tgt_lang="fr_XX")

fn_kwargs_mbart = {'tokenizer': tok_mbart, 'lang1': "en", 'lang2': "fr"}
wmt14_ds_mbart = get_wmt_dataset(fn_kwargs_mbart['lang1'] + "-" + fn_kwargs_mbart['lang2'], num_of_rows=512)
dataloader_mbart = create_dataloader(wmt14_ds_mbart, "translation", fn_kwargs_mbart, 32)
mbart_config = MBartConfig(encoder_layers=6, decoder_layers=6,
                           encoder_ffn_dim=2048, decoder_ffn_dim=2048,
                           encoder_attention_heads=8, decoder_attention_heads=8,
                           d_model=512, max_length=128, vocab_size=tok_mbart.vocab_size, dropout=0.1)
rnd_mbart = MBartForConditionalGeneration(mbart_config)

mbart_models = get_all_mbart_models()

Found cached dataset wmt14 (/data/n.dallanoce/wmt14/wmt14/fr-en/1.0.0/2de185b074515e97618524d69f5e27ee7545dcbed4aa9bc1a4235710ffca33f4)
Loading cached processed dataset at /data/n.dallanoce/wmt14/wmt14/fr-en/1.0.0/2de185b074515e97618524d69f5e27ee7545dcbed4aa9bc1a4235710ffca33f4/cache-3bd4b67a2a65c9e0.arrow


In [4]:
cs = CosineSim(mbart_models["M1"], mbart_models["M2_replay"])
sim = cs.compute_enc_hidd_states(dataloader_mbart)
print(round(sim, 4))

100%|██████████| 16/16 [00:08<00:00,  1.87it/s]

0.9104





In [6]:
from tqdm import tqdm

langs = ["en", "fr", "de", "es"]
model_lst = [(k, v) for k, v in mbart_models.items()]

for lang in tqdm(langs):
    fn_kwargs_mbart = {'tokenizer': tok_mbart, 'lang1': lang, 'lang2': "fr"}
    if lang != "en":
        fn_kwargs_mbart['lang2'] = "en"
    wmt14_ds_mbart = get_wmt_dataset(fn_kwargs_mbart['lang1'] + "-" + fn_kwargs_mbart['lang2'], num_of_rows=512)
    dataloader_mbart = create_dataloader(wmt14_ds_mbart, "translation", fn_kwargs_mbart, 128)
    for i in range(len(model_lst) - 1):
        mi_name, model_i = model_lst[i]
        for j in range(i + 1, len(model_lst)):
            mj_name, model_j = model_lst[j]
            cs = CosineSim(model_i, model_j)
            sim = cs.compute_enc_hidd_states(dataloader_mbart, show_tqdm=False)
            print(f"Similarity between {mi_name} and {mj_name} is {round(sim, 4)}")

  0%|          | 0/4 [00:00<?, ?it/s]Found cached dataset wmt14 (/data/n.dallanoce/wmt14/wmt14/fr-en/1.0.0/2de185b074515e97618524d69f5e27ee7545dcbed4aa9bc1a4235710ffca33f4)
Loading cached processed dataset at /data/n.dallanoce/wmt14/wmt14/fr-en/1.0.0/2de185b074515e97618524d69f5e27ee7545dcbed4aa9bc1a4235710ffca33f4/cache-3bd4b67a2a65c9e0.arrow


Similarity between M1 and M2 is 0.587
Similarity between M1 and M2_de_only is 0.0131
Similarity between M1 and M2_replay is 0.9104
Similarity between M1 and M3 is 0.4945
Similarity between M1 and M3_replay is 0.8724
Similarity between M1 and M1F1 is 0.6012
Similarity between M1 and M2F1 is 0.5645
Similarity between M1 and M2F1_replay is 0.5939
Similarity between M1 and M2F2 is 0.489
Similarity between M1 and MF2_ft_only is 0.0015


  0%|          | 0/4 [00:50<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 1.96 GiB (GPU 0; 15.75 GiB total capacity; 12.12 GiB already allocated; 1.90 GiB free; 12.75 GiB reserved in total by PyTorch)

In [None]:
sim = cs.compute_logits(dataloader_mbart)
print(round(sim, 4))

# mT6

In [9]:
from transformers import MT5ForConditionalGeneration, MT5Config
from utilities.models import get_all_mt6_models

tok_mt6 = AutoTokenizer.from_pretrained("nikodallanoce/mt5-cc4-vanilla-32k-5")
rnd_mt6 = MT5ForConditionalGeneration(
    MT5Config(num_layers=6, d_model=512, num_heads=8, d_ff=2048, vocab_size=len(tok_mt6), max_length=128,
              tie_word_embeddings=True))
fn_kwargs_mt6 = {'tokenizer': tok_mt6, 'lang1': "en", 'lang2': "es"}
wmt14_ds_mt6 = get_wmt_dataset(fn_kwargs_mt6['lang1'] + "-" + fn_kwargs_mt6['lang2'], num_of_rows=512)
dataloader_mt6 = create_dataloader(wmt14_ds_mt6, "translation", fn_kwargs_mt6, 32)
mt6_models = get_all_mt6_models()

Found cached dataset wmt14 (/data/n.dallanoce/wmt14/nikodallanoce___wmt14/es-en/1.0.0/87db7d5f83bc44f038b67325c372011ddb3cb63ec2bb219b5736426178356f0a)


Map:   0%|          | 0/512 [00:00<?, ? examples/s]

In [11]:
cs = CosineSim(mbart_models["M1"], mbart_models["M2_replay"])
sim = cs.compute_enc_hidd_states(dataloader_mt6)
print(round(sim, 4))

100%|██████████| 16/16 [00:06<00:00,  2.32it/s]

0.7614



