In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
import pickle

torch.cuda.set_device(5)
torch.cuda.current_device()

import warnings
from pathlib import Path
from wrappers.transformer_wrapper import FairseqTransformerHub
from wrappers.multilingual_transformer_wrapper import FairseqMultilingualTransformerHub
from alignment.aer import aer
import itertools

import alignment.align as align
from types import SimpleNamespace

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

import logging
logger = logging.getLogger()

logger.setLevel('WARNING')
warnings.simplefilter('ignore')

from dotenv import load_dotenv
load_dotenv()
device = "cuda" if torch.cuda.is_available() else "cpu"

## Load model

In [3]:
model = 'bilingual'# bilingual/multilingual
model_size = 'small'

In [4]:
if model == 'bilingual':
    # Bilingual paths
    europarl_dir = Path(os.environ['EUROPARL_DATA_DIR'])
    ckpt_dir = Path(os.environ['EUROPARL_CKPT_DIR'])
    # Choose model
    model_type = 'baseline'
    seed = 9819 # 2253  2453  5498  924  9819
    model_name = f"{model_type}/{seed}"

    args = SimpleNamespace(
        src = "de",
        tgt = "en",
        tokenizer = "bpe",
        test_set_dir = Path(os.environ['EUROPARL_DATA_DIR']) / "processed_data/",
        model_name_save = model_name.replace('/','_'),
        pre_layer_norm = False,
        num_layers = 6
        )

elif model == 'multilingual':
    # Multilingual paths
    ckpt_dir = Path(os.environ['M2M_CKPT_DIR'])
    europarl_dir = Path("./data/de-en")
    model_size = 'small' # small (412M) /big (1.2B)

    args = SimpleNamespace(
        src = "de",
        tgt = "en",
        tokenizer = "spm",
        test_set_dir = Path("./data/de-en").as_posix(),
        model_name_save = f'm2m100_{model_size}',
        pre_layer_norm = True,
        num_layers = 12
        )

In [5]:
#ckpt_dir = Path(os.environ['IWSLT14_CKPT_DIR'])

lang_flores_dict = {'en': 'eng', 'es': 'spa', 'zu': 'zul',
                    'de': 'deu', 'yo': 'yor', 'ms': 'msa',
                    'fr': 'fra', 'xh': 'xho'}
source_lang = 'de'
target_lang = 'en'
if model == 'bilingual':
    hub = FairseqTransformerHub.from_pretrained(
        ckpt_dir / f"{model_type}/{seed}",
        checkpoint_file=f"checkpoint_best.pt",
        data_name_or_path=(europarl_dir / "processed_data/fairseq_preprocessed_data").as_posix(), # processed data
    )

elif model == 'multilingual':
    # Checkpoint names
    if model_size=='big':
        checkpoint_file = '1.2B_last_checkpoint.pt'
    else:
        checkpoint_file = '418M_last_checkpoint.pt'
    data_name_or_path='.'
    hub = FairseqMultilingualTransformerHub.from_pretrained(
        ckpt_dir,
        checkpoint_file=checkpoint_file,
        data_name_or_path=data_name_or_path,
        source_lang= args.src,
        target_lang= args.tgt,
        lang_pairs =f'{source_lang}-{target_lang}',
        fixed_dictionary=f'{ckpt_dir}/model_dict.128k.txt')


## Compute AER

In [6]:
mode_list = ['alti', 'decoder.encoder_attn', 'alti_enc_cross_attn','attn_w']
aer_obt = aer(args, mode_list)

In [14]:
contrib_type = 'l1'
aer_obt.extract_contribution_matrix(hub, args.model_name_save, contrib_type,
                                pre_layer_norm=args.pre_layer_norm)

0
200
400


In [7]:
aer_obt.extract_alignments(final_punc_mark=False)

alti
decoder.encoder_attn
alti_enc_cross_attn
attn_w
alti
decoder.encoder_attn
alti_enc_cross_attn
attn_w


In [8]:
mode_list = ['alti', 'decoder.encoder_attn', 'alti_enc_cross_attn', 'attn_w']

for setting in ['AWI', 'AWO']:
    print(f'{setting}:\n')
    results = aer_obt.calculate_aer(setting)
    for mode in mode_list:
        print('Mode:', mode)
        print(results[mode]['aer'])
        print()

AWI:

Mode: alti
[0.5973234794766803, 0.4379306901028662, 0.4624488165384999, 0.5834415260161789, 0.6851093578348147, 0.700838909417757]

Mode: decoder.encoder_attn
[0.47103765105363027, 0.2602533927587761, 0.3611249875112399, 0.8550934048921457, 0.830720063916908, 0.8438843884388438]

Mode: alti_enc_cross_attn
[0.5973234794766803, 0.41980425446919, 0.4998002596624388, 0.8016479400749064, 0.8072505742534705, 0.8353140916808149]

Mode: attn_w
[0.6170462894930198, 0.2819765747916788, 0.46417410080118193, 0.9033638068448195, 0.9321608040201005, 0.9402746494223643]

AWO:

Mode: alti
[0.7944172575651653, 0.7859782283032059, 0.7295016478577849, 0.6367721961450115, 0.5687106761210426, 0.5575751523020074]

Mode: decoder.encoder_attn
[0.8521921502047338, 0.8555378008588834, 0.6951499999999999, 0.47756297006713444, 0.38919404773794064, 0.5158]

Mode: alti_enc_cross_attn
[0.7944172575651653, 0.7869269949066213, 0.6630380505343054, 0.5599720363527414, 0.5099870168780585, 0.5739039248976331]

Mode: