# Run experiments

In [1]:
from audiointerp.dataset.esc50 import ESC50dataset
from audiointerp.model.cnn14 import TransferCnn14
from audiointerp.fit import Trainer
from audiointerp.processing.spectrogram import LogSTFTSpectrogram
from audiointerp.interpretation.saliency import SaliencyInterpreter
from audiointerp.interpretation.gradcam import GradCAMInterpreter
from audiointerp.interpretation.shap import SHAPInterpreter
from audiointerp.interpretation.lime import LIMEInterpreter
import torchaudio
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T_audio
import torchvision.transforms as T_vision
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import random
import numpy as np
from IPython.display import Audio
from audiointerp.predict import Predict
from audiointerp.metrics import Metrics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_dir = "/root/ESC50"
# root_dir = "/home/yuliya/ESC50"
sr = 16000
train_folds = [1, 2, 3]
valid_folds = [4]
test_folds = [5]

In [3]:
n_fft = 512
hop_length = 256
win_length = 512
top_db = None

In [4]:
feature_extractor = LogSTFTSpectrogram(
    n_fft=n_fft, win_length=win_length, hop_length=hop_length, top_db=top_db,
    return_phase=False, return_full_db=False
)

In [5]:
test_data = ESC50dataset(root_dir=root_dir, sr=sr, folds=test_folds, normalize="peak", feature_extractor=feature_extractor)
test_loader_kwargs = {"batch_size": 32, "shuffle": False}

In [6]:
device = torch.device("cuda:1")
model_cls = TransferCnn14
model_kwargs = {"num_classes": 50, "num_bins": 257}
model_pretrain_weights_path = "weights/Cnn14_mAP=0.431.pth"

optimizer_cls = optim.Adam
optimizer_kwargs = {"lr": 1e-4}

criterion_cls = nn.CrossEntropyLoss
use_mixup = False
mixup_alpha = 0.0

In [7]:
model_trainer = Trainer(
    model_cls=model_cls,
    train_data=None,
    train_loader_kwargs=None,
    criterion_cls=criterion_cls,
    optimizer_cls=optimizer_cls,
    model_kwargs=model_kwargs,
    model_pretrain_weights_path=model_pretrain_weights_path,
    optimizer_kwargs=optimizer_kwargs,
    device=device,
    valid_data=None,
    valid_loader_kwargs=None,
    test_data=test_data,
    test_loader_kwargs=test_loader_kwargs,
    use_mixup=use_mixup,
    mixup_alpha=mixup_alpha
)

Random seed set to: 42


In [8]:
model_trainer.model.load_state_dict(torch.load("logstft_cnn14.pth"))

<All keys matched successfully>

In [9]:
model_trainer.test()

Test Loss: 0.8443, Test Acc: 0.7800


(0.8443098521232605, 0.78)

In [10]:
model = model_trainer.model

___

In [11]:
silence_val = -100.

In [12]:
shap_background_folds = [1, 2, 3]

In [13]:
def get_balanced_background(dataloader, num_samples_per_class=2, device="cpu"):
    from collections import defaultdict
    class_to_samples = defaultdict(list)
    
    for batch_x, batch_y in dataloader:
        for x, y in zip(batch_x, batch_y):
            if len(class_to_samples[y.item()]) < num_samples_per_class:
                class_to_samples[y.item()].append(x.unsqueeze(0))
    
    background_tensors = []
    for class_label, tensor_list in class_to_samples.items():
        background_tensors.extend(tensor_list)
    
    background = torch.cat(background_tensors, dim=0).to(device)
    return background

In [14]:
feature_extractor_predict = LogSTFTSpectrogram(
    n_fft=n_fft, win_length=win_length, hop_length=hop_length, top_db=top_db,
    return_phase=True, return_full_db=True
)

In [15]:
test_data_predict = ESC50dataset(root_dir=root_dir, sr=sr, folds=test_folds, normalize="peak")
test_loader_predict = DataLoader(test_data_predict, batch_size=1, shuffle=False)
train_data_shap = ESC50dataset(root_dir=root_dir, sr=sr, folds=shap_background_folds, normalize="peak", feature_extractor=feature_extractor)
train_loader_shap = DataLoader(train_data_shap, batch_size=100, shuffle=False)
shap_background = get_balanced_background(train_loader_shap, num_samples_per_class=2, device=device)

In [16]:
predict_saliency = Predict(model, feature_extractor_predict, interp_method_cls=SaliencyInterpreter, interp_method_kwargs={}, device=device)
predict_gradcam = Predict(model, feature_extractor_predict, interp_method_cls=GradCAMInterpreter, interp_method_kwargs={"target_layers": [model.base.conv_block6.conv2]}, device=device)
predict_lime = Predict(model, feature_extractor_predict, interp_method_cls=LIMEInterpreter, interp_method_kwargs={"num_samples": 1000}, device=device)
predict_shap = Predict(model, feature_extractor_predict, interp_method_cls=SHAPInterpreter, interp_method_kwargs={"background_data": shap_background}, device=device)

In [17]:
results_saliency = predict_saliency.predict_set(test_loader_predict, 'saliency_clean.csv', compute_first=True,
                                                silence_val=silence_val, model_type="cnn14_logstft", save_dir="results")

Все CSV-файлы сохранены в results/cnn14_logstft/saliency_clean/csvs


In [18]:
results_gradcam = predict_gradcam.predict_set(test_loader_predict, 'gradcam_clean.csv', compute_first=True,
                                                silence_val=silence_val, model_type="cnn14_logstft", save_dir="results")

Все CSV-файлы сохранены в results/cnn14_logstft/gradcam_clean/csvs


In [19]:
results_lime = predict_lime.predict_set(test_loader_predict, 'lime_clean.csv', compute_first=True,
                                        silence_val=silence_val, model_type="cnn14_logstft", save_dir="results")

100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:11<00:00, 86.97it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:15<00:00, 64.95it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 69.76it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:16<00:00, 60.80it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 83.01it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:17<00:00, 58.08it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:11<00:00, 87.95it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:11<00:00, 87.01it/s]
100%|███████████████████████████

Все CSV-файлы сохранены в results/cnn14_logstft/lime_clean/csvs


In [None]:
results_shap = predict_shap.predict_set(test_loader_predict, 'shap_clean.csv', compute_first=True,
                                        silence_val=silence_val, model_type="cnn14_logstft", save_dir="results")

Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap values
Done extracting shap