Before running this notebook, you need to compute head influence (the distance of noisy word representations before and after masking a specific head). You can run the following to do so:

`Python ./representations/masked_head_distance.py`

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict
import spacy
import pickle
import os
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def get_tag_attention(weights, heads, df, lang, is_monolingual):
    nlp = spacy.load("fr_core_news_sm" if lang=='fr' else "en_core_web_sm")
    
    num_layers = heads.shape[-1]
    start_idx = 1
    if is_monolingual:
        start_idx = 0
                
    output = {i:defaultdict(list) for i in range(num_layers)}
    
    for i, sentence_weights in enumerate(weights):
        sentence = df.loc[i, 'line']
        sentence_tags = [item.pos_ for item in nlp(sentence)]
        for j, layer_weights in enumerate(sentence_weights):
            influence_head = heads[i][j]
            head_weights = layer_weights[influence_head][start_idx:-1]
            for k, word_weight in enumerate(head_weights):
                if k==int(df.loc[i, 'index']):
                    continue
                output[j][sentence_tags[k]].append(word_weight)
                
    return output

In [None]:
def keep_top_n_longest_lists(d, n=10):
    sorted_dict = dict(sorted(d.items(), key=lambda item: len(item[1]), reverse=True))
    top_n_dict = dict(list(sorted_dict.items())[:n])
    
    return top_n_dict

In [4]:
def get_average_tag_attention(tag_attention):
    average_dict = defaultdict(list)
    for k, item in tag_attention.items():
        item = keep_top_n_longest_lists(item)
        for tag, attention in item.items():
            average_dict['Attention'].append(sum(attention)/len(attention))
            average_dict['Tag'].append(tag)
            average_dict['Layer'].append(k+1)
    return pd.DataFrame(average_dict)

