# Workshop on Speech Perception, Part 2: Comparing Audio Transformer representations with Centered Kernel Alignment

*Interpretability & Explainability in AI, MSc A.I., University of Amsterdam, June 2024*

This notebook includes contributions by: Marianne de Heer Kloots & Marta Grasa.
It makes use of the [reference code](https://colab.research.google.com/github/google-research/google-research/blob/master/representation_similarity/Demo.ipynb) provided by Simon Kornblith and colleagues for implementing CKA computations.

## Comparing representations with similarity-based interpretability

In the first notebook for this workshop we made our first steps towards understanding where different types of information are represented in Wav2Vec2, by evaluating how well we can decode interpretable features (like amplitude envelopes and phonemes) across the model's hidden layers.  

In this notebook we experiment with a different technique for analyzing model internals, by computing similarities between different representation spaces. Such representational similarity analyses even allow us to compare representational spaces with very different formats, such model embeddings and syntactic trees (as in [Shen et al., 2023](https://www.isca-archive.org/interspeech_2023/shen23_interspeech.html)), or model embeddings and human brain recordings (as in [Abnar et al., 2019](https://aclanthology.org/W19-4820/)) — as long as we are able to define a similarity measure over a common set of stimuli within each representational space. It also allows us to examine for example how representations evolve across different layers of the same model, or how representations generated by different models compare.

For this notebook, we will focus on a method called Centered Kernel Alignment (CKA), introduced by [Kornblith et al. (2019)](https://proceedings.mlr.press/v97/kornblith19a.html). We will use it to compare between a range of different Wav2Vec2 models (untrained, pretrained, and finetuned versions of the base and large architectures), and also to compare the internal representations of these models to more interpretable feature spaces like MFCC and GloVe vectors.

##### <input type="checkbox"/> <font color='blue'><b>ToThink 5</b></font>: Understanding similarity-based interpretability techniques

Familiarize yourself with similarity-based interpretability techniques by watching [this video](https://youtu.be/u7Dvb_a1D-0). If you're interested, a more in-depth explanation of CKA specifically is also covered in [this talk](https://youtu.be/TBjdvjdS2KM). What is an advantage of CKA as compared to other similarity-based interpretability methods, like CCA?

In [None]:
# @title CKA functions
## taken from Kornblith et al. (2019) - https://cka-similarity.github.io/
import numpy as np

def gram_linear(x):
  """Compute Gram (kernel) matrix for a linear kernel.

  Args:
    x: A num_examples x num_features matrix of features.

  Returns:
    A num_examples x num_examples Gram matrix of examples.
  """
  return x.dot(x.T)


def gram_rbf(x, threshold=1.0):
  """Compute Gram (kernel) matrix for an RBF kernel.

  Args:
    x: A num_examples x num_features matrix of features.
    threshold: Fraction of median Euclidean distance to use as RBF kernel
      bandwidth. (This is the heuristic we use in the paper. There are other
      possible ways to set the bandwidth; we didn't try them.)

  Returns:
    A num_examples x num_examples Gram matrix of examples.
  """
  dot_products = x.dot(x.T)
  sq_norms = np.diag(dot_products)
  sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]
  sq_median_distance = np.median(sq_distances)
  return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))


def center_gram(gram, unbiased=False):
  """Center a symmetric Gram matrix.

  This is equvialent to centering the (possibly infinite-dimensional) features
  induced by the kernel before computing the Gram matrix.

  Args:
    gram: A num_examples x num_examples symmetric matrix.
    unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
      estimate of HSIC. Note that this estimator may be negative.

  Returns:
    A symmetric matrix with centered columns and rows.
  """
  if not np.allclose(gram, gram.T):
    raise ValueError('Input must be a symmetric matrix.')
  gram = gram.copy()

  if unbiased:
    # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
    # L. (2014). Partial distance correlation with methods for dissimilarities.
    # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
    # stable than the alternative from Song et al. (2007).
    n = gram.shape[0]
    np.fill_diagonal(gram, 0)
    means = np.sum(gram, 0, dtype=np.float64) / (n - 2)
    means -= np.sum(means) / (2 * (n - 1))
    gram -= means[:, None]
    gram -= means[None, :]
    np.fill_diagonal(gram, 0)
  else:
    means = np.mean(gram, 0, dtype=np.float64)
    means -= np.mean(means) / 2
    gram -= means[:, None]
    gram -= means[None, :]

  return gram


def cka(gram_x, gram_y, debiased=False):
  """Compute CKA.

  Args:
    gram_x: A num_examples x num_examples Gram matrix.
    gram_y: A num_examples x num_examples Gram matrix.
    debiased: Use unbiased estimator of HSIC. CKA may still be biased.

  Returns:
    The value of CKA between X and Y.
  """
  gram_x = center_gram(gram_x, unbiased=debiased)
  gram_y = center_gram(gram_y, unbiased=debiased)

  # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
  # n*(n-3) (unbiased variant), but this cancels for CKA.
  scaled_hsic = gram_x.ravel().dot(gram_y.ravel())

  normalization_x = np.linalg.norm(gram_x)
  normalization_y = np.linalg.norm(gram_y)
  return scaled_hsic / (normalization_x * normalization_y)


def _debiased_dot_product_similarity_helper(
    xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y,
    n):
  """Helper for computing debiased dot product similarity (i.e. linear HSIC)."""
  # This formula can be derived by manipulating the unbiased estimator from
  # Song et al. (2007).
  return (
      xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)
      + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))


