# Saliency interp

In [1]:
from audiointerp.dataset.esc50 import ESC50dataset
from audiointerp.model.cnn14 import TransferCnn14
from audiointerp.interpretation.saliency import SaliencyInterpreter
from audiointerp.metrics import Metrics
from audiointerp.processing.spectrogram import LogMelSTFTSpectrogram
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
root_dir = "/home/yuliya/ESC50"
sr = 32000
test_folds = [5]

In [4]:
n_fft = 1024
hop_length = 320
win_length = 1024
n_mels = 64
f_min = 50
f_max = 14000
top_db = 80

In [5]:
feature_extractor = LogMelSTFTSpectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=2.0,
                                          n_mels=n_mels, sample_rate=sr, f_min=f_min, f_max=f_max, top_db=80,
                                          return_phase=False, return_full_db=False)

In [6]:
test_data = ESC50dataset(root_dir=root_dir, sr=sr, folds=test_folds, feature_extractor=feature_extractor)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [7]:
model = TransferCnn14(50)
model.to(device)
model.load_state_dict(torch.load("best.pth"))

<All keys matched successfully>

In [8]:
criterion = nn.CrossEntropyLoss()

In [9]:
def valid_step(model, criterion, dataloader, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    
    with torch.no_grad():
        for samples, labels in dataloader:
            samples = samples.to(device)
            labels = labels.to(device)
            
            outputs = model(samples)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * samples.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += samples.size(0)
    
    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects.double() / total_samples
    
    return epoch_loss, epoch_acc.item()

In [10]:
test_loss, test_acc = valid_step(model, criterion, test_loader, device)
print(f"Test loss: {test_loss:.2f}, Test acc: {test_acc:.2f}")

Test loss: 0.59, Test acc: 0.85


In [11]:
mm = 0
dd = 0
cc = 0

In [12]:
def interpret_and_evaluate(model, test_loader, device):
    model.eval()
    model.to(device)
    
    saliency = SaliencyInterpreter(model)
    results = {
        "FF": [],
        "AI": [],
        "AD": [],
        "AG": [],
        "FidIn": [],
        "SPS": [],
        "COMP": []
    }
    
    for inputs, _ in test_loader:
        inputs = inputs.to(device).requires_grad_(True)
        
        logits = model(inputs)
        predicted_class = logits.argmax(dim=1)

        intr, mask = saliency.interpret(inputs).values()

        unmasked_inputs = (inputs - inputs.amin(dim=(1, 2, 3), keepdim=True)) * (1 - mask) + inputs.amin(dim=(1, 2, 3), keepdim=True)
        masked_inputs = (inputs - inputs.amin(dim=(1, 2, 3), keepdim=True)) * mask + inputs.amin(dim=(1, 2, 3), keepdim=True)

    
        with torch.no_grad():
            logits_original = model(inputs)
            logits_masked = model(masked_inputs)
            logits_unmasked = model(unmasked_inputs)
        
        
        ff = Metrics.compute_FF(logits=logits_original, logits_out=logits_unmasked)
        ai = Metrics.compute_AI(logits=logits_original, logits_in=logits_masked)
        ad = Metrics.compute_AD(logits=logits_original, logits_in=logits_masked)
        ag = Metrics.compute_AG(logits=logits_original, logits_in=logits_masked)
        fidin = Metrics.compute_FidIn(logits=logits_original, logits_in=logits_masked)
        sps = Metrics.compute_SPS(inputs, mask, logits, device)
        comp = Metrics.compute_COMP(inputs, mask, logits, device)

        results["FF"].append(ff.cpu())
        results["AI"].append(ai.cpu())
        results["AD"].append(ad.cpu())
        results["AG"].append(ag.cpu())
        results["FidIn"].append(fidin.cpu())
        results["SPS"].append(torch.tensor(sps).cpu())
        results["COMP"].append(torch.tensor(comp).cpu())
    
    for m in results:
        results[m] = torch.cat(results[m])

    return results

In [15]:
results = interpret_and_evaluate(model, test_loader, device)

In [16]:
results

{'FF': tensor([ 4.3904e+00,  2.6829e+00,  3.6084e+00,  5.6793e+00,  5.9127e+00,
          3.9587e+00,  9.3576e-01,  1.7623e+00,  3.3140e+00,  2.1677e+00,
          5.7285e-01,  6.0738e+00,  1.8183e+00,  1.2579e+00,  2.4334e+00,
          1.9704e+00,  5.6164e+00,  4.5580e+00,  4.5498e+00,  1.6505e+00,
          5.5468e-01,  5.1951e+00,  6.0582e+00,  4.0452e+00,  2.2318e+00,
          5.4364e+00,  2.7789e+00, -6.1002e-03,  5.3026e+00,  4.0927e+00,
          1.5172e+00,  1.6669e+00,  1.3521e+00,  1.5379e+00, -5.0180e-01,
          3.2094e-02,  3.4421e+00,  1.3970e+00,  2.6572e+00, -4.2262e-01,
         -6.3466e-01,  6.0423e-01,  5.0393e+00,  5.1226e+00,  5.3212e+00,
          7.1753e+00,  6.1805e-02, -8.9615e-01,  1.6967e-01,  6.7778e+00,
          2.4765e-01,  8.2382e-02,  3.5105e+00, -2.8991e+00,  1.6720e+00,
          2.9569e+00,  5.4435e+00,  2.3113e+00,  1.0384e+01,  7.1298e-01,
          1.6891e+00,  6.3732e+00,  3.6298e+00,  1.7211e+00,  3.9837e+00,
          1.6318e+00,  5.6164e+0

In [17]:
for m in results:
    print(f"{m}: {results[m].mean().item()}")

FF: 2.3136532306671143
AI: 1.0
AD: 166.3847198486328
AG: 0.17359690368175507
FidIn: 0.07750000059604645
SPS: 0.5885762199361089
COMP: 9.701299477654178