In [None]:
def load_pickle(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

In [None]:
def save_or_load_attention_csv(df, output_path, src, model, error_type):
    csv_path = f'{output_path}{model}/test.{error_type}.{src}.tag_attention.csv'
    if not os.path.exists(csv_path):
        print("yes")
        df.to_csv(csv_path, index=False)
    return pd.read_csv(csv_path)

In [None]:
lang = "es"
errors = ['article', 'nounnum', 'prep']

fig_model_names = ['OPUS-MT', 'M2M100', 'MBART', 'NLLB']

for src in ['en']:
    for lang in ['es']:
        if src == 'fr' and lang != 'es':
            continue

        models = {
            f'opus-bi-{src}-{lang}': f'opus-mt-{src}-{lang}',
            f'm2m100-{src}-{lang}': 'm2m100_418M', 
            f'mbart-{src}-{lang}': 'mbart-large-50-many-to-many-mmt',
            f'nllb-{src}-{lang}': 'nllb-200-distilled-600M'
        }
        model_types = ['Base', 'Clean', 'Noisy']

        fig, axes = plt.subplots(len(models), len(errors), figsize=(32, 24), sharex=True, sharey=True)
        fig.suptitle(f'Average Attention Scores for POS Tags on {src.capitalize()}-{lang.capitalize()}', fontsize=20, y=0.95, fontweight='bold')

        for i, (finetuned_model, base_model) in enumerate(models.items()):
            for j, error in enumerate(errors):
                
                base_path = f'../outputs/representations'
                noisy_file = f'{base_path}/head_masking/{src}-{lang}/{finetuned_model}-{error}/test.{error}.{src}.to_clean.distance.pkl'
                clean_file = f'{base_path}/head_masking/{src}-{lang}/{finetuned_model}-clean-{error}/test.{error}.{src}.to_clean.distance.pkl'
                base_file = f'{base_path}/head_masking/{src}-{lang}/{base_model}/test.{error}.{src}.to_clean.distance.pkl'
                
                
                noisy_output = load_pickle(noisy_file)
                clean_output = load_pickle(clean_file)
                base_output = load_pickle(base_file)

                noisy_weights_file = f'{base_path}/attention_weights/{src}-{lang}/{finetuned_model}-{error}/test.{error}.{src}.attention_weights.pkl'
                clean_weights_file = f'{base_path}/attention_weights/{src}-{lang}/{finetuned_model}-clean-{error}/test.{error}.{src}.attention_weights.pkl'
                base_weights_file = f'{base_path}/attention_weights/{src}-{lang}/{base_model}/test.{error}.{src}.attention_weights.pkl'
                
                noisy_weights = load_pickle(noisy_weights_file)
                clean_weights = load_pickle(clean_weights_file)
                base_weights = load_pickle(base_weights_file)
                
                df_file = f'../data/grammar-noise/{src}-{lang}/test.{error}.{src}.pkl'
                df = pd.read_pickle(df_file)
                df = df[~df['label'].isin(['clean'])].reset_index(drop=True)

                noisy_influence_heads = noisy_output.argmax(-1)
                clean_influence_heads = clean_output.argmax(-1)
                base_influence_heads = base_output.argmax(-1)

                is_monolingual = 'opus' in finetuned_model
                
                noisy_tag_attention = get_tag_attention(noisy_weights, noisy_influence_heads, df, src, is_monolingual)
                clean_tag_attention = get_tag_attention(clean_weights, clean_influence_heads, df, src, is_monolingual)
                base_tag_attention = get_tag_attention(base_weights, base_influence_heads, df, src, is_monolingual)
                
                noisy_df = get_average_tag_attention(noisy_tag_attention)
                clean_df = get_average_tag_attention(clean_tag_attention)
                base_df = get_average_tag_attention(base_tag_attention)
                
                output_path = f'../outputs/representations/attention_weights/{src}-{lang}/'
                
                noisy_df = save_or_load_attention_csv(noisy_df, output_path, src, f'{finetuned_model}-{error}', error)
                clean_df = save_or_load_attention_csv(clean_df, output_path, src, f'{finetuned_model}-clean-{error}', error)
                base_df = save_or_load_attention_csv(base_df, output_path, src, base_model, error)
                
                dfs = [base_df, clean_df, noisy_df]
                vmin = min(df['Attention'].min() for df in dfs)
                vmax = max(df['Attention'].max() for df in dfs)

                for k, (df, title) in enumerate(zip(dfs, model_types)):
                    pos = axes[i, j].get_position()
                    width = pos.width / 3.3
                    new_pos = [pos.x0 + k * (width * 1.1), pos.y0, width, pos.height]
                    sub_ax = fig.add_axes(new_pos)
                    
                    show_y_labels = (k == 0)
                    show_x_labels = (i == len(models) - 1)

                    df['Layer'] = pd.Categorical(df['Layer'], categories=range(1, int(df['Layer'].max())+1), ordered=True)
                    pivot_data = df.pivot(index="Layer", columns="Tag", values="Attention")
                    
                    sns.heatmap(pivot_data, ax=sub_ax, cmap="YlGnBu", vmin=vmin, vmax=vmax, cbar=False, annot=False, fmt='.2f')
                    
                    sub_ax.set_title(title, fontsize=10, fontweight='bold')

                    if show_y_labels:
                        sub_ax.set_ylabel('Layer', fontsize=10, fontweight='bold')
                        sub_ax.set_yticklabels(range(1, int(df['Layer'].max())+1), rotation=90, ha='right', fontsize=8, fontweight='bold')
                    else:
                        sub_ax.set_ylabel('')
                        sub_ax.set_yticklabels([])

                    if show_x_labels and k == 1:
                        sub_ax.set_xlabel('Tag', fontsize=10, fontweight='bold')
                    else:
                        sub_ax.set_xlabel('')
                    sub_ax.set_xticklabels(pivot_data.columns, rotation=90, ha='right', fontsize=8, fontweight='bold')

                axes[i, j].remove()

        for ax, col in zip(axes[0], errors):
            fig.text(ax.get_position().x0 + ax.get_position().width / 2, 0.91, col.capitalize(), ha='center', va='bottom', fontsize=16, fontweight='bold')

        for ax, row in zip(axes[:, 0], fig_model_names):
            fig.text(0.09, ax.get_position().y0 + ax.get_position().height / 2, row, ha='right', va='center', fontsize=16, fontweight='bold', rotation=90)

        plt.subplots_adjust(hspace=0., wspace=0.1)
        plt.savefig(f'../figures/attention_to_pos/attention_{src}-{lang}.pdf', dpi=300, bbox_inches='tight')

        plt.show()