def feature_space_linear_cka(features_x, features_y, debiased=False):
  """Compute CKA with a linear kernel, in feature space.

  This is typically faster than computing the Gram matrix when there are fewer
  features than examples.

  Args:
    features_x: A num_examples x num_features matrix of features.
    features_y: A num_examples x num_features matrix of features.
    debiased: Use unbiased estimator of dot product similarity. CKA may still be
      biased. Note that this estimator may be negative.

  Returns:
    The value of CKA between X and Y.
  """
  features_x = features_x - np.mean(features_x, 0, keepdims=True)
  features_y = features_y - np.mean(features_y, 0, keepdims=True)

  dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2
  normalization_x = np.linalg.norm(features_x.T.dot(features_x))
  normalization_y = np.linalg.norm(features_y.T.dot(features_y))

  if debiased:
    n = features_x.shape[0]
    # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
    sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)
    sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)
    squared_norm_x = np.sum(sum_squared_rows_x)
    squared_norm_y = np.sum(sum_squared_rows_y)

    dot_product_similarity = _debiased_dot_product_similarity_helper(
        dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
        squared_norm_x, squared_norm_y, n)
    normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(
        normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
        squared_norm_x, squared_norm_x, n))
    normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(
        normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
        squared_norm_y, squared_norm_y, n))

  return dot_product_similarity / (normalization_x * normalization_y)

In [None]:
# @title Helper functions
def set_seed(seed):
    """Set random seed."""
    if seed == -1:
        seed = random.randint(0, 1000)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # if you are using GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def plot_gram_matrix(feature_list, subset_ids, ax, gram_fun=gram_linear):
    N_subsets = len(subset_ids)
    im = ax.imshow(gram_fun(np.array(feature_list)))
    divider_lines = [(len(feature_list)/N_subsets)*i for i in range(1,N_subsets)]
    ax.vlines(divider_lines, 0, len(feature_list), color='w')
    ax.hlines(divider_lines, 0, len(feature_list), color='w')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(0, len(feature_list))
    ax.set_ylim(0, len(feature_list))
    text_positions = [(divider_lines[0])/2] +\
    [(divider_lines[i]+divider_lines[i+1])/2 for i in range(len(divider_lines)-1)] +\
    [(divider_lines[-1]+len(feature_list))/2]
    for i, tp in enumerate(text_positions):
        ax.text(tp, -10, subset_ids[i], ha='center')
        ax.text(-10, tp, subset_ids[i], va='center', rotation=90)
    return im

# plotting aesthetics
linestyles = {
    'base_unt': 'dotted',
    'base_pret': 'dashed',
    'base_ft': 'solid',
    'large_unt': 'dotted',
    'large_pret': 'dashed',
    'large_ft': 'solid'
}

colors = {
    'base_unt': 'silver',
    'base_pret': 'mediumpurple',
    'base_ft': 'indigo',
    'large_unt': 'silver',
    'large_pret': 'mediumpurple',
    'large_ft': 'indigo'
}

### Configuration

In [None]:
%%capture
!pip install tgt

In [None]:
import tgt
import torch
import IPython.display as ipd
import librosa
import random
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import gc
from collections import defaultdict
from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Model, Wav2Vec2ForCTC
from IPython.display import Audio

In [None]:
# download a subset of the speech accent archive with force-aligned transcriptions
# ('please_call_stella' folder)
!gdown 1rH-EWbtJBp0teUFXCjVHE_u2icvNRNNT
!unzip please_call_stella.zip

