<a href="https://colab.research.google.com/github/erees1/alignment-jam/blob/main/Whisper_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

Whisper attention investigation

# Setup

In [1]:
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")

Running as a Jupyter notebook - intended for development only!


In [2]:
import torch
import torch.nn as nn

from collections import defaultdict

## Data

In [3]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import datasets
import torch

In [4]:
# load dummy dataset and read soundfiles
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")

Found cached dataset librispeech_asr_dummy (/home/edwardr/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


# Infering the model code

In [5]:
# Have a look at the source code for the forward method
import inspect
def print_method(func):
    lines = inspect.getsource(func)
    print(lines)
print_method(m.forward)

NameError: name 'm' is not defined

In [None]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")

In [None]:
print_method(model.generate)

# Attention Plots

In [13]:
def register_hooks(model):
    """
    Add hooks to the model to extract attention scores from both self and cross attn modules
    """
    attention_scores = defaultdict(list)
    def getActivation(name):
        # the hook signature
        def hook(module, input, output):
            attention_scores[name].append(output[1].detach())
        return hook

    hooks = []
    for i, layer in enumerate(model.model.decoder.layers): 
        h = layer.self_attn.register_forward_hook(getActivation(f'layer-{i}.self_attn'))
        hooks.append(h)
        h = layer.encoder_attn.register_forward_hook(getActivation(f'layer-{i}.encoder_attn'))
        hooks.append(h)
    return hooks, attention_scores

In [14]:
def get_num_features(audio_len):
    """
    Get number of whisper features / tokens corresponding to an audio length
    """
    n_mfccs = ((audio_len - 400) // processor.feature_extractor.hop_length + 1)
    attn_len = n_mfccs // 2 # Whisper has a conv layer with stride of 2 at the start of the encoder
    return  attn_len

In [15]:
def run_model_with_hooks(model, data):
    """
    Add hooks to a model and run it
    """
    hooks, attention_scores = register_hooks(model)
    audio_len = len(data)
    feat_len = get_num_features(audio_len)

    input_features = processor(data, return_tensors="pt", sampling_rate=16000, return_attention_mask = True).input_features 
    # Generate logits
    predicted_ids = model.generate(input_features, decoder_input_ids = torch.tensor([[50258]]), output_attentions=True)
    # take argmax and decode
    transcription_as_list = [processor.decode(p) for p in predicted_ids.flatten()]
    [h.remove() for h in hooks]
    return transcription_as_list, attention_scores, feat_len

In [16]:
def process_outputs(attention_scores, feat_len):
    """
    Convert the activations stored by the hooks into a dictionary where key is the layer name and each value is of shape
    N, N, D where N is the attention seq length and D is number of heads.
    """

    output = {}
    len_of_output_seq = len(list(attention_scores.values())[0])
    for k, v in attention_scores.items():
        # v is list of attention scores where v[i] is 
        # the attention for the ith token
        # v[i] has shape (1, n_heads, 1, attn_len)
        if 'encoder_attn' in k:
            # len of input_sequence + len of output seq required to plot
            # attn pairs
            seq_len = feat_len + len_of_output_seq  
        else:
            seq_len = len_of_output_seq
        
        n_heads = v[-1].shape[1]

        full_shape = torch.zeros((seq_len, seq_len, n_heads))
        for token_id, score in enumerate(v):
            expand = torch.zeros((seq_len, n_heads))
            s = score.reshape((n_heads, score.shape[-1]))
            s = s.permute(-1, 0)   # (s seq_len, num_heads)
            if 'encoder' in k:
                expand[:s.shape[0]] = s
                
                full_shape[token_id - len_of_output_seq] = expand
            else:
                expand[:s.shape[0]] = s  # (seq_len, num_heads)
                full_shape[token_id, ...] = expand
        output[k] = full_shape
    return output

In [26]:
from whisper.audio import load_audio
def get_attn_scores_for_file(file_path):
    """
    Plot the attention scores using pysvelte for an audio file
    """

    transcript_as_list, raw_attention_scores, feat_len = run_model_with_hooks(model, load_audio(file_path))

    # reduce length of the attention scores - modifies in place!
    window_size = 2
    stride = 2
    remove_padding = True
    if remove_padding:
        input_tokens_to_plot = feat_len
    else:
        input_tokens_to_plot = raw_attention_scores['layer-0.encoder_attn'][0][0, 0, 0].shape[-1]

    from torch.nn import AvgPool1d
    avg = AvgPool1d(window_size, stride, 0)
    def sum_pool_1d(x):
        return avg(x) * avg.kernel_size[0]

    avg = AvgPool1d(window_size, stride, 0)
    
    for layer_name, attention in raw_attention_scores.items():
        if 'encoder' in layer_name:
            new_a = []
            for token_id, attn_at_token in enumerate(attention):
                a = sum_pool_1d(attn_at_token[:,:,:,:input_tokens_to_plot].squeeze()).unsqueeze(0).unsqueeze(-2)
                attention[token_id] = a
            raw_attention_scores[layer_name] = attention
    new_input_tokens_len = a.shape[-1]
    attention_scores = process_outputs(raw_attention_scores, new_input_tokens_len)

    import pysvelte
    n_layers=1
    for k, v in list(attention_scores.items())[:n_layers*2]:
        print()
        print(f"{k}")
        print("-" * 20)
        if 'encoder_attn' in k:
            tokens=['* '] * new_input_tokens_len + transcript_as_list[:-1]
        else:
            tokens = transcript_as_list[:-1]
        pysvelte.AttentionMulti(tokens=tokens, attention=v).show()


In [27]:
import gdown
# Download audio file that has halucination
gdown.download(url="https://drive.google.com/u/0/uc?id=14R1MGTozskNpVXWhkt6ZkmSbc_9DKuLU", output="audio1.wav")

Downloading...
From: https://drive.google.com/u/0/uc?id=14R1MGTozskNpVXWhkt6ZkmSbc_9DKuLU
To: /home/edwardr/git/alignment-jam/audio1.wav
100%|██████████████████████████████████████████████████████████████████████████████| 1.09M/1.09M [00:00<00:00, 33.9MB/s]


'audio1.wav'

In [28]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

In [29]:
get_attn_scores_for_file('audio1.wav')




layer-0.self_attn
--------------------



layer-0.encoder_attn
--------------------
