In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
import pickle

torch.cuda.set_device(2)
torch.cuda.current_device()

import warnings
from pathlib import Path
from wrappers.transformer_wrapper import FairseqTransformerHub
from alignment.aer import aer
import itertools

import alignment.align as align

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

import logging
logger = logging.getLogger()

logger.setLevel('WARNING')
warnings.simplefilter('ignore')

from dotenv import load_dotenv
load_dotenv()
device = "cuda" if torch.cuda.is_available() else "cpu"

## Load model

## Bilingual

In [3]:
# Paths
europarl_dir = Path(os.environ['EUROPARL_DATA_DIR'])
ckpt_dir = Path(os.environ['EUROPARL_CKPT_DIR'])
#iwslt14_dir = Path(os.environ['IWSLT14_DATA_DIR'])
#ckpt_dir = Path(os.environ['IWSLT14_CKPT_DIR'])

# Choose model
model_type = 'baseline'
seed = 5498 # 2253  2453  5498  9240	9819
model_name = f"{model_type}/{seed}"

data_sample = 'interactive' # generate/interactive

NUM_LAYERS = 6

# Get sample from Gold alignment dataset
# test_src_bpe = europarl_dir / "processed_data/test.bpe.de"
# test_tgt_bpe = europarl_dir / "processed_data/test.bpe.en"

# test_src_word = europarl_dir / "data_in_progress/test.uc.de"
# test_tgt_word = europarl_dir / "data_in_progress/test.uc.en"

gold_alignment = europarl_dir / "gold_alignment/alignment.talp"

test_set_dir = europarl_dir / "processed_data/"
src = "de"
tgt = "en"
tokenizer = "bpe"

model_name_save = model_name.replace('/','_')
store_filename = f'./results/alignments/{model_name_save}/extracted_matrix.pkl'

pre_layer_norm = False

In [4]:
hub = FairseqTransformerHub.from_pretrained(
    ckpt_dir / model_name,
    checkpoint_file=f"checkpoint_best.pt",
    data_name_or_path=(europarl_dir / "processed_data/fairseq_preprocessed_data").as_posix(), # processed data
)

### Multilingual

In [3]:
model_size = 'small' # small (412M) /big (1.2B)
data_sample = 'interactive' # generate/interactive
teacher_forcing = False # teacher forcing/free decoding

# Paths
# Checkpoint path
ckpt_dir = Path(os.environ['M2M_CKPT_DIR'])

NUM_LAYERS = 12

model_name_save = f'm2m100_{model_size}'
store_filename = f'./results/alignments/{model_name_save}/extracted_matrix.pkl'

test_set_dir = Path("./data/de-en")
src = "de"
tgt = "en"
tokenizer = "spm"

gold_alignment = test_set_dir / "alignment.talp"

# Path to binarized data
if data_sample == 'generate':
    m2m_data_dir = Path(os.environ['M2M_DATA_DIR'])
    data_name_or_path=(f'{m2m_data_dir}/data_bin')
else:
    # use "." to avoid loading
    data_name_or_path='.'

# Chackpoint names
if model_size=='big':
    checkpoint_file = '1.2B_last_checkpoint.pt'
else:
    checkpoint_file = '418M_last_checkpoint.pt'
pre_layer_norm = True

In [4]:
from wrappers.multilingual_transformer_wrapper import FairseqMultilingualTransformerHub

hub = FairseqMultilingualTransformerHub.from_pretrained(
    ckpt_dir,
    checkpoint_file=checkpoint_file,
    data_name_or_path=data_name_or_path,
    source_lang= 'de',
    target_lang= 'en',
    lang_pairs ='de-en')




## Compute AER

In [5]:
mode_list = ['alti', 'decoder.encoder_attn', 'alti_enc_cross_attn']
aer_obt = aer(test_set_dir, model_name_save, mode_list, NUM_LAYERS, src, tgt, tokenizer)

In [6]:
contrib_type = 'l1'
aer_obt.extract_contribution_matrix(hub, model_name_save, contrib_type,
                                pre_layer_norm=pre_layer_norm)

0
200
400


In [7]:
aer_obt.extract_alignments(final_punc_mark=True)

In [8]:
num_layers=NUM_LAYERS
mode_list = ['alti', 'decoder.encoder_attn', 'alti_enc_cross_attn']

for setting in ['AWI', 'AWO']:
    print(f'{setting}:\n')
    results = aer_obt.calculate_aer(setting)
    for mode in mode_list:
        print('Mode:', mode)
        print(results[mode]['aer'])
        print()

AWI:

Mode: alti
[0.41421152501747727, 0.3838010586237891, 0.44407270548287225, 0.5724558074503145, 0.6523519424747828, 0.6881553979826226]

Mode: decoder.encoder_attn
[0.2561669829222012, 0.22106261859582543, 0.5246179966044142, 0.8418555877359433, 0.8218815539798262, 0.8439029261959452]

Mode: alti_enc_cross_attn
[0.41421152501747727, 0.37795865375012483, 0.587885748526915, 0.8095475881354239, 0.7896234894636972, 0.8222810346549485]

AWO:

Mode: alti
[0.7648556876061121, 0.7424847697992609, 0.6704284430240688, 0.5896334764805753, 0.5265654648956357, 0.5130330570258663]

Mode: decoder.encoder_attn
[0.8267751922500749, 0.7947668031558973, 0.573204833716169, 0.484470188754619, 0.3423050034954559, 0.4892140217716968]

Mode: alti_enc_cross_attn
[0.7648556876061121, 0.7307500249675423, 0.5763008089483671, 0.5444422251073604, 0.4864675921302307, 0.5384500149805254]

