# Illustrating all methods

In [1]:
from audiointerp.dataset.esc50 import ESC50dataset
from audiointerp.model.cnn14 import TransferCnn14
from audiointerp.fit import Trainer
from audiointerp.processing.spectrogram import LogMelSTFTSpectrogram, LogSTFTSpectrogram
from audiointerp.interpretation.saliency import SaliencyInterpreter
from audiointerp.interpretation.gradcam import GradCAMInterpreter
from audiointerp.interpretation.shap import SHAPInterpreter
from audiointerp.interpretation.lime import LIMEInterpreter
import torchaudio
import torchaudio.functional as F
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T_audio
import torchvision.transforms as T_vision
import torch
from torch.utils.data import DataLoader
import librosa
import matplotlib.pyplot as plt
import random
import numpy as np
from IPython.display import Audio
from audiointerp.predict import Predict
from audiointerp.metrics import Metrics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
orig_sr = 44100
num_classes = 50
num_bins_stft = 257
num_bins_mel = 64
test_folds = [5]
root_dir = "/root/ESC50"

In [3]:
sr_stft = 16000
sr_mel = 32000

In [4]:
fit_extr_kwargs_stft_test = {
    "n_fft": 512,
    "hop_length": 256,
    "win_length": 512,
    "top_db": None,
    "return_phase": False,
    "return_pre_db": False,
    "return_full_db": False
}

In [5]:
fit_extr_kwargs_mel_test = {
    "n_fft": 1024,
    "hop_length": 320,
    "win_length": 1024,
    "sample_rate": sr_mel,
    "n_mels": 64,
    "f_min": 50,
    "f_max": 14000,
    "top_db": None,
    "return_phase": False,
    "return_pre_db": False,
    "return_full_db": False
}

In [6]:
fit_extr_kwargs_stft = {
    "n_fft": 512,
    "hop_length": 256,
    "win_length": 512,
    "top_db": None,
    "return_phase": True,
    "return_pre_db": True,
    "return_full_db": True
}

In [7]:
fit_extr_kwargs_mel = {
    "n_fft": 1024,
    "hop_length": 320,
    "win_length": 1024,
    "sample_rate": sr_mel,
    "n_mels": 64,
    "f_min": 50,
    "f_max": 14000,
    "top_db": None,
    "return_phase": True,
    "return_pre_db": True,
    "return_full_db": True
}

## Loading wavs and models

In [8]:
clean_wav_file = "samples/cat.wav"
white_wav_file = "noises/165058__theundecided__white-noise.wav"
room_wav_file = "noises/203297__mzui__room-tone-office-industrial-ambience-01.wav"
horse_wav_file = "noises/149024__foxen10__horse_whinny.wav"

In [9]:
stft_model_weights = "logstft_cnn14.pth"
mel_model_weights = "logmel_cnn14.pth"

In [10]:
wav_clean, _ = torchaudio.load(clean_wav_file)
wav_white, _ = torchaudio.load(white_wav_file)
wav_room,  _ = torchaudio.load(room_wav_file)
wav_horse, _ = torchaudio.load(horse_wav_file)

In [11]:
Audio(wav_clean, rate=orig_sr)

In [12]:
Audio(wav_white, rate=orig_sr)

In [13]:
Audio(wav_room, rate=orig_sr)

In [14]:
Audio(wav_horse, rate=orig_sr)

## Prepare models

In [15]:
model_stft_kwargs = {"num_classes": num_classes, "num_bins": num_bins_stft}
model_mel_kwargs = {"num_classes": num_classes, "num_bins": num_bins_mel}

In [16]:
feature_extractor_stft_test = LogSTFTSpectrogram(**fit_extr_kwargs_stft_test)
feature_extractor_mel_test = LogMelSTFTSpectrogram(**fit_extr_kwargs_mel_test)
test_data_stft = ESC50dataset(root_dir=root_dir, sr=sr_stft, folds=test_folds, normalize="peak", feature_extractor=feature_extractor_stft_test)
test_data_mel = ESC50dataset(root_dir=root_dir, sr=sr_mel, folds=test_folds, normalize="peak", feature_extractor=feature_extractor_mel_test)
test_loader_kwargs = {"batch_size": 32, "shuffle": False}