# GloVe embeddings for all words in the please_call_stella fragment
!gdown 1tx7F6QLwjaWpg_sZY3mel8a13q0-CME0

In [None]:
# setup variables
STELLA_DIR = '/content/please_call_stella'
SAMP_FREQ = 16000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
set_seed(42)

In [None]:
## load all Wav2Vec2 models we will compare
base_pret_ckpt = "facebook/wav2vec2-base"
base_ft_ckpt = "facebook/wav2vec2-base-960h"
large_pret_ckpt = "facebook/wav2vec2-large"
large_ft_ckpt = "facebook/wav2vec2-large-960h"
base_config = Wav2Vec2Config.from_pretrained(base_pret_ckpt)
large_config = Wav2Vec2Config.from_pretrained(large_pret_ckpt)

base_models = ['base_unt', 'base_pret', 'base_ft']
large_models = ['large_unt', 'large_pret', 'large_ft']
models = {
    # untrained base model
    "base_unt": Wav2Vec2Model(base_config),
    # pretrained (only) base model
    "base_pret": Wav2Vec2Model.from_pretrained(base_pret_ckpt),
    # pretrained + finetuned base model
    "base_ft": Wav2Vec2Model.from_pretrained(base_ft_ckpt),
    # untrained large model
    "large_unt": Wav2Vec2Model(large_config),
    # pretrained (only) large model
    "large_pret": Wav2Vec2Model.from_pretrained(large_pret_ckpt),
    # pretrained + finetuned large model
    "large_ft": Wav2Vec2Model.from_pretrained(large_ft_ckpt),
}

# feature extractor is the same for all models
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_ft_ckpt)

## Speech Accent Archive recordings and word segments

