# Run experiments

In [1]:
from audiointerp.dataset.esc50 import ESC50dataset, ESC50contaminated
from audiointerp.model.cnn14 import TransferCnn14
from audiointerp.fit import Trainer
from audiointerp.processing.spectrogram import LogMelSTFTSpectrogram
from audiointerp.interpretation.saliency import SaliencyInterpreter
from audiointerp.interpretation.gradcam import GradCAMInterpreter
import torchaudio
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T_audio
import torchvision.transforms as T_vision
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import random
import numpy as np
from IPython.display import Audio
from audiointerp.predict import Predict
from audiointerp.metrics import Metrics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRAINING = False

In [3]:
def plot_learning_curves(train_losses, val_losses, train_accs=None, val_accs=None):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Train Loss")
    if val_losses:
        plt.plot(epochs, val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss Curve")
    plt.legend()

    plt.subplot(1, 2, 2)
    if train_accs is not None:
        plt.plot(epochs, train_accs, label="Train Acc")
    if val_accs is not None:
        plt.plot(epochs, val_accs, label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy Curve")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [4]:
root_dir = "/home/yuliya/ESC50"
sr = 32000
train_folds = [1, 2, 3]
valid_folds = [4]
test_folds = [5]

In [5]:
n_fft = 1024
hop_length = 320
win_length = 1024
n_mels = 64
f_min = 50
f_max = 14000
top_db = 80

In [6]:
feature_extractor_fit = LogMelSTFTSpectrogram(
    n_fft=n_fft, win_length=win_length, hop_length=hop_length,
    sample_rate=sr, n_mels=n_mels, f_min=f_min, f_max=f_max, top_db=top_db,
    return_phase=False, return_full_db=False
)

In [7]:
feature_extractor_predict = LogMelSTFTSpectrogram(
    n_fft=n_fft, win_length=win_length, hop_length=hop_length,
    sample_rate=sr, n_mels=n_mels, f_min=f_min, f_max=f_max, top_db=top_db,
    return_phase=True, return_full_db=True
)

In [8]:
feature_augs = nn.Sequential(
    T_audio.FrequencyMasking(20),
    T_audio.TimeMasking(20)
)

In [9]:
train_data = ESC50dataset(root_dir=root_dir, sr=sr, folds=train_folds, normalize="peak", feature_extractor=feature_extractor_fit, feature_augs=feature_augs)
valid_data = ESC50dataset(root_dir=root_dir, sr=sr, folds=valid_folds, normalize="peak", feature_extractor=feature_extractor_fit)
test_data = ESC50dataset(root_dir=root_dir, sr=sr, folds=test_folds, normalize="peak", feature_extractor=feature_extractor_fit)
test_data_noisy = ESC50contaminated(root_dir=root_dir, sr=sr, folds=test_folds, normalize="peak", feature_extractor=feature_extractor_fit,
                                    path_to_contaminating_audio="samples/sea_waves.wav")

In [10]:
train_loader_kwargs = {"batch_size": 32, "shuffle": True}
valid_loader_kwargs = {"batch_size": 32, "shuffle": False}
test_loader_kwargs = {"batch_size": 32, "shuffle": False}

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_cls = TransferCnn14
model_kwargs = {"num_classes": 50, "num_bins": 64}
model_pretrain_weights_path = "weights/Cnn14_mAP=0.431.pth"

optimizer_cls = optim.Adam
optimizer_kwargs = {"lr": 1e-4}

criterion_cls = nn.CrossEntropyLoss
use_mixup = True
mixup_alpha = 0.2

In [12]:
model_trainer = Trainer(
    model_cls=model_cls,
    train_data=train_data,
    train_loader_kwargs=train_loader_kwargs,
    criterion_cls=criterion_cls,
    optimizer_cls=optimizer_cls,
    model_kwargs=model_kwargs,
    model_pretrain_weights_path=model_pretrain_weights_path,
    optimizer_kwargs=optimizer_kwargs,
    device=device,
    valid_data=valid_data,
    valid_loader_kwargs=valid_loader_kwargs,
    test_data=test_data,
    test_loader_kwargs=test_loader_kwargs,
    use_mixup=use_mixup,
    mixup_alpha=mixup_alpha
)

Random seed set to: 42


In [13]:
if TRAINING:
    train_losses, train_accs, val_losses, val_accs, test_loss, test_acc = model_trainer.train(num_epochs=20, save_weights_path="logmel_cnn14.pth")

In [14]:
if TRAINING:
    plot_learning_curves(train_losses=train_losses, train_accs=train_accs, val_losses=val_losses, val_accs=val_accs)

In [15]:
model_trainer.model.load_state_dict(torch.load("logmel_cnn14.pth"))

<All keys matched successfully>

In [16]:
model_trainer.test()

Test Loss: 0.3285, Test Acc: 0.9175


(0.32851228475570676, 0.9175)

In [17]:
test_loader_noisy = DataLoader(test_data_noisy, **test_loader_kwargs)

In [18]:
model_trainer.test(test_loader_noisy)

Test Loss: 1.5115, Test Acc: 0.6075


(1.5115180492401123, 0.6075)

In [19]:
model = model_trainer.model
model

TransferCnn14(
  (base): Cnn14(
    (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_block1): ConvBlock(
      (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_block2): ConvBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_block3): ConvBlock(
      (conv1): Conv2d(128, 

___

In [20]:
test_data_predict = ESC50dataset(root_dir=root_dir, sr=sr, folds=test_folds, normalize="peak")
test_loader_predict = DataLoader(test_data_predict, batch_size=1, shuffle=False)

In [21]:
test_data_noisy_predict = ESC50contaminated(root_dir=root_dir, sr=sr, folds=test_folds, normalize="peak", path_to_contaminating_audio="samples/sea_waves.wav")
test_loader_noisy_predict = DataLoader(test_data_noisy_predict, batch_size=1, shuffle=False)

In [22]:
predict_saliency = Predict(model, feature_extractor_predict, interp_method_cls=SaliencyInterpreter, interp_method_kwargs={}, device=device)
predict_gradcam = Predict(model, feature_extractor_predict, interp_method_cls=GradCAMInterpreter, interp_method_kwargs={"target_layers": [model.base.conv_block6.conv2]}, device=device)

In [23]:
results_saliency = predict_saliency.predict_set(test_loader_predict, 'saliency_clean.csv', compute_first=True)

Results saved as results/saliency_clean.csv


In [28]:
results_saliency.head(10)

Unnamed: 0_level_0,minmax,minmax,minmax,minmax,minmax,minmax,minmax,minmax,sigmoid,sigmoid,...,topK_30_pos,topK_30_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos
Unnamed: 0_level_1,FF,AI,AD,AG,FidIn,SPS,COMP,is_correct,FF,AI,...,COMP,is_correct,FF,AI,AD,AG,FidIn,SPS,COMP,is_correct
sample,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0,0.81126,0.0,66.150238,0.0,1.0,0.529191,9.880007,True,0.262897,0.0,...,9.880007,True,0.813854,0.0,0.000482,0.0,1.0,0.529191,9.880007,True
1,0.04439,100.0,0.0,32.084843,1.0,0.773528,9.174465,False,0.046804,100.0,...,9.174465,False,0.071863,0.0,96.673141,0.0,0.0,0.773528,9.174465,False
2,0.41189,0.0,37.338764,0.0,1.0,0.729575,9.340622,False,0.358921,0.0,...,9.340622,False,0.44235,100.0,0.0,0.002761,1.0,0.729575,9.340622,False
3,0.445009,0.0,95.43322,0.0,0.0,0.838474,8.783148,False,0.121926,0.0,...,8.783148,False,0.446877,0.0,98.4692,0.0,0.0,0.838474,8.783148,False
4,0.758913,0.0,3.211129,0.0,1.0,0.55311,9.825275,True,0.443932,0.0,...,9.825275,True,0.751364,0.0,0.002733,0.0,1.0,0.55311,9.825275,True
5,0.451751,0.0,95.179344,0.0,0.0,0.868279,8.516842,True,0.40131,0.0,...,8.516842,True,0.448153,100.0,0.0,0.005454,1.0,0.868279,8.516842,True
6,0.53497,0.0,99.137337,0.0,0.0,0.557085,9.828486,True,0.155952,0.0,...,9.828486,True,0.508947,0.0,0.001972,0.0,1.0,0.557085,9.828486,True
7,0.877622,0.0,74.332932,0.0,1.0,0.524845,9.898358,True,0.315234,0.0,...,9.898358,True,0.908337,0.0,1.3e-05,0.0,1.0,0.524845,9.898358,True
8,0.604776,0.0,95.71582,0.0,0.0,0.574951,9.786975,True,0.085,0.0,...,9.786975,True,0.604978,0.0,99.909546,0.0,0.0,0.574951,9.786975,True
9,0.306725,0.0,61.044971,0.0,1.0,0.831246,8.810982,True,0.040831,0.0,...,8.810982,True,0.350312,0.0,0.000293,0.0,1.0,0.831246,8.810982,True


In [31]:
is_corr = results_saliency[('minmax', 'is_correct')]
results_saliency_correct   = results_saliency[is_corr].drop(columns=[('minmax','is_correct')])
results_saliency_incorrect = results_saliency[~is_corr].drop(columns=[('minmax','is_correct')])

In [32]:
results_saliency.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.758135,0.220415
minmax,AI,23.500000,42.452980
minmax,AD,34.437916,35.617657
minmax,AG,12.198810,25.076937
minmax,FidIn,0.782500,0.413062
...,...,...,...
topK_50_pos,AD,34.593307,46.926758
topK_50_pos,AG,0.001995,0.004434
topK_50_pos,FidIn,0.650000,0.477567
topK_50_pos,SPS,0.563807,0.081573


In [33]:
results_saliency_correct.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.791308,0.190732
minmax,AI,22.343325,41.711498
minmax,AD,34.177132,35.504986
minmax,AG,11.677388,24.711126
minmax,FidIn,0.792916,0.405770
...,...,...,...
topK_50_pos,AD,32.404835,46.230015
topK_50_pos,AG,0.002066,0.004499
topK_50_pos,FidIn,0.673025,0.469749
topK_50_pos,SPS,0.561863,0.078915


In [34]:
results_saliency_incorrect.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.389209,0.191586
minmax,AI,36.363636,48.850418
minmax,AD,37.338108,37.288876
minmax,AG,17.997648,28.618521
minmax,FidIn,0.666667,0.478714
...,...,...,...
topK_50_pos,AD,58.931789,48.451023
topK_50_pos,AG,0.001196,0.003589
topK_50_pos,FidIn,0.393939,0.496198
topK_50_pos,SPS,0.585423,0.105914


In [23]:
results_gradcam = predict_gradcam.predict_set(test_loader_predict, results_csv_name='gradcam_clean.csv', compute_first=True)

Results saved as results/gradcam_clean.csv


In [24]:
results_gradcam.head(20)

Unnamed: 0_level_0,minmax,minmax,minmax,minmax,minmax,minmax,minmax,minmax,sigmoid,sigmoid,...,topK_30_pos,topK_30_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos
Unnamed: 0_level_1,FF,AI,AD,AG,FidIn,SPS,COMP,is_correct,FF,AI,...,COMP,is_correct,FF,AI,AD,AG,FidIn,SPS,COMP,is_correct
sample,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0,0.589352,0.0,96.262627,0.0,0.0,0.490846,9.969327,True,0.55191,0.0,...,9.969327,True,0.313187,0.0,58.799671,0.0,1.0,0.490846,9.969327,True
1,3e-06,0.0,70.103493,0.0,0.0,0.0,0.0,False,0.026037,0.0,...,0.0,False,0.053047,0.0,0.004332,0.0,1.0,0.0,0.0,False
2,0.150595,0.0,89.507118,0.0,0.0,0.388541,10.11932,False,0.330146,0.0,...,10.11932,False,-0.019665,0.0,52.115631,0.0,1.0,0.388541,10.11932,False
3,-6e-06,0.0,95.054153,0.0,0.0,0.93457,7.924671,False,0.073101,0.0,...,7.924671,False,0.431718,100.0,0.0,0.001026,1.0,0.93457,7.924671,False
4,0.5265,0.0,97.205643,0.0,0.0,0.496131,9.925839,True,0.519146,0.0,...,9.925839,True,0.422681,0.0,21.207138,0.0,1.0,0.496131,9.925839,True
5,0.371266,0.0,96.574623,0.0,0.0,0.858366,8.590769,True,0.407744,0.0,...,8.590769,True,0.448153,100.0,0.0,0.005454,1.0,0.858366,8.590769,True
6,-0.02535,0.0,93.879433,0.0,0.0,0.871608,8.588024,True,0.113436,0.0,...,8.588024,True,0.508947,0.0,0.001972,0.0,1.0,0.871608,8.588024,True
7,0.558502,0.0,48.779224,0.0,1.0,0.659537,9.537663,True,0.551851,0.0,...,9.537663,True,0.437498,0.0,6.35289,0.0,1.0,0.659537,9.537663,True
8,0.132404,0.0,94.643158,0.0,0.0,0.657255,9.541927,True,0.064075,0.0,...,9.541927,True,0.124052,0.0,84.18119,0.0,0.0,0.657255,9.541927,True
9,0.128112,0.0,94.487564,0.0,0.0,0.863488,8.560882,True,0.198542,100.0,...,8.560882,True,0.350312,0.0,0.000293,0.0,1.0,0.863488,8.560882,True


In [25]:
is_corr = results_gradcam[('minmax', 'is_correct')]
results_gradcam_correct   = results_gradcam[is_corr].drop(columns=[('minmax','is_correct')])
results_gradcam_incorrect = results_gradcam[~is_corr].drop(columns=[('minmax','is_correct')])

In [26]:
results_gradcam.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.166924,0.235897
minmax,AI,1.250000,11.124157
minmax,AD,69.282318,32.452690
minmax,AG,0.574106,6.282712
minmax,FidIn,0.460000,0.499022
...,...,...,...
topK_50_pos,AD,25.077852,31.316759
topK_50_pos,AG,1.011616,6.974795
topK_50_pos,FidIn,0.887500,0.316376
topK_50_pos,SPS,0.556569,0.244833


In [27]:
results_gradcam_correct.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.177787,0.241062
minmax,AI,0.817439,9.016495
minmax,AD,68.352776,32.888042
minmax,AG,0.315252,4.736497
minmax,FidIn,0.490463,0.500592
...,...,...,...
topK_50_pos,AD,24.521065,30.886446
topK_50_pos,AG,0.873674,6.238208
topK_50_pos,FidIn,0.904632,0.294123
topK_50_pos,SPS,0.556951,0.238237


In [28]:
results_gradcam_incorrect.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.046119,0.113185
minmax,AI,6.060606,24.230585
minmax,AD,79.619919,25.310419
minmax,AG,3.452870,15.042082
minmax,FidIn,0.121212,0.331434
...,...,...,...
topK_50_pos,AD,31.270046,35.699242
topK_50_pos,AG,2.545704,12.603113
topK_50_pos,FidIn,0.696970,0.466694
topK_50_pos,SPS,0.552322,0.313431


In [31]:
results_saliency_noisy = predict_saliency.predict_set(test_loader_noisy_predict, 'saliency_seawaves_noise.csv', compute_first=True)

Results saved as results/saliency_seawaves_noise.csv


In [32]:
results_saliency_noisy.head(20)

Unnamed: 0_level_0,minmax,minmax,minmax,minmax,minmax,minmax,minmax,minmax,sigmoid,sigmoid,...,topK_30_pos,topK_30_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos
Unnamed: 0_level_1,FF,AI,AD,AG,FidIn,SPS,COMP,is_correct,FF,AI,...,COMP,is_correct,FF,AI,AD,AG,FidIn,SPS,COMP,is_correct
sample,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0,0.679046,0.0,5.650103,0.0,1.0,0.51046,9.919516,True,0.160181,0.0,...,9.919516,True,0.66971,100.0,0.0,0.002188,1.0,0.51046,9.919516,True
1,0.184986,0.0,51.760197,0.0,0.0,0.683472,9.463755,False,0.174259,0.0,...,9.463755,False,0.183753,0.0,96.312798,0.0,0.0,0.683472,9.463755,False
2,0.260286,100.0,0.0,6.984935,0.0,0.545431,9.837088,False,0.188405,0.0,...,9.837088,False,0.289877,0.0,95.231209,0.0,0.0,0.545431,9.837088,False
3,0.464612,100.0,0.0,31.195004,1.0,0.515387,9.915707,False,0.288434,0.0,...,9.915707,False,0.470877,0.0,99.51767,0.0,0.0,0.515387,9.915707,False
4,0.348997,0.0,55.39592,0.0,0.0,0.551653,9.833518,True,0.236015,0.0,...,9.833518,True,0.351971,0.0,99.988434,0.0,0.0,0.551653,9.833518,True
5,0.407829,100.0,0.0,20.734211,1.0,0.519932,9.90491,False,0.251218,0.0,...,9.90491,False,0.408596,0.0,98.962585,0.0,0.0,0.519932,9.90491,False
6,0.261508,0.0,87.073692,0.0,0.0,0.511714,9.921237,True,0.034323,100.0,...,9.921237,True,0.253597,0.0,97.328117,0.0,0.0,0.511714,9.921237,True
7,0.823038,0.0,33.554901,0.0,1.0,0.568191,9.798009,True,0.281812,0.0,...,9.798009,True,0.849337,100.0,0.0,0.006475,1.0,0.568191,9.798009,True
8,0.290546,0.0,86.542419,0.0,0.0,0.533382,9.879832,True,-0.073873,100.0,...,9.879832,True,0.289162,0.0,98.499313,0.0,0.0,0.533382,9.879832,True
9,0.238976,100.0,0.0,50.194286,1.0,0.528078,9.889299,False,0.099571,0.0,...,9.889299,False,0.245043,0.0,98.303421,0.0,0.0,0.528078,9.889299,False


In [33]:
is_corr = results_saliency_noisy[('minmax', 'is_correct')]
results_saliency_noisy_correct   = results_saliency_noisy[is_corr].drop(columns=[('minmax','is_correct')])
results_saliency_noisy_incorrect = results_saliency_noisy[~is_corr].drop(columns=[('minmax','is_correct')])

In [34]:
results_saliency_noisy.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.600442,0.248891
minmax,AI,47.750000,50.011902
minmax,AD,20.357559,29.850380
minmax,AG,23.702827,30.534906
minmax,FidIn,0.830000,0.376103
...,...,...,...
topK_50_pos,AD,42.816151,48.635948
topK_50_pos,AG,0.001634,0.004144
topK_50_pos,FidIn,0.562500,0.496700
topK_50_pos,SPS,0.526970,0.028248


In [35]:
results_saliency_noisy_correct.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.702795,0.240100
minmax,AI,32.921810,47.089920
minmax,AD,27.989511,33.971680
minmax,AG,18.379478,30.785431
minmax,FidIn,0.806584,0.395791
...,...,...,...
topK_50_pos,AD,18.863148,38.634342
topK_50_pos,AG,0.002535,0.005069
topK_50_pos,FidIn,0.806584,0.395791
topK_50_pos,SPS,0.533131,0.030316


In [36]:
results_saliency_noisy_incorrect.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.442023,0.165640
minmax,AI,70.700638,45.659199
minmax,AD,8.545042,16.048777
minmax,AG,31.942146,28.320709
minmax,FidIn,0.866242,0.341481
...,...,...,...
topK_50_pos,AD,79.889908,38.179180
topK_50_pos,AG,0.000239,0.000913
topK_50_pos,FidIn,0.184713,0.389307
topK_50_pos,SPS,0.517435,0.021552


In [38]:
results_gradcam_noisy = predict_gradcam.predict_set(test_loader_noisy_predict, results_csv_name='gradcam_seawaves_noise.csv', compute_first=True)

Results saved as results/gradcam_seawaves_noise.csv


In [39]:
results_gradcam_noisy.head(20)

Unnamed: 0_level_0,minmax,minmax,minmax,minmax,minmax,minmax,minmax,minmax,sigmoid,sigmoid,...,topK_30_pos,topK_30_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos,topK_50_pos
Unnamed: 0_level_1,FF,AI,AD,AG,FidIn,SPS,COMP,is_correct,FF,AI,...,COMP,is_correct,FF,AI,AD,AG,FidIn,SPS,COMP,is_correct
sample,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0,0.476334,0.0,97.802231,0.0,0.0,0.64794,9.56375,True,0.314671,0.0,...,9.56375,True,0.537681,0.0,85.824059,0.0,0.0,0.64794,9.56375,True
1,-4.012883e-05,0.0,91.222382,0.0,0.0,0.0,0.0,False,0.16042,0.0,...,0.0,False,0.172616,100.0,0.0,0.004949,1.0,0.0,0.0,False
2,-0.02345115,0.0,83.154335,0.0,0.0,0.321874,10.106059,False,0.002914,0.0,...,10.106059,False,0.258758,0.0,92.001015,0.0,0.0,0.321874,10.106059,False
3,0.3930327,0.0,58.654831,0.0,1.0,0.277605,10.195515,False,0.386653,0.0,...,10.195515,False,0.333662,0.0,73.643013,0.0,0.0,0.277605,10.195515,False
4,0.2606978,0.0,97.560989,0.0,0.0,0.585009,9.762238,True,0.240305,0.0,...,9.762238,True,0.318437,0.0,29.380491,0.0,1.0,0.585009,9.762238,True
5,0.3950965,0.0,61.800247,0.0,0.0,0.1706,10.292536,False,0.362871,0.0,...,10.292536,False,0.245136,0.0,87.607727,0.0,0.0,0.1706,10.292536,False
6,-1.361966e-05,0.0,95.387383,0.0,0.0,0.0,0.0,True,0.002667,0.0,...,0.0,True,0.251259,100.0,0.0,0.001849,1.0,0.0,0.0,True
7,0.1259353,0.0,63.357067,0.0,1.0,0.751976,9.198015,True,0.443178,0.0,...,9.198015,True,0.849337,100.0,0.0,0.006475,1.0,0.751976,9.198015,True
8,-0.1393031,0.0,96.659081,0.0,0.0,0.771719,9.074892,True,-0.255667,100.0,...,9.074892,True,0.280749,0.0,0.007587,0.0,1.0,0.771719,9.074892,True
9,0.2219657,0.0,59.102833,0.0,0.0,0.225913,10.218206,False,0.188602,0.0,...,10.218206,False,0.132291,0.0,59.431396,0.0,0.0,0.225913,10.218206,False


In [40]:
is_corr = results_gradcam_noisy[('minmax', 'is_correct')]
results_gradcam_noisy_correct   = results_gradcam_noisy[is_corr].drop(columns=[('minmax','is_correct')])
results_gradcam_noisy_incorrect = results_gradcam_noisy[~is_corr].drop(columns=[('minmax','is_correct')])

In [41]:
results_gradcam_noisy.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.165330,0.232420
minmax,AI,3.250000,17.754593
minmax,AD,69.801506,30.416687
minmax,AG,0.626069,4.165296
minmax,FidIn,0.377500,0.485369
...,...,...,...
topK_50_pos,AD,40.081612,35.412453
topK_50_pos,AG,1.043078,5.713395
topK_50_pos,FidIn,0.647500,0.478347
topK_50_pos,SPS,0.495021,0.224219


In [42]:
results_gradcam_noisy_correct.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.144273,0.235666
minmax,AI,2.880658,16.760778
minmax,AD,69.933640,33.136490
minmax,AG,0.529573,3.816485
minmax,FidIn,0.423868,0.495190
...,...,...,...
topK_50_pos,AD,33.615986,35.610287
topK_50_pos,AG,1.134403,6.211542
topK_50_pos,FidIn,0.769547,0.421992
topK_50_pos,SPS,0.545768,0.218164


In [43]:
results_gradcam_noisy_incorrect.describe().T[["mean", "std"]]

Unnamed: 0,Unnamed: 1,mean,std
minmax,FF,0.197920,0.224162
minmax,AI,3.821656,19.233219
minmax,AD,69.597000,25.746723
minmax,AG,0.775424,4.662930
minmax,FidIn,0.305732,0.462191
...,...,...,...
topK_50_pos,AD,50.088921,32.777699
topK_50_pos,AG,0.901727,4.858378
topK_50_pos,FidIn,0.458599,0.499877
topK_50_pos,SPS,0.416475,0.211024