In [17]:
device = torch.device("cuda:1")

optimizer_cls = optim.Adam
optimizer_kwargs = {"lr": 1e-4}

criterion_cls = nn.CrossEntropyLoss
use_mixup = False
mixup_alpha = 0.0

In [18]:
model_trainer_stft = Trainer(
    model_cls=TransferCnn14,
    train_data=None,
    train_loader_kwargs=None,
    criterion_cls=criterion_cls,
    optimizer_cls=optimizer_cls,
    model_kwargs=model_stft_kwargs,
    model_pretrain_weights_path=None,
    optimizer_kwargs=optimizer_kwargs,
    device=device,
    valid_data=None,
    valid_loader_kwargs=None,
    test_data=test_data_stft,
    test_loader_kwargs=test_loader_kwargs,
    use_mixup=use_mixup,
    mixup_alpha=mixup_alpha
)

Random seed set to: 42


In [19]:
model_trainer_mel = Trainer(
    model_cls=TransferCnn14,
    train_data=None,
    train_loader_kwargs=None,
    criterion_cls=criterion_cls,
    optimizer_cls=optimizer_cls,
    model_kwargs=model_mel_kwargs,
    model_pretrain_weights_path=None,
    optimizer_kwargs=optimizer_kwargs,
    device=device,
    valid_data=None,
    valid_loader_kwargs=None,
    test_data=test_data_mel,
    test_loader_kwargs=test_loader_kwargs,
    use_mixup=use_mixup,
    mixup_alpha=mixup_alpha
)

Random seed set to: 42


In [20]:
model_trainer_stft.model.load_state_dict(torch.load(stft_model_weights))

<All keys matched successfully>

In [21]:
model_trainer_stft.test()

Test Loss: 0.8443, Test Acc: 0.7800


(0.8443098521232605, 0.78)

In [22]:
model_trainer_mel.model.load_state_dict(torch.load(mel_model_weights))

<All keys matched successfully>

In [23]:
model_trainer_mel.test()

Test Loss: 0.2639, Test Acc: 0.9225


(0.2638887568563223, 0.9225)

In [24]:
model_stft = model_trainer_stft.model
model_mel = model_trainer_mel.model

## Preparing wavs

In [25]:
def rms_normalize(wav):
    rms = torch.sqrt(torch.mean(wav.pow(2)))
    if rms > 0.:
        return wav / rms
    else:
        return wav

def load_audio(path_to_audio, sr):
    audio, original_sr = torchaudio.load(path_to_audio)

    # convert to mono if necessary
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)

    # resample if necessary
    if original_sr != sr:
        audio = F.resample(audio, original_sr, sr)

    return audio


def fix_length(wav, num_samples):
    cur_len = wav.shape[-1]
    if cur_len < num_samples:
        pad = num_samples - cur_len
        wav = F.pad(wav, (0, pad))
    elif cur_len > num_samples:
        wav = wav[..., :num_samples]
    return wav


def make_mix(
    clean_audio_path, contaminating_audio_path,
    sr, duration=5.0, alpha=1.0, peak_normalize=True
):
    n_samples = int(sr * duration)

    clean = load_audio(clean_audio_path, sr)
    noise = load_audio(contaminating_audio_path, sr)

    clean = fix_length(clean, n_samples)
    noise = fix_length(noise, n_samples)

    clean = rms_normalize(clean)
    noise = rms_normalize(noise)

    mix = alpha * clean + (1 - alpha) * noise

    if peak_normalize:
        abs_max = mix.abs().max()
        if abs_max != 0:
            mix = mix / abs_max

    return mix

def make_clean(clean_audio_path, sr, duration=5.0, peak_normalize=True):
    n_samples = int(sr * duration)

    clean = load_audio(clean_audio_path, sr)

    clean = fix_length(clean, n_samples)

    if peak_normalize:
        abs_max = clean.abs().max()
        if abs_max != 0:
            clean = clean / abs_max

    return clean


### Samples for stft model

In [26]:
clean_stft = make_clean(clean_wav_file, sr = sr_stft)

In [27]:
Audio(clean_stft, rate=sr_stft)

