# Speech XAI

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import Dataset, load_dataset
from IPython.display import display
import numpy as np 
import os
import pandas as pd
from pathlib import Path
from pydub import AudioSegment
import torch
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor

from ferret import SpeechBenchmark, AOPC_Comprehensiveness_Evaluation_Speech, AOPC_Sufficiency_Evaluation_Speech

  from .autonotebook import tqdm as notebook_tqdm
  torchaudio.set_audio_backend("soundfile")
torchvision is not available - cannot save figures


In [3]:
DATASET_ID = "DynamicSuperb/IntentClassification_FluentSpeechCommands-Action"

In [4]:
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
print(device)

cuda:2


## Data

In [5]:
data = load_dataset(DATASET_ID, split="test")
data

Dataset({
    features: ['file', 'speakerId', 'transcription', 'audio', 'label', 'instruction'],
    num_rows: 200
})

In [6]:
sample = data[0]
sample

{'file': 'wavs/speakers/Xygv5loxdZtrywr9/77506ae0-452b-11e9-a843-8db76f4b5e29.wav',
 'speakerId': 'Xygv5loxdZtrywr9',
 'transcription': 'Increase the temperature in the washroom',
 'audio': {'path': '77506ae0-452b-11e9-a843-8db76f4b5e29.wav',
  'array': array([0.        , 0.        , 0.        , ..., 0.02133179, 0.01977539,
         0.01849365]),
  'sampling_rate': 16000},
 'label': 'increase',
 'instruction': 'Recognize the action behind the verbal expression. The answer could be activate, bring, change language, deactivate, decrease, or increase.'}

In this notebook we are using Wav2Vec2 which expects audio arrays to be in 16kHz. Luckly, this is the native sampling rate of our data. 

## Models

In [7]:
## Load model
model = Wav2Vec2ForSequenceClassification.from_pretrained(
    "superb/wav2vec2-base-superb-ic"
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
    "superb/wav2vec2-base-superb-ic"
)

if torch.cuda.is_available():
    model = model.to(device)

Some weights of the model checkpoint at superb/wav2vec2-base-superb-ic were not used when initializing Wav2Vec2ForSequenceClassification: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at superb/wav2vec2-base-superb-ic and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_em

## Speech-XAI: the `SpeechBenchmark` class

Note: if not specified otherwise, `SpeechBenchmark` assumes English as the source language.

In [8]:
## Instantiate benchmark class
benchmark = SpeechBenchmark(model, feature_extractor, device=device)

Let's start from transcribing the example above using WhisperX.

In [9]:
text, word_timestamps = benchmark.transcribe(
    sample["audio"]["array"],
    current_sr=sample["audio"]["sampling_rate"],
)
text, word_timestamps

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.2.1. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.2.1+cu121. Bad things might happen unless you revert torch to 1.x.


(' Increase the temperature in the washroom.',
 [{'word': 'Increase', 'start': 0.737, 'end': 1.02, 'score': 0.438},
  {'word': 'the', 'start': 1.04, 'end': 1.121, 'score': 0.141},
  {'word': 'temperature', 'start': 1.141, 'end': 1.526, 'score': 0.444},
  {'word': 'in', 'start': 1.546, 'end': 1.627, 'score': 0.848},
  {'word': 'the', 'start': 1.647, 'end': 1.728, 'score': 0.953},
  {'word': 'washroom.', 'start': 1.768, 'end': 2.132, 'score': 0.588}])

## Explain word importance

### Word importance

In [19]:
explanation = benchmark.explain(
    audio_path_or_array=sample["audio"]["array"],
    current_sr=sample["audio"]["sampling_rate"],
    methodology='LOO',
    word_timestamps=word_timestamps
)
# display(benchmark.show_table(explanation, decimals=3))
print(explanation)

ExplanationSpeech(features=['Increase', 'the', 'temperature', 'in', 'the', 'washroom.'], scores=array([[ 0.47325948, -0.45515063, -0.10200211, -0.15734437, -0.12148061,
         0.0109534 ],
       [ 0.07733697, -0.02064097,  0.34651279, -0.01588559, -0.01463729,
        -0.02365428],
       [-0.01432282, -0.01848161, -0.00988954, -0.00070852, -0.01123005,
         0.32860303]]), explainer='loo_speech+silence', target=[3, 4, 3], audio=<ferret.speechxai_utils.FerretAudio object at 0x7fb7ed858130>)


