In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
import pickle

#torch.cuda.set_device(3)
#torch.cuda.current_device()

import warnings
from pathlib import Path
from wrappers.transformer_wrapper import FairseqTransformerHub
from wrappers.multilingual_transformer_wrapper import FairseqMultilingualTransformerHub
from alignment.aer import aer
import itertools

import alignment.align as align
from types import SimpleNamespace

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

import logging
logger = logging.getLogger()

logger.setLevel('WARNING')
warnings.simplefilter('ignore')

from dotenv import load_dotenv
load_dotenv()
device = "cuda" if torch.cuda.is_available() else "cpu"

2023-01-21 07:23:09 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


## Load model

In [3]:
model = 'multilingual'# bilingual/multilingual
model_size = 'small'

In [4]:
if model == 'bilingual':
    # Bilingual paths
    europarl_dir = Path(os.environ['EUROPARL_DATA_DIR'])
    ckpt_dir = Path(os.environ['EUROPARL_CKPT_DIR'])
    # Choose model
    model_type = 'baseline'
    seed = 9819 # 2253  2453  5498  9240  9819
    model_name = f"{model_type}/{seed}"

    args = SimpleNamespace(
        src = "de",
        tgt = "en",
        tokenizer = "bpe",
        test_set_dir = Path(os.environ['EUROPARL_DATA_DIR']) / "processed_data/",
        model_name_save = model_name.replace('/','_'),
        pre_layer_norm = False,
        num_layers = 6
        )

elif model == 'multilingual':
    # Multilingual paths
    ckpt_dir = Path(os.environ['M2M_CKPT_DIR'])
    europarl_dir = Path("./data/de-en")
    model_size = 'big' # small (412M) /big (1.2B)

    args = SimpleNamespace(
        src = "de",
        tgt = "en",
        tokenizer = "spm",
        test_set_dir = Path("./data/de-en").as_posix(),
        model_name_save = f'm2m100_{model_size}',
        pre_layer_norm = True,
        num_layers = 12
        )

In [5]:
#ckpt_dir = Path(os.environ['IWSLT14_CKPT_DIR'])

lang_flores_dict = {'en': 'eng', 'es': 'spa', 'zu': 'zul',
                    'de': 'deu', 'yo': 'yor', 'ms': 'msa',
                    'fr': 'fra', 'xh': 'xho'}
source_lang = 'de'
target_lang = 'en'
if model == 'bilingual':
    hub = FairseqTransformerHub.from_pretrained(
        ckpt_dir / f"{model_type}/{seed}",
        checkpoint_file=f"checkpoint_best.pt",
        data_name_or_path=(europarl_dir / "processed_data/fairseq_preprocessed_data").as_posix(), # processed data
    )

elif model == 'multilingual':
    # Checkpoint names
    if model_size=='big':
        checkpoint_file = '1.2B_last_checkpoint.pt'
    else:
        checkpoint_file = '418M_last_checkpoint.pt'
    data_name_or_path='.'
    hub = FairseqMultilingualTransformerHub.from_pretrained(
        ckpt_dir,
        checkpoint_file=checkpoint_file,
        data_name_or_path=data_name_or_path,
        source_lang= args.src,
        target_lang= args.tgt,
        lang_pairs =f'{source_lang}-{target_lang}',
        fixed_dictionary=f'{ckpt_dir}/model_dict.128k.txt')




## Compute AER

In [6]:
mode_list = ['alti', 'decoder.encoder_attn', 'alti_enc_cross_attn','attn_w', 'cross_attn_contributions_proj', 'cross_attn_contrib_proj_alti', 'vector_norms_cross']
aer_obt = aer(args, mode_list)

In [7]:
contrib_type = 'l1'
aer_obt.extract_contribution_matrix(hub, contrib_type)

0


RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 11.93 GiB total capacity; 10.57 GiB already allocated; 18.12 MiB free; 11.07 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
aer_obt.extract_alignments(final_punc_mark=False, ignore_eos=True)

dict_keys(['alti', 'decoder.encoder_attn', 'alti_enc_cross_attn', 'attn_w', 'cross_attn_contributions_proj', 'cross_attn_contrib_proj_alti'])
alti
decoder.encoder_attn
alti_enc_cross_attn
attn_w
cross_attn_contributions_proj
cross_attn_contrib_proj_alti
alti
decoder.encoder_attn
alti_enc_cross_attn
attn_w
cross_attn_contributions_proj
cross_attn_contrib_proj_alti


In [None]:
mode_list = ['alti', 'decoder.encoder_attn', 'alti_enc_cross_attn', 'attn_w','cross_attn_contributions_proj', 'cross_attn_contrib_proj_alti', 'vector_norms_cross']

for setting in ['AWI', 'AWO']:
    print(f'{setting}:\n')
    results = aer_obt.calculate_aer(setting)
    for mode in mode_list:
        print('Mode:', mode)
        print(results[mode]['aer'])
        print()

AWI:

Mode: alti
[0.5973234794766803, 0.4379306901028662, 0.4624488165384999, 0.5834415260161789, 0.6851093578348147, 0.700838909417757]

Mode: decoder.encoder_attn
[0.47103765105363027, 0.2622590632178169, 0.3613802057325477, 0.8368121442125237, 0.830720063916908, 0.8438030560271647]

Mode: alti_enc_cross_attn
[0.5973234794766803, 0.41980425446919, 0.4998002596624388, 0.8016578448017577, 0.8072505742534705, 0.8353140916808149]

Mode: attn_w
[0.4966543493458504, 0.2459802257065814, 0.40002996105063415, 0.8208828522920204, 0.8297712973134925, 0.8507939678418057]

Mode: cross_attn_contributions_proj
[0.6370218715669629, 0.4268950364526116, 0.6398681713772096, 0.8267252571656847, 0.889243982822331, 0.8449515629681414]

Mode: cross_attn_contrib_proj_alti
[0.7283032058324179, 0.5296115050434436, 0.6806651353240787, 0.8244781783681214, 0.843353640267652, 0.8578348147408369]

AWO:

Mode: alti
[0.7944172575651653, 0.7859782283032059, 0.7295016478577849, 0.6367721961450115, 0.5687106761210426, 