In [28]:
white_stft = make_mix(clean_wav_file, white_wav_file, sr=sr_stft, alpha=0.6)

In [29]:
Audio(white_stft, rate=sr_stft)

In [30]:
room_stft = make_mix(clean_wav_file, room_wav_file, sr=sr_stft, alpha=0.6)

In [31]:
Audio(room_stft, rate=sr_stft)

In [32]:
horse_stft = make_mix(clean_wav_file, horse_wav_file, sr=sr_stft, alpha=0.6)

In [33]:
Audio(horse_stft, rate=sr_stft)

### Samples for mel model

In [34]:
clean_mel = make_clean(clean_wav_file, sr = sr_mel)

In [35]:
Audio(clean_mel, rate=sr_mel)

In [36]:
white_mel = make_mix(clean_wav_file, white_wav_file, sr=sr_mel, alpha=0.6)

In [37]:
Audio(white_mel, rate=sr_mel)

In [38]:
room_mel = make_mix(clean_wav_file, room_wav_file, sr=sr_mel, alpha=0.6)

In [39]:
Audio(room_mel, rate=sr_mel)

In [40]:
horse_mel = make_mix(clean_wav_file, horse_wav_file, sr=sr_mel, alpha=0.6)

In [41]:
Audio(horse_mel, rate=sr_mel)

## Predictions for samples

In [42]:
def get_prediction(model, sample, fit_extr):
    pred = model(fit_extr(sample.unsqueeze(0))[0].to(device))
    return torch.argmax(pred, 1)

In [43]:
feature_extractor_stft = LogSTFTSpectrogram(**fit_extr_kwargs_stft)
feature_extractor_mel = LogMelSTFTSpectrogram(**fit_extr_kwargs_mel)

### STFT model

In [44]:
get_prediction(model_stft, clean_stft, feature_extractor_stft)

tensor([5], device='cuda:1')

In [45]:
get_prediction(model_stft, white_stft, feature_extractor_stft)

tensor([5], device='cuda:1')

In [46]:
get_prediction(model_stft, room_stft, feature_extractor_stft)

tensor([5], device='cuda:1')

In [47]:
get_prediction(model_stft, horse_stft, feature_extractor_stft)

tensor([6], device='cuda:1')

### Mel model

In [48]:
get_prediction(model_mel, clean_mel, feature_extractor_mel)

tensor([5], device='cuda:1')

In [49]:
get_prediction(model_mel, white_mel, feature_extractor_mel)

tensor([5], device='cuda:1')

In [50]:
get_prediction(model_mel, room_mel, feature_extractor_mel)

tensor([5], device='cuda:1')

In [51]:
get_prediction(model_mel, horse_mel, feature_extractor_mel)

tensor([5], device='cuda:1')

## Getting attrs

In [52]:
silence_val = -100.

In [53]:
shap_background_folds = [1, 2, 3]

In [54]:
def get_balanced_background(dataloader, num_samples_per_class=2, device="cpu"):
    from collections import defaultdict
    class_to_samples = defaultdict(list)
    
    for batch_x, batch_y in dataloader:
        for x, y in zip(batch_x, batch_y):
            if len(class_to_samples[y.item()]) < num_samples_per_class:
                class_to_samples[y.item()].append(x.unsqueeze(0))
    
    background_tensors = []
    for class_label, tensor_list in class_to_samples.items():
        background_tensors.extend(tensor_list)
    
    background = torch.cat(background_tensors, dim=0).to(device)
    return background

In [55]:
train_data_shap_stft = ESC50dataset(root_dir=root_dir, sr=sr_stft, folds=shap_background_folds, normalize="peak", feature_extractor=feature_extractor_stft_test)
train_loader_shap_stft = DataLoader(train_data_shap_stft, batch_size=100, shuffle=False)
shap_background_stft = get_balanced_background(train_loader_shap_stft, num_samples_per_class=2, device=device)

In [56]:
train_data_shap_mel = ESC50dataset(root_dir=root_dir, sr=sr_mel, folds=shap_background_folds, normalize="peak", feature_extractor=feature_extractor_mel_test)
train_loader_shap_mel = DataLoader(train_data_shap_mel, batch_size=100, shuffle=False)
shap_background_mel = get_balanced_background(train_loader_shap_mel, num_samples_per_class=2, device=device)

