# Speech XAI

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datasets import Dataset
from IPython.display import display
import numpy as np 
import os
import pandas as pd
from pathlib import Path
from pydub import AudioSegment
import torch
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor


import warnings
warnings.filterwarnings('ignore')

## Set seed
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)


from ferret import SpeechBenchmark, AOPC_Comprehensiveness_Evaluation_Speech, AOPC_Sufficiency_Evaluation_Speech

## Data

In [None]:
dataset_name = 'FSC'
data_dir = f'{str(Path.home())}/data/speech/fluent_speech_commands_dataset'

# We read the test data of FSC dataset
df = pd.read_csv(f"{data_dir}/data/test_data.csv")
df["path"] = df["path"].apply(lambda x: os.path.join(data_dir, x))

dataset = Dataset.from_pandas(df)

## Models

In [None]:
device_str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_str)

print(device)

In [None]:
## Load model
model = Wav2Vec2ForSequenceClassification.from_pretrained(
    "superb/wav2vec2-base-superb-ic"
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
    "superb/wav2vec2-base-superb-ic"
)

if torch.cuda.is_available():
    model = model.to(device)

## Speech-XAI: the `SpeechBenchmark` class

In [None]:
## Instantiate benchmark class
benchmark = SpeechBenchmark(model, feature_extractor, device=device_str)

In [None]:
## Example
# 'transcription': 'Turn up the bedroom heat.'
# 'action': 'increase'
# 'object': 'heat'
# 'location': 'bedroom'

idx = 136
audio_path = dataset[idx]['path']

audio = AudioSegment.from_wav(audio_path)
display(audio)

## Explain word importance

### Word importance

In [None]:
explanation = benchmark.explain(
    audio_path_or_array=audio_path, 
    methodology='LOO'
)

display(benchmark.show_table(explanation, decimals=3))

In [None]:
explanation = benchmark.explain(
    audio_path_or_array=audio_path, 
    methodology='LIME'
)

display(benchmark.show_table(explanation, decimals=3))

In [None]:
aopc_compr = AOPC_Comprehensiveness_Evaluation_Speech(benchmark.model_helper)
evaluation_output_c = aopc_compr.compute_evaluation(explanation)

aopc_suff = AOPC_Sufficiency_Evaluation_Speech(benchmark.model_helper)
evaluation_output_s = aopc_suff.compute_evaluation(explanation)

evaluation_output_c, evaluation_output_s

### Working with transcriptions explicitly

`Ferret` offers an interface with ASR (automatic speech recognition) models from [`WhisperX`](https://github.com/m-bain/whisperX) in the form of the `transcribe_audio` function. This is called from within `Ferret` and there's no need to access it explicitly. Nevertheless, should the need arise, here's how to generate the word-level transcript (with time alignments for the audio part) used internally by the `SpeechBenchmark.evaluate` method.

In [None]:
from ferret.explainers.explanation_speech.utils_removal import transcribe_audio

In [None]:
text, words_trascript = transcribe_audio(
    audio_path=audio_path,
    device=device.type,
    batch_size=2,
    compute_type="float32",
    language='en'
)

In [None]:
explanation = benchmark.explain(
    audio_path=audio_path, 
    methodology='LOO',
    # Transcripts are passed explicitly.
    words_trascript=words_trascript
)

display(benchmark.show_table(explanation, decimals=3))

## Explain paralinguistic impact

In [None]:
explain_table = benchmark.explain(
    audio_path=audio_path,
    methodology='perturb_paraling',
)
display(benchmark.show_table(explain_table, decimals=2))

## Show variation

In [None]:
perturbation_types = ['time stretching', 'pitch shifting', 'reverberation', 'noise']
variations_table = benchmark.explain_variations(
    audio_path=audio_path,
    perturbation_types=perturbation_types
)

In [None]:
variations_table_plot = {k:variations_table[k] for k in variations_table if k in ['time stretching', 'pitch shifting', 'noise']}
fig = benchmark.plot_variations(variations_table_plot, show_diff = True, figsize=(4.6, 4.2));
# fig.savefig(f'example_{dataset_name}_context.pdf', bbox_inches='tight')