# Saliency interp

In [14]:
from audiointerp.dataset.esc50 import ESC50dataset
from audiointerp.model.cnn14 import TransferCnn14
from audiointerp.interpretation.saliency import SaliencyInterpreter
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
root_dir = "/home/yuliya/ESC50"
sr = 32000
test_folds = [5]

In [17]:
n_fft = 1024
hop_length = 320
win_length = 1024
n_mels = 64
f_min = 50
f_max = 14000

In [18]:
spec = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=2.0)
mel = T.MelScale(n_mels=n_mels, sample_rate=sr, f_min=f_min, f_max=f_max, n_stft=n_fft // 2 + 1)
amplitude_to_db = T.AmplitudeToDB(stype="power", top_db=80)

In [19]:
feature_extractor = nn.Sequential(spec, mel, amplitude_to_db)

In [20]:
test_data = ESC50dataset(root_dir=root_dir, sr=sr, folds=test_folds, feature_extractor=feature_extractor)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [21]:
model = TransferCnn14(50)
model.to(device)
model.load_state_dict(torch.load("best.pth"))

<All keys matched successfully>

In [22]:
criterion = nn.CrossEntropyLoss()

In [23]:
def valid_step(model, criterion, dataloader, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    
    with torch.no_grad():
        for samples, labels in dataloader:
            samples = samples.to(device)
            labels = labels.to(device)
            
            outputs = model(samples)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * samples.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += samples.size(0)
    
    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects.double() / total_samples
    
    return epoch_loss, epoch_acc.item()

In [24]:
test_loss, test_acc = valid_step(model, criterion, test_loader, device)
print(f"Test loss: {test_loss:.2f}, Test acc: {test_acc:.2f}")

Test loss: 0.57, Test acc: 0.88


In [25]:
def compute_logit_difference(model, test_loader, device):
    model.eval()
    model.to(device)
    
    saliency = SaliencyInterpreter(model)
    results = []
    
    for inputs, _ in test_loader:
        inputs = inputs.to(device).requires_grad_(True)
        
        logits = model(inputs)
        predicted_class = logits.argmax(dim=1)

        mask = saliency.interpret(inputs)["masks"]

        masked_inputs = inputs * (1 - mask)
    
        logits_original = model(inputs)
        logits_masked = model(masked_inputs)
        
        logit_diff = logits_original.gather(1, predicted_class.unsqueeze(1)) - \
                     logits_masked.gather(1, predicted_class.unsqueeze(1))
        
        results.append(logit_diff.cpu().detach())
    
    return torch.cat(results)

In [26]:
results = compute_logit_difference(model, test_loader, device)

In [27]:
results.mean()

tensor(6.7708)

In [28]:
torch.sum(results > 0)

tensor(364)