In [57]:
predict_saliency_stft = Predict(model_stft, feature_extractor_stft, interp_method_cls=SaliencyInterpreter, interp_method_kwargs={}, device=device)
predict_gradcam_stft = Predict(model_stft, feature_extractor_stft, interp_method_cls=GradCAMInterpreter, interp_method_kwargs={"target_layers": [model_stft.base.conv_block6.conv2]}, device=device)
predict_lime_stft = Predict(model_stft, feature_extractor_stft, interp_method_cls=LIMEInterpreter, interp_method_kwargs={"num_samples": 1000}, device=device)
predict_shap_stft = Predict(model_stft, feature_extractor_stft, interp_method_cls=SHAPInterpreter, interp_method_kwargs={"background_data": shap_background_stft}, device=device)

In [58]:
predict_saliency_mel = Predict(model_mel, feature_extractor_mel, interp_method_cls=SaliencyInterpreter, interp_method_kwargs={}, device=device)
predict_gradcam_mel = Predict(model_mel, feature_extractor_mel, interp_method_cls=GradCAMInterpreter, interp_method_kwargs={"target_layers": [model_mel.base.conv_block6.conv2]}, device=device)
predict_lime_mel = Predict(model_mel, feature_extractor_mel, interp_method_cls=LIMEInterpreter, interp_method_kwargs={"num_samples": 1000}, device=device)
predict_shap_mel = Predict(model_mel, feature_extractor_mel, interp_method_cls=SHAPInterpreter, interp_method_kwargs={"background_data": shap_background_mel}, device=device)

In [59]:
predict_saliency_stft.predict(
    wav=clean_stft, wav_name="clean_saliency", sr=sr_stft, feature_type="stft",
    silence_val=silence_val, fmin=0, fmax=8000,
    save_root="predictions_illust", model_type="stft"
)

{'FF': tensor(0.9744),
 'AI': tensor(0.),
 'AD': tensor(99.9995),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.5000], dtype=torch.float64),
 'COMP': tensor([10.6021], dtype=torch.float64)}

In [60]:
predict_gradcam_stft.predict(
    wav=clean_stft, wav_name="clean_gradcam", sr=sr_stft, feature_type="stft",
    silence_val=silence_val, fmin=0, fmax=8000,
    save_root="predictions_illust", model_type="stft"
)

{'FF': tensor(0.9463),
 'AI': tensor(100.),
 'AD': tensor(0.),
 'AG': tensor(0.0310),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.]),
 'COMP': tensor([0.])}

In [61]:
predict_lime_stft.predict(
    wav=clean_stft, wav_name="clean_lime", sr=sr_stft, feature_type="stft",
    silence_val=silence_val, fmin=0, fmax=8000,
    save_root="predictions_illust", model_type="stft"
)

100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:13<00:00, 73.16it/s]


{'FF': tensor(0.3539),
 'AI': tensor(100.),
 'AD': tensor(0.),
 'AG': tensor(54.9044),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.4953], dtype=torch.float64),
 'COMP': tensor([10.6114], dtype=torch.float64)}

In [62]:
predict_shap_stft.predict(
    wav=clean_stft, wav_name="clean_shap", sr=sr_stft, feature_type="stft",
    silence_val=silence_val, fmin=0, fmax=8000,
    save_root="predictions_illust", model_type="stft"
)

Done extracting shap values


{'FF': tensor(0.9744),
 'AI': tensor(0.),
 'AD': tensor(99.9987),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.5000], dtype=torch.float64),
 'COMP': tensor([10.6021], dtype=torch.float64)}

In [63]:
predict_saliency_mel.predict(
    wav=clean_mel, wav_name="clean_saliency", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)



{'FF': tensor(0.9941),
 'AI': tensor(0.),
 'AD': tensor(99.5241),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.5000], dtype=torch.float64),
 'COMP': tensor([9.6823], dtype=torch.float64)}

In [64]:
predict_gradcam_mel.predict(
    wav=clean_mel, wav_name="clean_gradcam", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)