To construct our stimulus set for representational similarity measures, we will use a subset of recordings from the [Speech Accent Archive](https://accent.gmu.edu/). The Speech Accent Archive contains recordings of many different speakers reading the same _elicitation paragraph_ which we refer to as 'Please call Stella'. This paragraph was constructed to contain a large variety of difficult English sounds and sound sequences, and read by speakers of English with a variety of native and non-native accents.  

In this notebook, we will be analyzing model representations on the level of _words_. To this end, we provide word-aligned transcriptions of a subset of the Speech Accent Archive recordings, obtained automatically using the [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/) tool and provided as .TextGrid files in the `please_call_stella` folder (downloaded above).

For demonstration, we here only use a small subset of 3 male North-American native English speakers. However, you can find more recordings with aligned transcripts in the same folder, in case you are interested in studying speaker or accent effects for your mini-project. Have a look at the `speaker_info` dataframe loaded below for details.


##### <input type="checkbox"/> <font color='green'>**ToDo 13**</font>: Inspecting Speech Accent Archive data

Load our `annotated_speaker_subset` from the Speech Accent Archive by running the cells below. Listen to an example recording and some extracted word segments.

In [None]:
# functions for reading in wav files and TextGrid annotations from our please_call_stella folder
def get_word_annotations(annotations_file):
  word_annotation_tier = tgt.io.read_textgrid(annotations_file).get_tier_by_name('words')
  annotations_dict = {
      'start_times': [intv.start_time for intv in word_annotation_tier],
      'end_times': [intv.end_time for intv in word_annotation_tier],
      'words': [intv.text for intv in word_annotation_tier]
  }
  return annotations_dict

def get_word_audios(full_audio, sr, words, word_start_times, word_end_times):
  word_audios = [
      full_audio[int(wst*sr):int(wet*sr)]
      for wst, wet in zip(word_start_times, word_end_times)
  ]
  return word_audios

In [None]:
# information about all speakers in this speech accent archive subset (e.g. age, gender and native language)
speaker_info = pd.read_csv(f'{STELLA_DIR}/speakers.csv', sep=';', index_col=0)

# we will create a subset containing only recordings by these speakers
speaker_subset_ids = ['english1', 'english51', 'english81']

# load the recordings and annotations for these speakers into a dict
annotated_speaker_subset = {
    speaker_id: {
        'audio': librosa.load(f'{STELLA_DIR}/{speaker_id}.wav', sr=SAMP_FREQ)[0]
    } | get_word_annotations(f'{STELLA_DIR}/{speaker_id}.TextGrid')
    for speaker_id in speaker_subset_ids
}

# list of id strings describing all audio segments (including the speaker and the word corresponding to each segment)
audio_ids = [f'{speaker_id}_{w}' for speaker_id in speaker_subset_ids for w in annotated_speaker_subset[speaker_id]['words']]

# raw audio signals (waveforms) for each word segment
word_audios = [wa
    for speaker_id in speaker_subset_ids
    for wa in get_word_audios(annotated_speaker_subset[speaker_id]['audio'],
                              SAMP_FREQ,
                              annotated_speaker_subset[speaker_id]['words'],
                              annotated_speaker_subset[speaker_id]['start_times'],
                              annotated_speaker_subset[speaker_id]['end_times'])
    ]

In [None]:
speaker_id = 'english1'
print(f'Full recording for {speaker_id}:')
ipd.display(Audio(annotated_speaker_subset['english1']['audio'], rate=SAMP_FREQ))

In [None]:
print(f'A few word segments:')
for audio_id, word_audio in list(zip(audio_ids, word_audios))[:3]:
  print(audio_id)
  ipd.display(Audio(word_audio, rate=SAMP_FREQ))

## Gram matrices for MFCCs and GloVe embeddings

The first step in our similarity analyses is to compute _gram matrices_ which capture the _representational similarity_ between each example in our dataset as computed using some kernel function over our chosen set of features (see a helpful visualization [here](https://youtu.be/TBjdvjdS2KM&t=219)).  

We will soon be computing these matrices for model internal layers, but we'll first compute them over two sets of somewhat more interpretable features: the average [Mel-frequency cepstral coefficients (MFCCs)](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum) over word segments (capturing the word's acoustics) and the [GloVe vectors](https://nlp.stanford.edu/projects/glove/) for each word in the paragraph (capturing word-level distributional information).

##### <input type="checkbox"/> <font color='green'>**ToDo 14**</font>: Inspecting Gram matrices with Linear and RBF kernels

Run the code cells below to plot the Gram matrices for the MFCC and GloVe features over all word segments in our subset of recordings. We can compute similarity between examples using a simple linear kernel (the dot product), but we can also use more complex non-linear kernels like the RBF kernel. What differences do you observe in the Gram matrices of these two feature spaces? What effect does the choice of kernel function have?

In [None]:
# glove embeddings for all words (3 times the same paragraph)
glove_emb_dict = pickle.load(open('stella_glove_embs.pkl', 'rb'))
glove_embs = np.vstack([glove_emb_dict[audio_id.split('_')[1]] for audio_id in audio_ids])

# MFCC features for all words (3 times the same paragraph, read by each speaker)
avg_MFCCs = np.vstack([librosa.feature.mfcc(y=wa, sr=SAMP_FREQ, n_mels=18, n_fft=256).mean(axis=1) for wa in word_audios])

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10,10))
axs[0,0].set_title('MFCCs, Linear kernel')
im = plot_gram_matrix(avg_MFCCs, speaker_subset_ids, axs[0,0], gram_fun=gram_linear)

axs[0,1].set_title('GloVe embeddings, Linear kernel')
im = plot_gram_matrix(glove_embs, speaker_subset_ids, axs[0,1], gram_fun=gram_linear)

axs[1,0].set_title('MFCCs, RBF kernel')
im = plot_gram_matrix(avg_MFCCs, speaker_subset_ids, axs[1,0], gram_fun=gram_rbf)

axs[1,1].set_title('GloVe embeddings, RBF kernel')
im = plot_gram_matrix(glove_embs, speaker_subset_ids, axs[1,1], gram_fun=gram_rbf)

plt.show()

## Gram matrices for Wav2Vec2 representations
We'll now proceed to use the same technique to analyze inter-example similarities at different layers of the Wav2Vec2 model. To do that, we first need to extract the audio frame representations for each of the word segments in our dataset. Running the code cells below will do that.

### Extracting model hidden states for word segments

In [None]:
class SaveOutput:
    def __init__(self):
        self.outputs = defaultdict()

    def __call__(self, name):
        def hook(module, module_in, module_out):
            self.outputs[name] = module_out.detach()
        return hook

    def clear(self):
        self.outputs = defaultdict()

def get_frame_states_for_segments(model, feature_extractor, full_waveform, segment_start_times, segment_end_times):
  """
  Extract layerwise audio frame representations for segments out of a single full waveform,
  indicated by the provided start & end times.
  """
  ## convert segment start & end times to model frame indices
  time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
  segment_start_frames = [int(np.floor(start_time / time_offset)) for start_time in segment_start_times]
  segment_end_frames = [int(np.ceil(end_time / time_offset)) for end_time in segment_end_times]

  ## register hooks at CNN output and Transformer embeds + layers
  save_output = SaveOutput()
  # for the ASR-finetuned model
  if type(model) == Wav2Vec2ForCTC:
      last_conv_layer = model.wav2vec2.feature_extractor.conv_layers[-1]
      last_conv_layer.activation.register_forward_hook(save_output('CNN'))
      model.wav2vec2.encoder.layer_norm.register_forward_hook(save_output('embeds'))
      for i, enc_layer in enumerate(model.wav2vec2.encoder.layers):
          enc_layer.final_layer_norm.register_forward_hook(save_output(f'T{i+1}'))
  # for the pretrained & untrained model
  elif type(model) == Wav2Vec2Model:
      last_conv_layer = model.feature_extractor.conv_layers[-1]
      last_conv_layer.activation.register_forward_hook(save_output('CNN'))
      model.encoder.layer_norm.register_forward_hook(save_output('embeds'))
      for i, enc_layer in enumerate(model.encoder.layers):
          enc_layer.final_layer_norm.register_forward_hook(save_output(f'T{i+1}'))

  # prepare inputs
  inputs = feature_extractor(full_waveform,
                             sampling_rate=SAMP_FREQ,
                             return_tensors="pt",
                             padding='longest').input_values.to(DEVICE)
  # forward pass
  model.eval()
  model.to(DEVICE)
  with torch.no_grad():
    model(inputs, output_hidden_states=True, output_attentions=False)

  ## store only the frame states between the segment starts & ends
  layer_segment_states = {layer: [] for layer in save_output.outputs.keys()}
  for layer in save_output.outputs.keys():
    # for the CNN output
    if layer.startswith('C'):
      for start_frame, end_frame in zip(segment_start_frames, segment_end_frames):
        layer_segment_states[layer].append(save_output.outputs[layer][:, :, start_frame:end_frame].swapaxes(1, 2).detach().cpu().numpy())
    # for the Transformer embeds + layer representations
    else:
      for start_frame, end_frame in zip(segment_start_frames, segment_end_frames):
        layer_segment_states[layer].append(save_output.outputs[layer][:, start_frame:end_frame, :].detach().cpu().numpy())

  del save_output
  gc.collect()
  torch.cuda.empty_cache()

  return layer_segment_states

def frame_states_over_dataset(models, feature_extractor, dataset):
  """
  Extract frame states for a range of models (provided in the `models` dict, which all work with feature_extractor),
  over all annotated word segments in the provided dataset (dict organized by speaker_id with audio, start_times, and end_times for
  each speaker, as in annotated_speaker_subset)
  """
  model_frame_states = {}
  N_speakers = len(dataset)
  for model_id, model in models.items():
    model_name = model.config._name_or_path if model.config._name_or_path else 'randomly initialized'
    print(f'Extracting states from {model_id} ({model_name})...')
    layer_segment_states = []
    for i, (speaker_id, speaker_data) in enumerate(dataset.items()):
      print(f'\tSpeaker {i+1}/{N_speakers} ({speaker_id})...')
      full_waveform = speaker_data['audio']
      word_start_times = speaker_data['start_times']
      word_end_times = speaker_data['end_times']
      layer_segment_states.append(get_frame_states_for_segments(model, feature_extractor, full_waveform, word_start_times, word_end_times))
    model_frame_states[model_id] = {
        layer: [segst for i in range(len(layer_segment_states)) for segst in layer_segment_states[i][layer]]
        for layer in layer_segment_states[0].keys()
    }
    del layer_segment_states
    gc.collect()
    torch.cuda.empty_cache()
  return model_frame_states

In [None]:
# extract frame states for all models in our comparison set and all speakers in our speaker subset
model_frame_states = frame_states_over_dataset(models, feature_extractor, annotated_speaker_subset)

### Computing mean-pooled word representations

##### <input type="checkbox"/> <font color='green'>**ToDo 15**</font>: Computing mean-pooled representations over word segments

To be able to use our kernel functions for computing similarities between words, we first need to create word representations out of the sets of frame states we have for each word segment. For now, we will simply take the mean over all frames within a word segment to create equally shaped vectors for each word segment in our dataset.  
Complete the function below to return a matrix with the mean frame state vector for each word segment.

In [None]:
def get_mean_frame_embs(frame_states):
  """
  :param frame_states:        list of frame states for audio segments, extracted from a single
                              model component; i.e. list of N_segments elements, each a np.array
                              of shape (1, N_frames, D_component), where D_component is 512 for
                              CNN output and 768 for Transformer representations

  :return mean_frame_embs:    np.array of shape (N_segments, D_component), containing the mean-
                              pooled frame embedding for every segment (i.e. the mean over frame
                              states across the segment)
  """
  raise NotImplementedError
  mean_frame_embs = # YOUR CODE HERE
  return mean_frame_embs

##### <input type="checkbox"/> <font color='green'>**ToDo 16**</font>: Inspecting Gram matrices of Wav2Vec2 embeddings

Now we are able to construct between-word similarity matrices for all models and layers of interest. Use the code cell below to inspect the Gram matrices for a few different models and layers. Also feel free to experiment with linear and RBF kernels again!

In [None]:
# choose from our set of loaded models from which we have extracted frame states
# ('base_unt', 'base_pret', 'base_ft', 'large_unt', 'large_pret', 'large_ft')
model_id = 'base_ft'

# choose from 'CNN', 'embeds', 'T{i}' (i = 1..12 for base and 1..24 for large)
model_component = 'CNN'

# choose between gram_linear and gram_rbf
gram_fun = gram_linear

# get word representations
model_embs = get_mean_frame_embs(model_frame_states[model_id][model_component])

# plot gram matrix for these word representations
fig, ax = plt.subplots()
im = plot_gram_matrix(model_embs, speaker_subset_ids, ax, gram_fun=gram_linear)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.7])
fig.colorbar(im, cax=cbar_ax)
plt.show()

