### Imports

In [None]:
import einops
import torch
import collections
import numpy as np
import seaborn as sns
import warnings
import logging

from fairseq import *
from functools import partial
from collections import defaultdict
from einops import rearrange


logger = logging.getLogger(__name__)
warnings.filterwarnings('ignore')

### Load Model and Data

In [None]:
model_name = "" # multiformer or transformer (String)

if model_name == "multiformer":
    from fairseq.models.speech_to_text.s2t_multiformer import S2TMultiformerModel as model_
elif model_name == "transformer":
    from fairseq.models.speech_to_text.s2t_transformer import S2TTransformerModel as model_
else:
    logger.error(
    "Please choose between the two options: multiformer or transformer"
    )


In [None]:
ckpt = "avg_7_around_best.pt" # Averaged checkpoint name (String)
path = "" # Path to where the model's checkpoints are stored (String)
data_name_or_path="" # Path to where Must-C lenguage pair is located (String)
split = "" # Data partition with which to perform the analysis (train_st in the paper) (String)

model = model_.from_pretrained(
    path,
    checkpoint_file=ckpt,
    data_name_or_path=data_name_or_path,
)
model.eval()

if split not in model.task.datasets.keys():
            model.task.load_dataset(split)

### Functions to compute the contribution of each head to he attention output

In [None]:
def get_sample(split, index):

    src_tensor = model.task.dataset(split)[index].source
    tgt_tensor = model.task.dataset(split)[index].target

    return src_tensor, tgt_tensor


def trace_forward(src_tensor, tgt_tensor):

    layer_inputs = defaultdict(list)
    layer_outputs = defaultdict(list)

    model.zero_grad()

    def save_activation(name, mod, inp, out):
        layer_inputs[name].append(inp)
        layer_outputs[name].append(out)

    handles = {}

    for name, layer in model.named_modules():
        handles[name] = layer.register_forward_hook(partial(save_activation, name))
    
    src_tensor = src_tensor.unsqueeze(0).to(model.device)
    tgt_tensor = torch.cat([
        torch.tensor([model.task.tgt_dict.eos_index]),
        tgt_tensor[:-1]
    ]).unsqueeze(0).to(model.device)

    # Inference
    model_output, encoder_out = model.models[0](src_tensor, torch.Tensor([src_tensor.size(-2)]), tgt_tensor)
    log_probs = model.models[0].get_normalized_probs(model_output, log_probs=True, sample=None)

    for k, v in handles.items():
        handles[k].remove()

    return layer_inputs, layer_outputs


def get_layer_contributions(layer, contrib_type, model_name):

    if model_name == "multiformer":
        wo = model.models[0].encoder.transformer_layers[layer].self_attn.wo.weight # (h·d_h) x D
        bo = model.models[0].encoder.transformer_layers[layer].self_attn.wo.bias # D

        pre_wo = rearrange(
            layer_inputs[f"models.0.encoder.transformer_layers.{layer}.self_attn.wo"][0][0][0],
            't (h c) -> t h c',  
            h=model.models[0].encoder.transformer_layers[layer].self_attn.num_heads,
        )
        pos_wo = layer_outputs[f"models.0.encoder.transformer_layers.{layer}.self_attn.wo"][0][0] # T x D

    elif model_name == "transformer":
        wo = model.models[0].encoder.transformer_layers[layer].self_attn.out_proj.weight # (h·d_h) x D
        bo = model.models[0].encoder.transformer_layers[layer].self_attn.out_proj.bias # D

        pre_wo = rearrange(
            layer_inputs[f"models.0.encoder.transformer_layers.{layer}.self_attn.out_proj"][0][0],
            't b (h c) -> b t h c',  
            h=model.models[0].encoder.transformer_layers[layer].self_attn.num_heads,
        )[0]

        pos_wo = rearrange(
            layer_outputs[f"models.0.encoder.transformer_layers.{layer}.self_attn.out_proj"][0],
            't b c -> b t c',  
            c=model.models[0].encoder.transformer_layers[layer].self_attn.embed_dim,
        )[0]

    else:
        logger.error(
        "Please choose between the two options: multiformer or transformer"
        )

    wo = rearrange(
        wo,
        'd (h c) -> h c d',  
        h=model.models[0].encoder.transformer_layers[layer].self_attn.num_heads,
    )

    pos_wo_ = torch.einsum(
        't h c , h c d -> t h d',
        pre_wo,
        wo
    )

    assert (pos_wo_.sum(-2) + bo - pos_wo).sum() / pos_wo.numel() < 1e-6

    if contrib_type == 'norm_l1':
        contrib = torch.norm(pos_wo_, dim=-1, p=1)
    elif contrib_type == 'norm_l2':
        contrib = torch.norm(pos_wo_, dim=-1, p=2)
    
    contrib = contrib/contrib.sum(-1, keepdim=True)
    contrib = contrib.median(-2, keepdim=True).values

    return contrib


def get_contributions(contrib_type, model_name):
    contributions = []
    for l in range(len(model.models[0].encoder.transformer_layers)):
        contributions.append(get_layer_contributions(l, contrib_type, model_name))
    return torch.stack(contributions, dim=-2)

### Analysis with N Random Samples

In [None]:
num_samples = 500 # Integer number of samples to perform the analysis (500 in the paper) (Integer)

contributions = []
for index in torch.randint(len(model.task.dataset(split)), (num_samples,)).tolist():
    src_tensor, tgt_tensor = get_sample(split, index)
    layer_inputs, layer_outputs = trace_forward(src_tensor, tgt_tensor)
    contributions.append(get_contributions('norm_l1', model_name))

contributions = torch.cat(contributions, dim=0)

contributions = rearrange(
    contributions,
    't l h -> l h t',  
    h=model.models[0].encoder.transformer_layers[0].self_attn.num_heads,
)

In [None]:
contrib_median = contributions.median(-1).values.detach().numpy()
ax = sns.heatmap(contrib_median, linewidths=.5, vmin=0.1, vmax=0.5, cmap="Blues",)