{'FF': tensor(0.9865),
 'AI': tensor(0.),
 'AD': tensor(0.),
 'AG': tensor(0.),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.]),
 'COMP': tensor([0.])}

In [65]:
predict_lime_mel.predict(
    wav=clean_mel, wav_name="clean_lime", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 203.90it/s]




{'FF': tensor(0.9538),
 'AI': tensor(0.),
 'AD': tensor(0.),
 'AG': tensor(0.),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.4832], dtype=torch.float64),
 'COMP': tensor([9.7154], dtype=torch.float64)}

In [66]:
predict_shap_mel.predict(
    wav=clean_mel, wav_name="clean_shap", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)

Done extracting shap values


{'FF': tensor(0.9974),
 'AI': tensor(0.),
 'AD': tensor(99.7335),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.5000], dtype=torch.float64),
 'COMP': tensor([9.6823], dtype=torch.float64)}

In [67]:
predict_saliency_mel.predict(
    wav=white_mel, wav_name="white_saliency", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)



{'FF': tensor(0.9570),
 'AI': tensor(0.),
 'AD': tensor(0.0030),
 'AG': tensor(0.),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.]),
 'COMP': tensor([0.])}

In [68]:
predict_gradcam_mel.predict(
    wav=white_mel, wav_name="white_gradcam", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)



{'FF': tensor(-0.0271),
 'AI': tensor(0.),
 'AD': tensor(99.8838),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.4986], dtype=torch.float64),
 'COMP': tensor([9.6852], dtype=torch.float64)}

In [69]:
predict_lime_mel.predict(
    wav=white_mel, wav_name="white_lime", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 249.16it/s]




{'FF': tensor(0.9702),
 'AI': tensor(100.),
 'AD': tensor(0.),
 'AG': tensor(65.3960),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.4845], dtype=torch.float64),
 'COMP': tensor([9.7129], dtype=torch.float64)}

In [70]:
predict_shap_mel.predict(
    wav=white_mel, wav_name="white_shap", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)

Done extracting shap values


{'FF': tensor(0.9700),
 'AI': tensor(0.),
 'AD': tensor(99.9358),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.5000], dtype=torch.float64),
 'COMP': tensor([9.6823], dtype=torch.float64)}

In [71]:
predict_saliency_mel.predict(
    wav=room_mel, wav_name="room_saliency", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)



{'FF': tensor(0.9865),
 'AI': tensor(0.),
 'AD': tensor(0.),
 'AG': tensor(0.),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.]),
 'COMP': tensor([0.])}

In [72]:
predict_gradcam_mel.predict(
    wav=room_mel, wav_name="room_gradcam", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)



{'FF': tensor(0.9865),
 'AI': tensor(0.),
 'AD': tensor(0.),
 'AG': tensor(0.),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.]),
 'COMP': tensor([0.])}

In [73]:
predict_lime_mel.predict(
    wav=room_mel, wav_name="room_lime", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 242.40it/s]




{'FF': tensor(0.9955),
 'AI': tensor(0.),
 'AD': tensor(2.4134),
 'AG': tensor(0.),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.4821], dtype=torch.float64),
 'COMP': tensor([9.7175], dtype=torch.float64)}

In [74]:
predict_shap_mel.predict(
    wav=room_mel, wav_name="room_shap", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)

Done extracting shap values


{'FF': tensor(0.9991),
 'AI': tensor(0.),
 'AD': tensor(99.9325),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.5000], dtype=torch.float64),
 'COMP': tensor([9.6823], dtype=torch.float64)}

In [75]:
predict_saliency_mel.predict(
    wav=horse_mel, wav_name="horse_saliency", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)



{'FF': tensor(0.9970),
 'AI': tensor(0.),
 'AD': tensor(99.9815),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.5000], dtype=torch.float64),
 'COMP': tensor([9.6823], dtype=torch.float64)}

In [76]:
predict_gradcam_mel.predict(
    wav=horse_mel, wav_name="horse_gradcam", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)



{'FF': tensor(0.9838),
 'AI': tensor(100.),
 'AD': tensor(0.),
 'AG': tensor(0.0284),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.]),
 'COMP': tensor([0.])}

