# Saliency interp

In [1]:
from audiointerp.dataset.esc50 import ESC50dataset
import torch.nn as nn
import torch.nn.functional as TF
import torch.optim as optim
import torchaudio.transforms as T
from IPython.display import Audio
import torch
from torch.utils.data import DataLoader
import os
from collections import OrderedDict
from tqdm import tqdm
from captum.attr import Saliency

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# utility functions
def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
    
def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)

In [3]:
# convblock
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        
        x = input
        x = TF.relu_(self.bn1(self.conv1(x)))
        x = TF.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = TF.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = TF.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = TF.avg_pool2d(x, kernel_size=pool_size)
            x2 = TF.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x

In [4]:
# Cnn14
class Cnn14(nn.Module):
    def __init__(self):
        
        super(Cnn14, self).__init__()

        self.bn0 = nn.BatchNorm2d(64)

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)

        self.fc1 = nn.Linear(2048, 2048, bias=True)
        
        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_layer(self.fc1)
 
    def forward(self, input):
        # (batch_size, 1, mel_bins, timesteps)

        x = input.transpose(2, 3)
        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        x = TF.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = TF.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = TF.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
        x = TF.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
        x = TF.dropout(x, p=0.2, training=self.training)
        x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
        x = TF.dropout(x, p=0.2, training=self.training)
        x = torch.mean(x, dim=3)
        
        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = TF.dropout(x, p=0.5, training=self.training)
        x = TF.relu_(self.fc1(x))
        embedding = TF.dropout(x, p=0.5, training=self.training)

        return embedding

In [5]:
# model for transfer
class TransferModel(nn.Module):
    def __init__(self, embedder, num_classes):

        super(TransferModel, self).__init__()
        
        self.base = embedder()
        emb_dim = self.base.fc1.out_features
        self.classifier = nn.Linear(in_features=emb_dim, out_features=num_classes)

    def load_base_weights(self, path_to_weights):
        weights_full = torch.load(path_to_weights)["model"]

        weights = OrderedDict()
        for key, value in weights_full.items():
            if key.startswith(("logmel_extractor", "spectrogram_extractor", "fc_audioset")):
                continue
            weights[key] = value

        self.base.load_state_dict(weights)
        

    def forward(self, input):
        embedding = self.base(input)
        logits = self.classifier(embedding)

        return logits

In [6]:
root_dir = "/home/yuliya/ESC50"
sr = 32000
test_folds = [5]

In [7]:
n_fft = 1024
hop_length = 320
win_length = 1024
n_mels = 64
f_min = 50
f_max = 14000

In [8]:
spec = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=2.0)
mel = T.MelScale(n_mels=n_mels, sample_rate=sr, f_min=f_min, f_max=f_max, n_stft=n_fft // 2 + 1)
amplitude_to_db = T.AmplitudeToDB(stype="power", top_db=80)

In [9]:
feature_extractor = nn.Sequential(spec, mel, amplitude_to_db)

In [10]:
test_data = ESC50dataset(root_dir=root_dir, sr=sr, folds=test_folds, feature_extractor=feature_extractor)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [11]:
def compute_logit_difference(model, test_loader, device):
    model.eval()
    model.to(device)
    
    saliency = Saliency(model)
    results = []
    
    for inputs, _ in test_loader:
        inputs = inputs.to(device).requires_grad_(True)
        
        logits = model(inputs)
        predicted_class = logits.argmax(dim=1)

        attributions = saliency.attribute(inputs, target=predicted_class)
        
        mask = torch.sigmoid(attributions)
        masked_inputs = inputs * (1 - mask)
    
        logits_original = model(inputs)
        logits_masked = model(masked_inputs)
        
        logit_diff = logits_original.gather(1, predicted_class.unsqueeze(1)) - \
                     logits_masked.gather(1, predicted_class.unsqueeze(1))
        
        results.append(logit_diff.cpu().detach())
    
    return torch.cat(results)

In [12]:
model = TransferModel(Cnn14, 50)
model.load_state_dict(torch.load("best.pth"))

<All keys matched successfully>

In [13]:
criterion = nn.CrossEntropyLoss()

In [14]:
def valid_step(model, criterion, dataloader, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += images.size(0)
    
    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects.double() / total_samples
    
    return epoch_loss, epoch_acc.item()

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
model = model.to(device)

In [17]:
test_loss, test_acc = valid_step(model, criterion, test_loader, device)
test_loss, test_acc

(0.43132887959480287, 0.8775000000000001)

In [18]:
results = compute_logit_difference(model, test_loader, device)

In [19]:
results.mean()

tensor(1.8695)

In [21]:
torch.sum(results > 0)

tensor(320)