In [20]:
explanation = benchmark.explain(
    audio_path_or_array=sample["audio"]["array"],
    current_sr=sample["audio"]["sampling_rate"], 
    methodology='LIME',
    word_timestamps=word_timestamps
)
print(explanation)
#display(benchmark.show_table(explanation, decimals=3))

ExplanationSpeech(features=['Increase', 'the', 'temperature', 'in', 'the', 'washroom.'], scores=array([[ 0.30518979, -0.05905298,  0.02406042,  0.06312685, -0.01027066,
         0.00634839],
       [-0.00192933,  0.04791304,  0.30365684,  0.01351917, -0.02577572,
         0.13388124],
       [ 0.07868745, -0.02967894,  0.21510287,  0.02970933,  0.03952176,
         0.44306288]]), explainer='LIME+silence', target=[3, 4, 3], audio=<ferret.speechxai_utils.FerretAudio object at 0x7fb7ed7f19d0>)


We can run the same function but with no word timestamps. The class will generate them automatically.

In [25]:
type(sample["audio"]["sampling_rate"])

int

In [29]:
explanation = benchmark.explain(
    audio_path_or_array=sample["audio"]["array"],
    current_sr=sample["audio"]["sampling_rate"], 
    methodology='LIME',
)
print(explanation)
#display(benchmark.show_table(explanation, decimals=3))

Transcribing audio to get word level timestamps...


Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.2.1. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.2.1+cu121. Bad things might happen unless you revert torch to 1.x.
Transcribed audio with whisperX into:  Increase the temperature in the washroom.
ExplanationSpeech(features=['Increase', 'the', 'temperature', 'in', 'the', 'washroom.'], scores=array([[ 2.73476301e-01, -2.75996308e-02,  2.68968859e-02,
         4.38230033e-02, -9.83693653e-03,  3.43606501e-02],
       [-4.55664511e-02,  2.00727565e-04,  3.07805104e-01,
        -7.30904579e-03,  8.18154319e-03,  1.45066594e-01],
       [ 7.67946057e-02, -1.63121582e-02,  1.69544374e-01,
         1.03233484e-02,  6.95427995e-02,  4.02942428e-01]]), explainer='LIME+silence', target=[3, 4, 3], audio=<ferret.speechxai_utils.FerretAudio object at 0x7fb7ec504c40>, word_timestamps=[{'word': 'Increase', 'start': 0.737, 'end': 1.02, 'score': 0.438}, {'word': 'the', 'sta

In [30]:
aopc_compr = AOPC_Comprehensiveness_Evaluation_Speech(benchmark.model_helper)
evaluation_output_c = aopc_compr.compute_evaluation(explanation)

aopc_suff = AOPC_Sufficiency_Evaluation_Speech(benchmark.model_helper)
evaluation_output_s = aopc_suff.compute_evaluation(explanation)

evaluation_output_c, evaluation_output_s

(EvaluationSpeech(name='aopc_compr_speech', score=[0.32901989901438355, 0.4174739196896553, 0.5148161690682173], target=[3, 4, 3]),
 EvaluationSpeech(name='aopc_suff', score=[0.17665663920342922, -0.009631142020225525, -0.01769007444381714], target=[3, 4, 3]))

## Explain paralinguistic impact

In [32]:
explain_table = benchmark.explain(
    audio_path_or_array=sample["audio"]["array"],
    current_sr=sample["audio"]["sampling_rate"],
    methodology='perturb_paraling',
)
display(benchmark.show_table(explain_table, decimals=2))

Transcribing audio to get word level timestamps...


Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.2.1. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.2.1+cu121. Bad things might happen unless you revert torch to 1.x.
Transcribed audio with whisperX into:  Increase the temperature in the washroom.


TypeError: compute_explanation() got an unexpected keyword argument 'audio'

## Show variation

In [None]:
perturbation_types = ['time stretching', 'pitch shifting', 'reverberation', 'noise']
variations_table = benchmark.explain_variations(
    audio_path=audio_path,
    perturbation_types=perturbation_types
)

In [None]:
variations_table_plot = {k:variations_table[k] for k in variations_table if k in ['time stretching', 'pitch shifting', 'noise']}
fig = benchmark.plot_variations(variations_table_plot, show_diff = True, figsize=(4.6, 4.2));
# fig.savefig(f'example_{dataset_name}_context.pdf', bbox_inches='tight')