In [77]:
predict_lime_mel.predict(
    wav=horse_mel, wav_name="horse_lime", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 243.29it/s]




{'FF': tensor(0.9972),
 'AI': tensor(100.),
 'AD': tensor(0.),
 'AG': tensor(19.3776),
 'FidIn': tensor([1.]),
 'SPS': tensor([0.4895], dtype=torch.float64),
 'COMP': tensor([9.7031], dtype=torch.float64)}

In [78]:
predict_shap_mel.predict(
    wav=horse_mel, wav_name="horse_shap", sr=sr_stft, feature_type="mel",
    silence_val=silence_val, fmin=50, fmax=14000,
    save_root="predictions_illust", model_type="mel"
)

Done extracting shap values


{'FF': tensor(0.9972),
 'AI': tensor(0.),
 'AD': tensor(99.9882),
 'AG': tensor(0.),
 'FidIn': tensor([0.]),
 'SPS': tensor([0.5000], dtype=torch.float64),
 'COMP': tensor([9.6823], dtype=torch.float64)}

## The picture

In [101]:
def _plot_spec(ax, spec, title, is_mel, sr, fmin, fmax, vmin, vmax, cmap="magma"):
    if isinstance(spec, torch.Tensor):
        spec = spec.detach().cpu().numpy()

    n_bands, n_frames = spec.shape

    if is_mel:
        freqs = librosa.mel_frequencies(n_mels=n_bands, fmin=fmin, fmax=fmax)
    else:
        freqs = np.linspace(0, sr / 2, n_bands)

    im = ax.imshow(spec, aspect="auto", origin="lower", cmap=cmap, vmin=vmin, vmax=vmax)
    yt = np.linspace(0, n_bands - 1, 6, dtype=int)
    ax.set_yticks(yt)
    ax.set_yticklabels([f"{round(freqs[i])}" for i in yt], fontsize=4)
    ax.set_xticks([])
    ax.set_title(title, fontsize=8)
    return im

In [102]:
def apply_mask(inputs, mask, silence_val=None):
    if silence_val is None:
        base = inputs.amin(dim=(-3, -2, -1), keepdim=True)
    else:
        base = torch.tensor(silence_val,
                            dtype=inputs.dtype,
                            device=inputs.device).view(1, 1, 1, 1)

    masked = (inputs - base) * mask + base
    return masked

In [103]:
def make_five_row_plot(
    clean_stft, white_stft, room_stft, horse_stft,
    clean_mel, white_mel, room_mel, horse_mel,
    preds_stft, preds_mel, cfg_stft, cfg_mel,
    save_path="figure.png",
):

    mask_names = {
        "Saliency-topK50": ("Saliency", "topK_50_pos"),
        "GradCAM-bin": ("GradCAM", "bin"),
        "GradCAM-minmax": ("GradCAM", "minmax"),
        "LIME-bin": ("LIME", "bin"),
        "LIME-minmax": ("LIME", "minmax"),
    }
    col_order = ["Original"] + list(mask_names)
    rows = [
        ("STFT clean", clean_stft, preds_stft, cfg_stft),
        ("Mel clean", clean_mel, preds_mel, cfg_mel),
        ("Mel white", white_mel, preds_mel, cfg_mel),
        ("Mel room", room_mel, preds_mel, cfg_mel),
        ("Mel horse", horse_mel, preds_mel, cfg_mel),
    ]

    all_specs = []

    for _, wav_row, pred_dict, _ in rows:
        
        pred_any = next(iter(pred_dict.values()))
        spec_db, *_ = pred_any.feature_extractor(
            wav_row.to(pred_any.device).unsqueeze(0))
        all_specs.append(spec_db[0][0])

        for (method, mask_key) in mask_names.values():
            pred = pred_dict[method]
            spec_db, *_ = pred.feature_extractor(
                wav_row.to(pred.device).unsqueeze(0))
            _, masks = pred.interpretator.interpret(spec_db, ret_masks=True)
            masked_db = apply_mask(spec_db, masks[mask_key],
                                      silence_val=silence_val)
            all_specs.append(masked_db[0][0])

    global_min = min(spec.min().item() for spec in all_specs)
    global_max = max(spec.max().item() for spec in all_specs)
    
    fig, axes = plt.subplots(len(rows), len(col_order),
                             figsize=(20, 12), dpi=300)

    for r, (row_lbl, wav_row, pred_dict, cfg) in enumerate(rows):

        pred_any = next(iter(pred_dict.values()))
        spec_db, *_ = pred_any.feature_extractor(
            wav_row.to(pred_any.device).unsqueeze(0))
        _plot_spec(axes[r, 0], spec_db[0][0], row_lbl,
                   **cfg, vmin=global_min, vmax=global_max)

        for c, col in enumerate(col_order[1:], 1):
            method, mask_key = mask_names[col]
            pred = pred_dict[method]

            spec_db, *_ = pred.feature_extractor(
                wav_row.to(pred.device).unsqueeze(0))
            _, masks = pred.interpretator.interpret(spec_db, ret_masks=True)
            masked_db = apply_mask(spec_db, masks[mask_key],
                                      silence_val=silence_val)
            _plot_spec(axes[r, c], masked_db[0][0], col,
                       **cfg, vmin=global_min, vmax=global_max)

    cbar = fig.colorbar(
        axes[0, 0].images[0], ax=axes.ravel().tolist(),
        shrink=0.6, label="Energy (dB)"
    )

    # plt.tight_layout()
    fig.savefig(save_path)
    plt.close(fig)
    print(f"Saved {save_path}")