## CKA similarity between model layers

Now that we have between-word similarity matrices for all models and components, we can compute the similarities between all those matrices with CKA! This allows us for example to compare different layers of the same architecture, and possibly detect apparent differences in processing between models, as done by [Kornblith et al. (2019)](https://proceedings.mlr.press/v97/kornblith19a/kornblith19a.pdf).  

##### <input type="checkbox"/> <font color='green'>**ToDo 17**</font>: Inspecting between-layer similarities across models

Run the cells below to plot the between-layer similarities for each of the models in our comparison set. What differences stand out between the untrained, pretrained and finetuned versions, and between the base and large models?

In [None]:
def compute_model_layer_cka(first_model_id, second_model_id, gram_fun=gram_linear):
  first_model_components = list(model_frame_states[first_model_id].keys())
  second_model_components = list(model_frame_states[second_model_id].keys())

  cka_df = pd.DataFrame(index=first_model_components, columns=second_model_components)
  for comp1 in first_model_components:
    for comp2 in second_model_components:
      comp1_embs = get_mean_frame_embs(model_frame_states[first_model_id][comp1])
      comp2_embs = get_mean_frame_embs(model_frame_states[second_model_id][comp2])
      cka_df.loc[comp1, comp2] = cka(gram_fun(comp1_embs), gram_fun(comp2_embs))

  return cka_df

def plot_model_layer_cka(cka_df, ax):
  first_model_components = list(cka_df.index)
  second_model_components = list(cka_df.columns)

  sns.heatmap(cka_df.values.astype(float), ax=ax)
  ax.set_yticks(np.arange(len(first_model_components))+0.5, first_model_components, rotation=0)
  ax.set_xticks(np.arange(len(second_model_components))+0.5, second_model_components, rotation=90)

In [None]:
cka_base_models = {
    'base_unt': compute_model_layer_cka('base_unt', 'base_unt'),
    'base_pret': compute_model_layer_cka('base_pret', 'base_pret'),
    'base_ft': compute_model_layer_cka('base_ft', 'base_ft')
}
cka_large_models = {
    'large_unt': compute_model_layer_cka('large_unt', 'large_unt'),
    'large_pret': compute_model_layer_cka('large_pret', 'large_pret'),
    'large_ft': compute_model_layer_cka('large_ft', 'large_ft')
}

In [None]:
# plot base models
fig, axs = plt.subplots(1,3, figsize=(16,4))
for i, (model_id, cka_df) in enumerate(cka_base_models.items()):
  plot_model_layer_cka(cka_df, axs[i])
  axs[i].set_title(model_id)
plt.show()

# plot large models
fig, axs = plt.subplots(1,3, figsize=(16,4))
for i, (model_id, cka_df) in enumerate(cka_large_models.items()):
  plot_model_layer_cka(cka_df, axs[i])
  axs[i].set_title(model_id)
plt.show()

## CKA similarity of Wav2Vec2 to MFCCs and GloVe

Next to computing similarities within and across model representations, we can also use CKA to compute similarity between model embeddings and more interpretable features.

### Layerwise CKA similarity to interpretable features

##### <input type="checkbox"/> <font color='green'>**ToDo 18**</font>: Inspecting layerwise similarities to MFCCs and GloVe across models

Run the code cells below to obtain layerwise CKA similarities between Wav2Vec2 representations and MFCC vs. GloVe features. How do these layerwise patterns differ between models?

In [None]:
def compute_layerwise_feature_cka(model_id, features, frame_agg_fun=get_mean_frame_embs, gram_fun=gram_linear):
  feature_cka = {
      model_component: cka(gram_fun(frame_agg_fun(model_frame_states[model_id][model_component])),
                           gram_fun(features))
      for model_component in model_frame_states[model_id].keys()
  }
  return feature_cka

def plot_layerwise_feature_cka(model_ids, layerwise_cka_sims, ax):
  for model_id, layerwise_cka in zip(model_ids, layerwise_cka_sims):
    ax.plot(list(layerwise_cka.values()), color=colors[model_id], linestyle=linestyles[model_id], label=model_id)
  ax.set_xticks(range(len(layerwise_cka.keys())), list(layerwise_cka.keys()), rotation=90)
  ax.legend(frameon=False)
  ax.set_ylabel('CKA similarity')

In [None]:
base_layerwise_MFCC_sims = [
    compute_layerwise_feature_cka(model_id, avg_MFCCs)
    for model_id in base_models
    ]
large_layerwise_MFCC_sims = [
    compute_layerwise_feature_cka(model_id, avg_MFCCs)
    for model_id in large_models
]
base_layerwise_GloVe_sims = [
    compute_layerwise_feature_cka(model_id, glove_embs)
    for model_id in base_models
    ]
large_layerwise_GloVe_sims = [
    compute_layerwise_feature_cka(model_id, glove_embs)
    for model_id in large_models
]

fig, axs = plt.subplots(1, 4, figsize=(20,4))
axs[0].set_title('Layerwise similarity to MFCCs (base models)')
plot_layerwise_feature_cka(base_models, base_layerwise_MFCC_sims, ax=axs[0])
axs[1].set_title('Layerwise similarity to MFCCs (large models)')
plot_layerwise_feature_cka(large_models, large_layerwise_MFCC_sims, ax=axs[1])
axs[2].set_title('Layerwise similarity to GloVe (base models)')
plot_layerwise_feature_cka(base_models, base_layerwise_GloVe_sims, ax=axs[2])
axs[3].set_title('Layerwise similarity to GloVe (large models)')
plot_layerwise_feature_cka(large_models, large_layerwise_GloVe_sims, ax=axs[3])
plt.show()

### Comparing mean and middle frame embeddings

##### <input type="checkbox"/> <font color='green'>**ToDo 19**</font>: Computing middle frame embeddings

In their paper about word-level information encoded in self-supervised speech model representations, [Pasad et al. (2024)](https://doi.org/10.1162/tacl_a_00656) investigate, amongst other things, which frames within a word segment seem most informative. One remarkable result from this analysis is that the middle frame of the word segment often seems as informative on its own as the mean-pooled word representation averaging all frames within a word segment (see e.g. Figure 4). We will here try to replicate this analysis by comparing similarity to GloVe embeddings between the mean-pooled representations we used so far and the middle frame embeddings.  
Complete the function below to return a matrix with the middle frame state vector for each word segment. Then examine the layerwise patterns of similarities to GloVe for mean and middle frame embeddings. Are they in line with the findings in Pasad et al.?


In [None]:
def get_middle_frame_embs(frame_states):
  """
  :param frame_states:        list of frame states for audio segments, extracted from a single
                              model component; i.e. list of N_segments elements, each a np.array
                              of shape (1, N_frames, D_component), where D_component is 512 for
                              CNN output and 768 for Transformer representations

  :return middle_frame_embs:  np.array of shape (N_segments, D_component), containing the middle
                              frame state for every segment
  """
  raise NotImplementedError
  middle_frame_indices = # YOUR CODE HERE
  middle_frame_embs = # YOUR CODE HERE
  return middle_frame_embs

In [None]:
def plot_emb_comparison(model_id, features, ax):
  mean_emb_cka_sims = compute_layerwise_feature_cka(model_id, features, frame_agg_fun=get_mean_frame_embs)
  middle_emb_cka_sims = compute_layerwise_feature_cka(model_id, features, frame_agg_fun=get_middle_frame_embs)
  ax.plot(list(mean_emb_cka_sims.values()), linestyle='solid', color='black', label='mean-pooled')
  ax.plot(list(middle_emb_cka_sims.values()), linestyle='dashed', color='black', label='middle frame')
  ax.set_ylabel('CKA similarity')
  ax.set_xticks(range(len(mean_emb_cka_sims.keys())), list(mean_emb_cka_sims.keys()), rotation=90)

fig, axs = plt.subplots(1, len(models), figsize=(28,4), sharey=True)
for i, model_id in enumerate(models.keys()):
  plot_emb_comparison(model_id, glove_embs, axs[i])
  axs[i].set_title(model_id)
  if i == 0:
    axs[i].legend(frameon=False, loc='upper left')
plt.suptitle('Layerwise similarity to GloVe embeddings', y=1)
plt.show()

##### <input type="checkbox"/> <font color='red'>**ToSubmit 5**</font>: Interpreting differences between models with CKA

Choose your favourite model comparison plot out of the ones generated above to include in your report.  
Add a brief (max. 2 sentences) caption explaining what insights this analysis provides on differences in speech sound processing between the untrained, pretrained, and finetuned versions of the base and large Wav2Vec2 models.

# 🚀 Mini-project starting points

This week you have learned about several techniques for analyzing the internals of Audio Transformer models. We hope you are excited to apply these and/or other techniques to further study speech and sound processing models for your mini-project! Below, we provide a few possible starting points for inspiration. Of course, you are also free to come up with your own ideas. But do confirm with your TA that they are realistic to pursue within the mini-project timeframe!

**1. Temporal generalization and context-invariant phonetic encoding**  
In week 1, you applied Temporal Generalization Matrices to compute the crossaccuracy of probes across different model layers. Temporal Generalization analysis has also been used to study the encoding of phoneme sequences across time in neural speech models and the human brain — see [Oli et al. (2024)](https://arxiv.org/pdf/2405.08237) for a study replicating neuroscience findings in a self-supervised RNN model. What is the window of phoneme decodability in Wav2Vec2, and how context-invariant are its phonetic encodings?

**2. Effects of noise or speaker accent on model representations and transcription accuracy**  
In the first notebook of this workshop, we saw that adding noise to the audio signal causes a decrease in Wav2Vec2's transcription accuracy (as measured by CER and WER). Similar performance decreases can be observed for other types of audio that the model wasn't exposed to during training, such as non-native accented speech or speech from other speaker minorities in the training dataset (e.g. women, children, elderly). It would be great if we could identify the Transformer components responsible for such biases, so that we might be able to mitigate them using targeted finetuning (as has been done for text models; e.g. [Chintam et al., 2023](https://aclanthology.org/2023.blackboxnlp-1.29/)). Possible steps towards this:
- Can you localize which model components show most representational sensitivity to changes in noise / speaker identity / speaker accent using similarity-based interpretability techniques? (e.g. using differences in similarities between word representations for native- and non-native accented speech)
- Can you localize which model components are causally involved in increases to WER/CER metrics observed for noisy audio / accented speakers etc.? (e.g. using causal interventions)

**3. Language-specific phonology and perceptual narrowing**  
Human speech perception is known to become tuned to language-specific experience over development: we start out being able to distinguish speech sounds from many different languages, but we become worse at distinguishing speech sounds from languages other than our native language over time. This phenomenon is known as [perceptual narrowing](https://en.wikipedia.org/wiki/Perceptual_narrowing#Phoneme_distinction).  
Can you find evidence of language-specific speech perception in Wav2Vec2 models trained on different languages?  
Relevant resources:
- French [LeBenchmark](https://huggingface.co/LeBenchmark) set of Wav2Vec2 models trained on different amounts of French speech data
- English [set of Wav2Vec2 checkpoints](https://huggingface.co/techsword/wav2vec2-base-english-librispeech730h-checkpoints/tree/main) over training released by [Shen et al. (2024)](https://aclanthology.org/2024.naacl-long.239/)
- The [VoxAngeles](https://github.com/pacscilab/voxangeles) Corpus with phonetic transcriptions for 95 different languages, and the [CommonPhone](https://zenodo.org/records/5846137) dataset with phonetic transcriptions of a few more common languages
- ABX tests of phone discrimination ability as used by [Millet & Dunbar (2022)](https://aclanthology.org/2022.acl-long.523/) and their [collection of stimuli](https://docs.cognitive-ml.fr/perceptimatic/perceptimatic_description.html) targeted at French vs. English speech sound contrasts