# Gradcam interp

In [1]:
from audiointerp.dataset.esc50 import ESC50dataset
from audiointerp.model.cnn14 import TransferCnn14
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam import GradCAM
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
root_dir = "/home/yuliya/ESC50"
sr = 32000
test_folds = [5]

In [4]:
n_fft = 1024
hop_length = 320
win_length = 1024
n_mels = 64
f_min = 50
f_max = 14000

In [5]:
spec = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=2.0)
mel = T.MelScale(n_mels=n_mels, sample_rate=sr, f_min=f_min, f_max=f_max, n_stft=n_fft // 2 + 1)
amplitude_to_db = T.AmplitudeToDB(stype="power", top_db=80)

In [6]:
feature_extractor = nn.Sequential(spec, mel, amplitude_to_db)

In [7]:
test_data = ESC50dataset(root_dir=root_dir, sr=sr, folds=test_folds, feature_extractor=feature_extractor)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [8]:
model = TransferCnn14(50)
model.to(device)
model.load_state_dict(torch.load("best.pth"))

<All keys matched successfully>

In [9]:
criterion = nn.CrossEntropyLoss()

In [10]:
def valid_step(model, criterion, dataloader, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    
    with torch.no_grad():
        for samples, labels in dataloader:
            samples = samples.to(device)
            labels = labels.to(device)
            
            outputs = model(samples)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * samples.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += samples.size(0)
    
    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects.double() / total_samples
    
    return epoch_loss, epoch_acc.item()

In [11]:
test_loss, test_acc = valid_step(model, criterion, test_loader, device)
print(f"Test loss: {test_loss:.2f}, Test acc: {test_acc:.2f}")

Test loss: 0.57, Test acc: 0.88


In [12]:
model

TransferCnn14(
  (base): Cnn14(
    (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_block1): ConvBlock(
      (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_block2): ConvBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_block3): ConvBlock(
      (conv1): Conv2d(128, 

In [29]:
def compute_logit_difference(model, test_loader, device):
    model.eval()
    model.to(device)

    target_layers = [model.base.conv_block6.conv2]
    
    gradcam = GradCAM(model=model, target_layers=target_layers)
    results = []
    
    for inputs, _ in test_loader:
        inputs = inputs.to(device).requires_grad_(True)
        
        logits = model(inputs)
        predicted_class = logits.argmax(dim=1)
        targets = [ClassifierOutputTarget(int(pred)) for pred in predicted_class]

        attrs = gradcam(input_tensor=inputs, targets=targets)
        attrs = torch.from_numpy(attrs).unsqueeze(1).to(device)
        mask = F.sigmoid(attrs)

        masked_inputs = inputs * (1 - mask)
    
        logits_original = model(inputs)
        logits_masked = model(masked_inputs)
        
        logit_diff = logits_original.gather(1, predicted_class.unsqueeze(1)) - \
                     logits_masked.gather(1, predicted_class.unsqueeze(1))
        
        results.append(logit_diff.cpu().detach())
    
    return torch.cat(results)

In [30]:
results = compute_logit_difference(model, test_loader, device)

In [31]:
results.mean()

tensor(9.5587)

In [32]:
torch.sum(results > 0)

tensor(384)