In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datasets import Dataset
from IPython.display import display
import numpy as np 
import os
import pandas as pd
from pathlib import Path
from pydub import AudioSegment
import torch
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor


import warnings
warnings.filterwarnings('ignore')

## Set seed
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)


from ferret import SpeechBenchmark

In [None]:
dataset_name = 'FSC'
data_dir = f'{str(Path.home())}/data/speech/fluent_speech_commands_dataset'

# We read the test data of FSC dataset
df = pd.read_csv(f"{data_dir}/data/test_data.csv")
df["path"] = df["path"].apply(lambda x: os.path.join(data_dir, x))

dataset = Dataset.from_pandas(df)

In [None]:
device_str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_str)

print(device)

In [None]:
## Load model
model = Wav2Vec2ForSequenceClassification.from_pretrained(
    "superb/wav2vec2-base-superb-ic"
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
    "superb/wav2vec2-base-superb-ic"
)

if torch.cuda.is_available():
    model = model.to(device)

In [None]:
## Instantiate benchmark class
benchmark = SpeechBenchmark(model, feature_extractor, device=device_str)

In [None]:
## Example
# 'transcription': 'Turn up the bedroom heat.'
# 'action': 'increase'
# 'object': 'heat'
# 'location': 'bedroom'

idx = 136
audio_path = dataset[idx]['path']

In [None]:
audio = AudioSegment.from_wav(audio_path)
display(audio)

Get a transcription.

In [None]:
import whisperx
from tqdm.notebook import tqdm

In [None]:
asr_model = whisperx.load_model('large-v2', device=device.type, compute_type='float16')
asr_alignment_model, metadata = whisperx.load_align_model(language_code='en', device=device.type)

In [None]:
nonaligned_transcriptions = {
    idx: asr_model.transcribe(np.array(audio.get_array_of_samples()).astype(np.float32), batch_size=16)
}

nonaligned_transcriptions

In [None]:
transcriptions = {
    idx: whisperx.align(
        nonaligned_transcription["segments"],
        asr_alignment_model,
        metadata,
        np.array(audio.get_array_of_samples()).astype(np.float32),
        device.type,
        return_char_alignments=False
    )
    for idx, nonaligned_transcription in nonaligned_transcriptions.items()
}

In [None]:
if transcriptions is None:
    transcriptions_file = f'./transcriptions_{dataset_name}.pickle'
    
    import pickle
    # Load the transcriptions, if available
    if os.path.exists(transcriptions_file):
        with open(transcriptions_file, "rb") as handle:
            transcriptions = pickle.load(handle)
    else:
        transcriptions = {}

In [None]:
word_transcript = transcriptions[idx]['word_segments'] if transcriptions else None

word_transcript

# Explain word importance

In [None]:
explanation = benchmark.explain(
    audio_path=audio_path, 
    methodology='LOO', words_trascript=word_transcript)

In [None]:
display(benchmark.show_table(explanation, decimals=3))

In [None]:
explanation = benchmark.explain(
    audio_path=audio_path, 
    methodology='LIME', words_trascript=word_transcript)

display(benchmark.show_table(explanation, decimals=3))

In [None]:
from ferret import AOPC_Comprehensiveness_Evaluation_Speech, AOPC_Sufficiency_Evaluation_Speech

aopc_compr = AOPC_Comprehensiveness_Evaluation_Speech(benchmark.model_helper)
evaluation_output_c = aopc_compr.compute_evaluation(explanation, words_trascript=word_transcript)

aopc_suff = AOPC_Sufficiency_Evaluation_Speech(benchmark.model_helper)
evaluation_output_s = aopc_suff.compute_evaluation(explanation, words_trascript=word_transcript)

evaluation_output_c, evaluation_output_s

# Explain paralinguistic impact

In [None]:
explain_table = benchmark.explain(
    audio_path=audio_path,
    methodology='perturb_paraling',
)
display(benchmark.show_table(explain_table, decimals=2))

# Show variation

In [None]:
perturbation_types = ['time stretching', 'pitch shifting', 'reverberation', 'noise']
variations_table = benchmark.explain_variations(
    audio_path=audio_path,
    perturbation_types=perturbation_types
)

In [None]:
variations_table_plot = {k:variations_table[k] for k in variations_table if k in ['time stretching', 'pitch shifting', 'noise']}
fig = benchmark.plot_variations(variations_table_plot, show_diff = True, figsize=(4.6, 4.2));
# fig.savefig(f'example_{dataset_name}_context.pdf', bbox_inches='tight')