In [104]:
predict_saliency_stft = Predict(model_stft, feature_extractor_stft, interp_method_cls=SaliencyInterpreter, interp_method_kwargs={}, device=device)
predict_gradcam_stft = Predict(model_stft, feature_extractor_stft, interp_method_cls=GradCAMInterpreter, interp_method_kwargs={"target_layers": [model_stft.base.conv_block6.conv2]}, device=device)
predict_lime_stft = Predict(model_stft, feature_extractor_stft, interp_method_cls=LIMEInterpreter, interp_method_kwargs={"num_samples": 1000}, device=device)
predict_shap_stft = Predict(model_stft, feature_extractor_stft, interp_method_cls=SHAPInterpreter, interp_method_kwargs={"background_data": shap_background_stft}, device=device)

In [105]:
predict_saliency_mel = Predict(model_mel, feature_extractor_mel, interp_method_cls=SaliencyInterpreter, interp_method_kwargs={}, device=device)
predict_gradcam_mel = Predict(model_mel, feature_extractor_mel, interp_method_cls=GradCAMInterpreter, interp_method_kwargs={"target_layers": [model_mel.base.conv_block6.conv2]}, device=device)
predict_lime_mel = Predict(model_mel, feature_extractor_mel, interp_method_cls=LIMEInterpreter, interp_method_kwargs={"num_samples": 1000}, device=device)
predict_shap_mel = Predict(model_mel, feature_extractor_mel, interp_method_cls=SHAPInterpreter, interp_method_kwargs={"background_data": shap_background_mel}, device=device)

In [106]:
preds_stft = {"Saliency": predict_saliency_stft, "GradCAM": predict_gradcam_stft, "LIME": predict_lime_stft, "SHAP": predict_shap_stft}
preds_mel  = {"Saliency": predict_saliency_mel,  "GradCAM": predict_gradcam_mel,  "LIME": predict_lime_mel, "SHAP": predict_shap_mel}

In [107]:
make_five_row_plot(
    clean_stft=clean_stft,
    white_stft=white_stft,
    room_stft=room_stft,
    horse_stft=horse_stft,
    clean_mel=clean_mel,
    white_mel=white_mel,
    room_mel=room_mel,
    horse_mel=horse_mel,
    preds_stft= preds_stft,
    preds_mel= preds_mel,
    cfg_stft= dict(sr=16000, fmin=0,  fmax=8000, is_mel=False),
    cfg_mel= dict(sr=32000, fmin=50,  fmax=14000, is_mel=True),
    save_path = "figure.png",
)

100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:13<00:00, 73.41it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:13<00:00, 73.29it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 215.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 215.31it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 248.92it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 248.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 242.14it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 242.03it/s]
100%|███████████████████████████

Saved figure.png
