# Workshop Week 1: Introduction to Posthoc Interpretability

*Interpretability & Explainability in AI, MSc A.I., University of Amsterdam, June 2022*

## Lab 2: Probing Audio Models

In the previous notebook, you have trained probes on the hidden states of a fine-tuned RoBERTa model, to gain insights into how sentiment is represented in this model across layers. In this notebook, you will perform a similar analysis on an **audio-based language model: Wav2Vec2**. This self-supervised model learns powerful speech representations from raw audio data and can be applied to many downstream tasks, including **Automatic Speech Recognition (ASR)**.

Before the rise of deep learning, ASR was performed using a pipeline of different components, each performing a subpart of the task. The ASR pipeline includes the following steps:

1.   **Feature extraction**: extracting relevant time-frequency information from the raw audio signal.
2.   **Acoustic modelling**: mapping the extracted audio features to a sequence of [phonemes](https://prowritingaid.com/phoneme), the smallest meaningful units in speech. Phonemes only refer to sounds and do not necessarily match with written letters. For example, consider the words *fit*, *phone*, and *laugh*. These words all contain the same phoneme: /f/. Since the same phoneme can be written in several different ways (*f*, *ph*, *gh*), it is useful for an ASR model to have an intermediate phoneme representation.
3.   **Language modelling**: mapping the sequence of phonemes to a sequence of written words.

Self-supervised speech models such as Wav2Vec2 are able to perform all of the above steps in an **end-to-end** fashion. In this assignment, we will try to find evidence for the implicit execution of the second subcomponent: acoustic modelling. We will therefore be probing the layers of the model for **phoneme information**.







## Configuration

Don't forget to enable the GPU runtime at the top! (Runtime -> Change runtime type)

Install necessary packages:

In [None]:
!pip install datasets==1.18.3 # we need this specific version of datasets in order to load the TIMIT data
!pip install transformers

Connect to the GPU:

In [None]:
from tqdm import *
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Set a random seed for reproducibility of the experiments:

In [None]:
import numpy as np
import random

def set_seed(seed):
    """Set random seed."""
    if seed == -1:
        seed = random.randint(0, 1000)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # if you are using GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_seed(42)

# The TIMIT dataset

We will use the [TIMIT Acoustic-Phonetic Continuous Speech Corpus](https://catalog.ldc.upenn.edu/LDC93S1), which contains sentence recordings of 630 speakers of eight major American-English dialects. Each speaker read aloud the same ten sentences, which were designed to elicit a wide variety of speech sounds. The corpus includes time-aligned transcriptions, as well as the raw waveform for each spoken sentence, sampled at a rate of 16 kHz.

<font color='green'>**ToDo1**</font>

Load the TIMIT corpus by running the cell below.

In [None]:
from datasets import load_dataset, load_metric

timit = load_dataset("timit_asr", "clean").shuffle(seed=42)

# Print number of train and test samples
print(len(timit['train']), len(timit['test']))

<font color='green'>**ToThink1**</font>

Look at some examples of the data and try to understand the annotations.

What information is annotated in ```phonetic_detail?``` Tip: an explanation of the labels can be found [here](https://catalog.ldc.upenn.edu/docs/LDC93S1/PHONCODE.TXT).



In [None]:
# YOUR CODE HERE

# The Wav2Vec2 Model

The Wav2Vec2 model was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.

The model takes raw audio waveforms as input and splits them into fixed-size frames of 25 milliseconds. During pre-training, some of these frames are masked and the model has to predict the correct speech unit for the masked position. In doing so, the model learns powerful speech representations in a self-supervised manner.

After pre-training, the model can be fine-tuned for several downstream tasks. We will investigate a model version that is fine-tuned for **Automatic Speech Recognition**, i.e. predicting written transcriptions that correspond to the spoken input. More specifically, the model was fine-tuned to predict a character for each of the 25ms-frames that we discussed above. These characters are then collapsed into well-formed transcriptions using [Connectionist Temporal Classification](https://distill.pub/2017/ctc/).





<font color='green'>**ToThink2**</font>

Read [this blogpost](https://jonathanbgn.com/2021/09/30/illustrated-wav2vec-2.html) about Wav2Vec2 and make sure you understand the different components inside the model.

<font color='green'>**ToDo2**</font>

Load the model and processor by running the cell below. Examine the architecture carefully.

In [None]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

# Load model and processor
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

# Set model to evaluation mode
model.eval()
model.to(DEVICE)

## Prepare model input and analyze transcriptions

<font color='green'>**ToDo3**</font>

The Wav2Vec 2.0 model takes raw waveforms as input. Select one waveform from the data, plot the signal and listen to the audio.

In [None]:
import matplotlib.pyplot as plt
import IPython.display as ipd
import numpy as np
import librosa

# Select a data item you want to examine
item_index = 0

# Retrieve the waveform from the data
raw_waveform = # YOUR CODE HERE

# Plot the raw waveform
plt.figure(figsize=(15,5))
librosa.display.waveshow(raw_waveform, sr=16000, alpha=0.4)

# Print corresponding text
text = # YOUR CODE HERE
print("Text:", text)

# Listen to the audio
ipd.Audio(data=np.asarray(raw_waveform), autoplay=False, rate=16000)

<font color='green'>**ToDo4**</font>

Let the model generate a transcription for the waveform. Are there any mistakes?

<font color='red'>**ToSubmit1**</font>

Please submit your plot of the waveform signal, together with the input text and the transcription that Wav2Vec2 generated for that particular waveform.

In [None]:
# Tokenize
input_values = processor(raw_waveform, sampling_rate=16000, return_tensors="pt", padding="max_length", max_length=1000).input_values  # Batch size 1
print("Input shape:", input_values.shape)

# Generate transcription with the model
# YOUR CODE HERE

# Extract hidden states

We will now extract the hidden states from the model's Transformer layers. These hidden states will serve as the training and evaluation data for our probing classifiers.

Concretely, we will perform the following steps to achieve this:

1.  **Prepare input**: Retrieve the raw waveforms (i.e. audio arrays) from the TIMIT corpus and process them using the Wav2Vec2 processor.
2.  **Forward pass**: Pass the waveforms through the model with ```output_hidden_states``` set to True.
3. **Save hidden states**: Save the hidden states in a dictionary, which is organized by Transformer layer index and waveform index. Each waveform will have list of frame-level hidden states.
4. **Sort hidden states per phoneme class**: Sort the hidden states by phoneme class using the time-aligned transcriptions from TIMIT.



<font color='green'>**ToDo5**</font>

Finish the function below to extract the hidden states from the Transformer layers of the model. Hint: You can simply extract hidden states by calling ```.hidden_states``` on the model output.



In [None]:
def extract_hidden_states(model, processor, inputs, num_layers):
    '''
    Extract hidden states from Wav2Vec 2.0 transformer layers.
    :param model: Wav2Vec 2.0 model
    :param processor: Wav2Vec 2.0 processor
    :param inputs: list of TIMIT instances (i.e. timit['train'] or timit['test'])
    :return: dictionary containing frame-level hidden states saved per transformer layer and per waveform
    '''

    # Get waveforms
    waveforms = [input["audio"]["array"] for input in inputs]

    # Here we will save all frame-level hidden states, sorted by layer and waveform
    frame_states = {
        layer_idx: {
            waveform_idx: []
            for waveform_idx in range(len(waveforms))
        }
        for layer_idx in range(num_layers)
    }

    for waveform_idx, waveform in enumerate(waveforms):

        print(f'Extracting hidden states from waveform {waveform_idx} out of {len(waveforms)} waveforms...')

        # Process waveform using the Wav2Vec2 processor
        processed_input = processor(waveform, sampling_rate=16000, return_tensors="pt", padding='longest').input_values

        with torch.no_grad():
            input_tensor = torch.tensor(processed_input, device=DEVICE)

            # forward pass
            model_output = # YOUR CODE HERE

            # get all hidden outputs
            transformer_layers = # YOUR CODE HERE

        # Save frame-level hidden states, organized by layer and waveform
        for layer_idx, layer in enumerate(transformer_layers):
            for waveform in layer:
              for frame in waveform:
                  frame_states[layer_idx][waveform_idx].append(frame.cpu())

    return frame_states

Function to sort the hidden states by phoneme class using the time-aligned transcriptions from TIMIT:

In [None]:
def sort_states_per_phoneme(data, frame_states, num_layers):

    frame_states_per_phoneme = {
        layer_idx: defaultdict(list)
        for layer_idx in range(num_layers)
    }

    for layer_idx, layer in frame_states.items():

        for waveform_idx, waveform in layer.items():

            # retrieve phoneme annotation for the current waveform
            phonemes = data[waveform_idx]["phonetic_detail"]
            phoneme_indeces = defaultdict(list)

            for start, stop, phoneme in zip(phonemes['start'], phonemes['stop'], phonemes['utterance']):

                # divide start and stop point by sample rate (16000 hz) and frame length (0.020 sec)
                start_index = math.floor((start / 16000) / 0.020)
                stop_index = math.ceil((stop / 16000) / 0.020)
                phoneme_indeces[phoneme].extend(range(start_index, stop_index))

            # find hidden states corresponding to phoneme indeces and save them per layer and per phoneme
            for phoneme, indeces in phoneme_indeces.items():
                phoneme_states = [waveform[idx] for idx in indeces if idx < len(waveform)]
                frame_states_per_phoneme[layer_idx][phoneme].extend(phoneme_states)

    return frame_states_per_phoneme

<font color='green'>**ToDo6**</font>

Extract the hidden states for training and testing our probes (this might take some time):

In [None]:
import math
from collections import defaultdict

set_seed(42)

# We select a relatively small number of sentences since they will be split up in a large number of frames
train_size = 800
test_size = 100
num_layers = 13 # input plus 12 transformer layers

# Generate random indeces to select a subset of the train and test data
train_indeces = random.sample(range(0, len(timit['train'])), train_size)
test_indeces = random.sample(range(0, len(timit['test'])), test_size)

train_subset = timit['train'].select(train_indeces)
test_subset = timit['test'].select(test_indeces)
print(len(train_subset), len(test_subset))

# Extract frame-level hidden states for training and testing the probes (make sure you pass the data subset)
frame_states_train = # YOUR CODE HERE
frame_states_test = # YOUR CODE HERE

# Sort the hidden states by phoneme (make sure you pass the data subset)
phoneme_states_train = # YOUR CODE HERE
phoneme_states_test = # YOUR CODE HERE

# Map fine-grained phoneme labels to broader categories

Phonemes can have different realizations depending on the context in which they occur, or depending on the dialect of the speaker. TIMIT contains annotations for many of these realizations (61 in total). We will merge phonemes that sound very similar into a single category, such that we end up with 39 broader categories in total.

In [None]:
# The original labels
print(phoneme_states_train[0].keys())

In [None]:
phoneme_mapping = {
    'p': 'p',
    'b': 'b',
    't': 't',
    'd': 'd',
    'k': 'k',
    'g': 'g',
    'dx': 'dx',
    'f': 'f',
    'v': 'v',
    'dh': 'dh',
    'th': 'th',
    's': 's',
    'z': 'z',
    'r': 'r',
    'w': 'w',
    'y': 'y',
    'jh': 'jh',
    'ch': 'ch',
    'iy': 'iy',
    'eh': 'eh',
    'ey': 'ey',
    'ae': 'ae',
    'aw': 'aw',
    'ay': 'ay',
    'oy': 'oy',
    'ow': 'ow',
    'uh': 'uh',
    'ah': 'ah',
    'ax': 'ah',
    'ax-h': 'ah',
    'aa': 'aa',
    'ao': 'aa',
    'er': 'er',
    'axr': 'er',
    'hh': 'hh',
    'hv': 'hh',
    'ih': 'ih',
    'ix': 'ih',
    'l': 'l',
    'el': 'l',
    'm': 'm',
    'em': 'm',
    'n': 'n',
    'en': 'n',
    'nx': 'n',
    'ng': 'ng',
    'eng': 'ng',
    'sh': 'sh',
    'zh': 'sh',
    'uw': 'uw',
    'ux': 'uw',
    'pcl': 'sil',
    'bcl': 'sil',
    'tcl': 'sil',
    'dcl': 'sil',
    'kcl': 'sil',
    'gcl': 'sil',
    'h#': 'sil',
    'pau': 'sil',
    'epi': 'sil'
}

# Train and evaluate probing classifiers

It is now time to train and evaluate our probes. We will define a probing classifier for each layer of the Wav2Vec2 model to see how phoneme information is represented across layers. We will be using Logistic Regression as our classification models (but feel free to experiment with other classifiers such as SVM).

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

import warnings

warnings.filterwarnings("ignore")

# Define layer-wise probes, to be trained and tested on frame-level Wav2Vec2 embeddings
layer_probes = {
    layer_idx: LogisticRegression(solver="liblinear", penalty="l2", max_iter=10)
    for layer_idx in range(num_layers)
}

Helper functions for putting the data in the right format for Logistic Regression models, and for balancing the data:

In [None]:
def data_loader(hidden_state_dict, layer_idx, target_phonemes=None):
    X = []
    y = []

    for phoneme, hidden_state_list in hidden_state_dict[layer_idx].items():
        for i in hidden_state_list:
            if i != None and phoneme != 'q':
                if target_phonemes == None:
                    X.append(np.array(i))
                    y.append(phoneme_mapping[phoneme])
                else:
                    if phoneme in target_phonemes:
                        X.append(np.array(i))
                        y.append(phoneme_mapping[phoneme])

    return X, y

def balance_classes(X_instances, y_labels):

    balanced_data_X = []
    balanced_data_y = []

    class_distribution = Counter(y_labels)
    num_instances_per_class = min(class_distribution.values())

    for label in class_distribution.keys():

        i = 1
        instances = []
        labels = []

        for x, y in zip(X_instances, y_labels):
            if y == label:
                instances.append(x)
                labels.append(y)
                if i == num_instances_per_class:
                    balanced_data_X.extend(instances)
                    balanced_data_y.extend(labels)
                    break
                i += 1

    return balanced_data_X, balanced_data_y

<font color='green'>**ToDo7**</font>

Train and evaluate the probing classifier for each layer (this will take some time). Examine the difference in probing accuracy when predicting different categories of phonemes (i.e. stops, fricatives, nasals, glides, vowels). The different categories are explained [here](http://www.ello.uos.de/field.php/PhoneticsandPhonology/MannerOfArticulation).

In [None]:
from collections import Counter

phoneme_dict = {
    'stops': ['p', 't', 'k', 'b', 'd', 'g'],
    'fricatives': ['f', 'v', 'th', 'dh', 's', 'z', 'sh'],
    'nasals': ['m', 'n', 'ng'],
    'approximants': ['l', 'w', 'y', 'hh'],
    'vowels': ['aa', 'ow', 'iy', 'eh', 'uh']
}

# Save layer-wise accuracies for each phoneme category
accs_per_category = []

for phoneme_category, target_phonemes in phoneme_dict.items():

  print(f"Training layer-wise probes to predict {phoneme_category}...")
  accs = []

  # Train and test an individual probe for each layer
  for layer_idx in range(num_layers):

    # Put train data in the right format for our probing classifier + balance classes
    train_X, train_y = data_loader(phoneme_states_train, layer_idx, target_phonemes=target_phonemes)
    train_X, train_y = balance_classes(train_X, train_y)

    if layer_idx == 0:
      print("Balanced class distribution TRAIN:", Counter(train_y))

    # Train model
    # YOUR CODE HERE

    # Put test data in the right format for our probing classifier + balance classes
    test_X, test_y = data_loader(phoneme_states_test, layer_idx, target_phonemes=target_phonemes)
    test_X, test_y = balance_classes(test_X, test_y)

    if layer_idx == 0:
      print("Balanced class distribution TEST:", Counter(test_y))

    # Predict
    test_pred = # YOUR CODE HERE

    # Calculate accuracy
    test_acc = accuracy_score(test_y, test_pred)
    print(f'Accuracy for layer {layer_idx}:', test_acc)
    accs.append(test_acc)

  accs_per_category.append(accs)

<font color='green'>**ToDo8**</font>

Plot the layer-wise accuracies per phoneme category.

<font color='green'>**ToThink3**</font>

Does the layer-wise evolution of phoneme information match with the subcomponents in the traditional ASR pipeline?

<font color='red'>**ToSubmit2**</font>

Please submit your plot of the probing results. Make sure to add a caption in which you briefly explain the observed pattern of the probing accuracies over layers and how this relates to the traditional ASR pipeline.

In [None]:
for accs in accs_per_category:

  # Plot layer-wise probing accuracy per phoneme category
  plt.plot(range(num_layers), accs, marker='.')
  plt.title(f'Representation of phonemes in Wav2Vec 2.0', fontsize=16)
  plt.xlabel('Wav2Vec 2.0 layer', fontsize=12)
  plt.xticks(range(num_layers), range(1, 14), fontsize=12)
  plt.ylabel('Probing accuracy', fontsize=12)
  plt.legend(list(phoneme_dict.keys()))
  plt.ylim((0.5, 1.0))

plt.show()