<a href="https://colab.research.google.com/github/erees1/alignment-jam/blob/main/Whisper_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

Whisper attention investigation

# Setup

In [75]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
  IN_COLAB = False
  print("Running as a Jupyter notebook - intended for development only!")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/PySvelte.git
  Cloning https://github.com/neelnanda-io/PySvelte.git to /tmp/pip-req-build-2pla7jr_
  Running command git clone -q https://github.com/neelnanda-io/PySvelte.git /tmp/pip-req-build-2pla7jr_


In [76]:
import torch
import torch.nn as nn

from collections import defaultdict

## Data

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import datasets
import torch

In [None]:
# load dummy dataset and read soundfiles
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")

Downloading builder script:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr_dummy/clean to /root/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset librispeech_asr_dummy downloaded and prepared to /root/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b. Subsequent calls will reuse this data.


Downloading:   0%|          | 0.00/185k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/827 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/494k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.11k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/290M [00:00<?, ?B/s]

In [None]:
m = model.model.decoder.layers[0].self_attn

# Infering the model code

In [None]:
# Have a look at the source code for the forward method
import inspect
def print_method(func):
    lines = inspect.getsource(func)
    print(lines)
print_method(m.forward)

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_

In [None]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")

In [None]:
print_method(model.generate)

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        penalty_alpha: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
        force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
        encoder_no_repeat_ngram_size: Optional[

# Attention Plots

In [None]:
def register_hooks(model):
    # Activation hook to get attention scores
    attention_scores = defaultdict(list)
    def getActivation(name):
        # the hook signature
        def hook(module, input, output):
            attention_scores[name].append(output[1].detach())
        return hook

    hooks = []
    for i, layer in enumerate(model.model.decoder.layers): 
        h = layer.self_attn.register_forward_hook(getActivation(f'layer-{i}.self_attn'))
        hooks.append(h)
        h = layer.encoder_attn.register_forward_hook(getActivation(f'layer-{i}.encoder_attn'))
        hooks.append(h)
    return hooks, attention_scores

In [308]:
def get_num_features(audio_len):
    n_mfccs = ((audio_len - 400) // processor.feature_extractor.hop_length + 1)
    attn_len = n_mfccs // 2 # Whisper has a conv layer with stride of 2 at the start of the encoder
    return  attn_len

In [309]:
def run_model_with_hooks(model, data):
    hooks, attention_scores = register_hooks(model)
    audio_len = len(data["audio"]["array"])
    feat_len = get_num_features(audio_len)

    input_features = processor(data["audio"]["array"], return_tensors="pt", sampling_rate=16000, return_attention_mask = True).input_features 
    # Generate logits
    predicted_ids = model.generate(input_features, decoder_input_ids = torch.tensor([[50258]]), output_attentions=True)
    # take argmax and decode
    transcription_as_list = [processor.decode(p) for p in predicted_ids.flatten()]
    [h.remove() for h in hooks]
    return transcription_as_list, attention_scores, feat_len

In [310]:
def process_outputs(attention_scores, feat_len):
    """
    Produce a dictionary where key is the layer name and each value is of shape
    N, N, D where N is the attention seq length and D is number of heads.
    """

    output = {}
    len_of_output_seq = len(list(attention_scores.values())[0])
    for k, v in attention_scores.items():
        # v is list of attention scores where v[i] is 
        # the attention for the ith token
        # v[i] has shape (1, n_heads, 1, attn_len)
        if 'encoder_attn' in k:
            # len of input_sequence + len of output seq required to plot
            # attn pairs
            seq_len = feat_len + len_of_output_seq  
        else:
            seq_len = len_of_output_seq
        
        n_heads = v[-1].shape[1]

        full_shape = torch.zeros((seq_len, seq_len, n_heads))
        for token_id, score in enumerate(v):
            expand = torch.zeros((seq_len, n_heads))
            s = score.reshape((n_heads, score.shape[-1]))
            s = s.permute(-1, 0)   # (s seq_len, num_heads)
            if 'encoder' in k:
                expand[:s.shape[0]] = s
                
                full_shape[token_id - len_of_output_seq] = expand
            else:
                expand[:s.shape[0]] = s  # (seq_len, num_heads)
                full_shape[token_id, ...] = expand
        output[k] = full_shape
    return output

In [311]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")

In [304]:
processor(ds[0]['audio']['array']).input_features[0].shape

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


(80, 3000)

In [293]:
transcript_as_list, raw_attention_scores, feat_len = run_model_with_hooks(model, ds[0])



In [295]:
# reduce length of the attention scores - modifies in place!
window_size = 3
stride = 3
remove_padding = False
if remove_padding:
    input_tokens_to_plot = feat_len
else:
    input_tokens_to_plot = raw_attention_scores['layer-0.encoder_attn'][0][0, 0, 0].shape[-1]

from torch.nn import AvgPool1d
avg = AvgPool1d(window_size, stride, 0)
def sum_pool_1d(x):
    return avg(x) * avg.kernel_size[0]

avg = AvgPool1d(window_size, stride, 0)
for layer_name, attention in raw_attention_scores.items():
    if 'encoder' in layer_name:
        new_a = []
        for token_id, attn_at_token in enumerate(attention):
            a = sum_pool_1d(attn_at_token[:,:,:,:input_tokens_to_plot].squeeze()).unsqueeze(0).unsqueeze(-2)
            attention[token_id] = a
        raw_attention_scores[layer_name] = attention
new_input_tokens_len = a.shape[-1]
print(new_input_tokens_len)

500


In [296]:
attention_scores = process_outputs(raw_attention_scores, new_input_tokens_len)

In [297]:
import pysvelte
n_layers=4
for k, v in list(attention_scores.items())[:n_layers*2]:
    print()
    print(f"{k}")
    print("-" * 20)
    if 'encoder_attn' in k:
        tokens=['*'] * new_input_tokens_len + transcript_as_list[:-1]
    else:
        tokens = transcript_as_list[:-1]
    pysvelte.AttentionMulti(tokens=tokens, attention=v).show()



layer-0.self_attn
--------------------



layer-0.encoder_attn
--------------------



layer-1.self_attn
--------------------



layer-1.encoder_attn
--------------------



layer-2.self_attn
--------------------



layer-2.encoder_attn
--------------------



layer-3.self_attn
--------------------



layer-3.encoder_attn
--------------------
