In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time
import os
import sys
import copy
import datetime
import random
import math
import warnings
warnings.filterwarnings('ignore')

In [11]:
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from torchaudio import models, transforms
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn

In [4]:
import matplotlib
# matplotlib.use('Agg')
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt 

matplotlib.rcParams['lines.linewidth'] = 1
matplotlib.rcParams['lines.markersize'] = 5

In [5]:
import numpy as np
import IPython.display as ipd
from tqdm import tqdm

In [8]:
sys.path.append("/home/geshi/ChaosMining")

from chaosmining.data_utils import ChaosAudioDataset
from chaosmining.utils import check_make_dir
from chaosmining.audio import parse_argument

from captum.attr import IntegratedGradients, Saliency, DeepLift, FeatureAblation, visualization

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
save_flag = False
n_channels = 10

# Noisy Audio Dataset

In [27]:
# Create training and testing split of the data. We do not use validation in this tutorial.
train_set = ChaosAudioDataset('/data/home/geshi/ChaosMining/data/audio/RBFP', "train")
val_set = ChaosAudioDataset('/data/home/geshi/ChaosMining/data/audio/RBFP', "val")

waveform, label, sample_rate = train_set[0]

In [28]:
print("Shape of waveform: {}".format(waveform.size()))
print("Sample rate of waveform: {}".format(sample_rate))
print("class of waveform: {}".format(label))
print("dataset size train {}, val {}".format(len(train_set), len(val_set)))

Shape of waveform: torch.Size([10, 16000])
Sample rate of waveform: 16000
class of waveform: 0
dataset size train 84843, val 9981


In [32]:
labels = sorted(train_set.classes)
print("number of classes", len(labels))

number of classes 35


In [37]:
ipd.Audio(waveform[0].numpy(), rate=sample_rate)

# Generate Synthetic Audio Data

In [38]:
def label_to_index(word):
    # Return the position of the word in labels
    return torch.tensor(labels.index(word))

def index_to_label(index):
    # Return the word corresponding to the index in labels
    # This is the inverse of label_to_index
    return labels[index]

word_start = "yes"
index = label_to_index(word_start)
word_recovered = index_to_label(index)

print(word_start, "-->", index, "-->", word_recovered)

yes --> tensor(33) --> yes


In [55]:
def collate_fn(batch):

    # A data tuple has the form:
    # waveform, sample_rate, label, speaker_id, utterance_number

    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for waveform, label, *_ in batch:
        tensors += [waveform]
        targets += [label]

    # Group the list of tensors into a batched tensor
    tensors = torch.stack(tensors)
    targets = torch.tensor(targets)

    return tensors, targets

In [56]:
batch_size = 64

if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

# Define Network

In [57]:
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = torch.squeeze(x)
        x = self.fc1(x)
        return x

In [58]:
model = M5(n_input=n_channels, n_output=len(labels))
model.to(device)
print(model)

M5(
  (conv1): Conv1d(10, 32, kernel_size=(80,), stride=(16,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=35, bias=True)
)


In [59]:
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)  # reduce the 
criterion = nn.CrossEntropyLoss()

In [63]:
def train(model, epoch, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device)
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        output = model(data)

        # negative log-likelihood for a tensor of size (batch x 1 x n_output)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print training stats
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

        # update progress bar
        pbar.update(pbar_update)
        # record loss
        losses.append(loss.item())

In [66]:
def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    return tensor.argmax(dim=-1)


def test(model, epoch):
    model.eval()
    correct = 0
    for data, target in val_loader:

        data = data.to(device)
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        output = model(data)

        pred = get_likely_index(output)
        correct += number_of_correct(pred, target)

        # update progress bar
        pbar.update(pbar_update)

    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(val_loader.dataset)} ({100. * correct / len(val_loader.dataset):.0f}%)\n")

In [68]:
log_interval = 20
n_epoch = 2

pbar_update = 1 / (len(train_loader) + len(val_loader))
losses = []

with tqdm(total=n_epoch) as pbar:
    for epoch in range(1, n_epoch + 1):
        train(model, epoch, log_interval)
        test(model, epoch)
        scheduler.step()

  0%|                        | 0.001349527665317139/2 [00:00<05:15, 158.02s/it]



  1%|▏                       | 0.014170040485829956/2 [00:02<04:45, 143.79s/it]



  1%|▎                       | 0.029014844804318474/2 [00:04<04:59, 151.81s/it]



  2%|▌                        | 0.04116059379217272/2 [00:06<04:43, 144.91s/it]



  3%|▋                       | 0.055330634278002666/2 [00:08<04:50, 149.27s/it]



  3%|▊                        | 0.06882591093117411/2 [00:10<05:08, 159.62s/it]



  4%|█                        | 0.08164642375168706/2 [00:12<04:54, 153.66s/it]



  5%|█▏                       | 0.09514170040485857/2 [00:14<04:58, 156.44s/it]



  5%|█▎                       | 0.10931174089068867/2 [00:16<04:52, 154.58s/it]



  6%|█▌                       | 0.12280701754386018/2 [00:19<05:15, 168.05s/it]



  7%|█▊                        | 0.1363022941970317/2 [00:21<05:15, 169.06s/it]



  7%|█▊                       | 0.14912280701754463/2 [00:23<05:02, 163.59s/it]



  8%|██                       | 0.16329284750337472/2 [00:25<04:49, 157.79s/it]



  9%|██▏                      | 0.17678812415654624/2 [00:28<04:51, 159.96s/it]



 10%|██▍                      | 0.19028340080971776/2 [00:30<04:51, 161.29s/it]



 10%|██▌                      | 0.20377867746288927/2 [00:32<04:52, 162.97s/it]



 11%|██▋                      | 0.21659919028340222/2 [00:34<05:05, 171.41s/it]



 12%|███                       | 0.2307692307692323/2 [00:36<04:59, 169.45s/it]



 12%|███                      | 0.24426450742240383/2 [00:39<04:43, 161.75s/it]



 13%|███▎                      | 0.2577597840755753/2 [00:41<05:09, 177.84s/it]



 14%|███▍                     | 0.27125506072874683/2 [00:43<05:03, 175.73s/it]



 14%|███▌                     | 0.28475033738191835/2 [00:46<05:04, 177.25s/it]



 15%|███▋                     | 0.29824561403508987/2 [00:48<05:22, 189.24s/it]



 16%|████                      | 0.3117408906882614/2 [00:50<04:19, 153.91s/it]



 16%|████▏                     | 0.3252361673414329/2 [00:53<05:03, 180.94s/it]



 17%|████▍                     | 0.3387314439946044/2 [00:55<04:19, 156.03s/it]



 18%|████▍                    | 0.35222672064777594/2 [00:57<04:41, 170.84s/it]



 18%|████▌                    | 0.36572199730094745/2 [00:59<04:28, 164.53s/it]



 19%|████▋                    | 0.37921727395411897/2 [01:02<05:04, 188.00s/it]



 20%|█████                     | 0.3927125506072905/2 [01:04<04:20, 161.93s/it]



 20%|█████▍                     | 0.406207827260462/2 [01:07<05:00, 188.61s/it]



 21%|█████▍                    | 0.4197031039136335/2 [01:09<04:35, 174.04s/it]



 22%|█████▍                   | 0.43319838056680504/2 [01:11<04:39, 178.49s/it]



 22%|██████                     | 0.446018893387318/2 [01:14<04:24, 170.44s/it]



 23%|█████▉                    | 0.4601889338731481/2 [01:16<04:42, 183.57s/it]



 24%|██████▏                   | 0.4736842105263196/2 [01:18<04:35, 180.48s/it]



 24%|██████▎                   | 0.4871794871794911/2 [01:21<04:26, 175.89s/it]



 25%|██████▌                   | 0.5006747638326625/2 [01:24<04:31, 181.25s/it]



 26%|██████▋                   | 0.5141700404858329/2 [01:26<04:23, 177.08s/it]



 26%|██████▊                   | 0.5276653171390033/2 [01:28<04:41, 191.04s/it]



 27%|███████                   | 0.5411605937921737/2 [01:31<04:21, 179.55s/it]



 28%|███████▏                  | 0.5546558704453441/2 [01:33<04:32, 188.28s/it]



 28%|███████▍                  | 0.5681511470985146/2 [01:36<04:23, 184.26s/it]



 29%|███████▊                   | 0.581646423751685/2 [01:38<04:58, 210.48s/it]



 30%|███████▋                  | 0.5951417004048554/2 [01:41<04:42, 200.94s/it]



 30%|███████▉                  | 0.6086369770580258/2 [01:44<04:31, 194.92s/it]



 31%|████████                  | 0.6221322537111962/2 [01:47<05:00, 218.43s/it]



 32%|████████▎                 | 0.6356275303643666/2 [01:49<04:38, 203.99s/it]



 32%|████████▊                  | 0.649122807017537/2 [01:52<04:42, 208.78s/it]



 33%|████████▌                 | 0.6626180836707074/2 [01:55<04:38, 208.55s/it]



 34%|████████▊                 | 0.6761133603238778/2 [01:58<04:37, 209.46s/it]



 34%|████████▉                 | 0.6896086369770482/2 [02:00<04:49, 220.84s/it]



 35%|█████████▏                | 0.7031039136302186/2 [02:03<04:18, 199.61s/it]



 36%|█████████▋                 | 0.716599190283389/2 [02:06<04:15, 199.28s/it]



 37%|█████████▍                | 0.7300944669365594/2 [02:09<04:31, 213.49s/it]



 37%|█████████▋                | 0.7435897435897298/2 [02:12<04:26, 212.27s/it]



 38%|█████████▊                | 0.7570850202429003/2 [02:15<04:34, 221.12s/it]



 39%|██████████                | 0.7705802968960707/2 [02:18<04:21, 212.63s/it]



 39%|██████████▏               | 0.7840755735492411/2 [02:21<04:09, 205.25s/it]



 40%|██████████▎               | 0.7975708502024115/2 [02:24<04:29, 223.88s/it]



 41%|██████████▌               | 0.8110661268555819/2 [02:27<04:34, 230.73s/it]



 41%|██████████▋               | 0.8245614035087523/2 [02:30<04:36, 234.92s/it]



 42%|██████████▉               | 0.8380566801619227/2 [02:33<04:20, 224.03s/it]



 43%|███████████               | 0.8515519568150931/2 [02:36<04:36, 240.81s/it]



 43%|███████████▏              | 0.8650472334682635/2 [02:40<04:39, 245.94s/it]



 44%|███████████▍              | 0.8785425101214339/2 [02:43<04:49, 258.34s/it]



 45%|███████████▌              | 0.8913630229419458/2 [02:46<04:56, 267.08s/it]



 50%|█████████████▌             | 1.0020242914979431/2 [02:52<00:51, 51.71s/it]


Test Epoch: 1	Accuracy: 5772/9981 (58%)



 51%|██████████████▏             | 1.016194331983772/2 [02:53<01:09, 70.79s/it]



 52%|██████████████▍             | 1.030364372469601/2 [02:54<01:11, 74.14s/it]



 52%|██████████████             | 1.0425101214574544/2 [02:55<01:18, 81.74s/it]



 53%|██████████████▎            | 1.0566801619432833/2 [02:56<01:11, 75.80s/it]



 54%|██████████████▍            | 1.0701754385964537/2 [02:58<01:14, 80.03s/it]



 54%|███████████████▏            | 1.082321187584307/2 [02:59<01:18, 85.05s/it]



 55%|███████████████▎            | 1.096491228070136/2 [03:00<01:16, 84.50s/it]



 55%|██████████████▉            | 1.1093117408906479/2 [03:01<01:13, 82.59s/it]



 56%|███████████████▏           | 1.1241565452091353/2 [03:02<01:09, 79.20s/it]



 57%|███████████████▎           | 1.1376518218623057/2 [03:03<01:07, 78.33s/it]



 58%|███████████████▌           | 1.1511470985154761/2 [03:04<01:05, 76.99s/it]



 58%|███████████████▋           | 1.1646423751686465/2 [03:05<01:00, 71.87s/it]



 59%|███████████████▉           | 1.1788124156544755/2 [03:06<00:58, 71.38s/it]



 60%|████████████████           | 1.1909581646423288/2 [03:07<00:59, 73.53s/it]



 60%|████████████████▎          | 1.2044534412954992/2 [03:08<00:59, 74.40s/it]



 61%|████████████████▍          | 1.2186234817813282/2 [03:09<00:58, 74.45s/it]



 62%|█████████████████▊           | 1.23144399460184/2 [03:10<00:56, 73.12s/it]



 62%|████████████████▊          | 1.2449392712550105/2 [03:11<00:55, 72.89s/it]



 63%|████████████████▉          | 1.2577597840755224/2 [03:12<00:56, 75.91s/it]



 64%|█████████████████▏         | 1.2719298245613513/2 [03:13<00:56, 77.17s/it]



 64%|█████████████████▎         | 1.2860998650471802/2 [03:14<00:54, 76.65s/it]



 65%|█████████████████▌         | 1.2995951417003506/2 [03:15<00:54, 77.45s/it]



 66%|██████████████████▍         | 1.313090418353521/2 [03:16<00:53, 77.46s/it]



 66%|█████████████████▉         | 1.3265856950066914/2 [03:17<00:50, 74.34s/it]



 67%|██████████████████         | 1.3394062078272033/2 [03:18<00:50, 77.14s/it]



 68%|██████████████████▎        | 1.3529014844803737/2 [03:19<00:49, 76.26s/it]



 68%|██████████████████▍        | 1.3670715249662027/2 [03:20<00:51, 80.61s/it]



 69%|██████████████████▋        | 1.3798920377867145/2 [03:21<00:49, 79.39s/it]



 70%|███████████████████▌        | 1.393387314439885/2 [03:22<00:43, 72.52s/it]



 70%|███████████████████        | 1.4075573549257139/2 [03:23<00:42, 71.16s/it]



 71%|███████████████████▏       | 1.4217273954115428/2 [03:24<00:42, 73.02s/it]



 72%|███████████████████▎       | 1.4338731443993962/2 [03:25<00:43, 76.65s/it]



 72%|███████████████████▌       | 1.4473684210525666/2 [03:26<00:48, 88.60s/it]



 73%|███████████████████▋       | 1.4615384615383955/2 [03:28<00:46, 85.67s/it]



 74%|████████████████████▋       | 1.475033738191566/2 [03:29<00:46, 89.52s/it]



 74%|████████████████████       | 1.4885290148447363/2 [03:30<00:43, 85.05s/it]



 75%|████████████████████▎      | 1.5006747638325897/2 [03:31<00:43, 87.22s/it]



 76%|████████████████████▍      | 1.5155195681510771/2 [03:32<00:42, 87.63s/it]



 76%|████████████████████▋      | 1.5290148448042475/2 [03:34<00:40, 86.88s/it]



 77%|█████████████████████▌      | 1.542510121457418/2 [03:35<00:39, 86.53s/it]



 78%|█████████████████████      | 1.5560053981105884/2 [03:36<00:39, 89.22s/it]



 78%|█████████████████████▏     | 1.5695006747637588/2 [03:37<00:38, 89.17s/it]



 79%|█████████████████████▎     | 1.5829959514169292/2 [03:38<00:36, 86.42s/it]



 80%|█████████████████████▌     | 1.5964912280700996/2 [03:40<00:36, 90.00s/it]



 80%|███████████████████████▎     | 1.60998650472327/2 [03:41<00:34, 87.70s/it]



 81%|█████████████████████▉     | 1.6234817813764404/2 [03:42<00:32, 87.15s/it]



 82%|██████████████████████     | 1.6369770580296108/2 [03:43<00:33, 91.82s/it]



 83%|██████████████████████▎    | 1.6504723346827812/2 [03:44<00:29, 85.29s/it]



 83%|██████████████████████▍    | 1.6639676113359516/2 [03:46<00:29, 89.11s/it]



 84%|███████████████████████▍    | 1.677462887989122/2 [03:47<00:27, 85.51s/it]



 84%|██████████████████████▊    | 1.6896086369769754/2 [03:48<00:29, 93.64s/it]



 85%|███████████████████████    | 1.7044534412954628/2 [03:49<00:26, 88.92s/it]



 86%|███████████████████████▏   | 1.7179487179486332/2 [03:50<00:24, 87.46s/it]



 87%|███████████████████████▎   | 1.7314439946018036/2 [03:51<00:24, 90.33s/it]



 87%|████████████████████████▍   | 1.744939271254974/2 [03:53<00:23, 90.77s/it]



 88%|███████████████████████▋   | 1.7584345479081445/2 [03:54<00:22, 93.78s/it]



 89%|███████████████████████▉   | 1.7719298245613149/2 [03:55<00:20, 89.09s/it]



 89%|████████████████████████   | 1.7854251012144853/2 [03:56<00:19, 91.01s/it]



 90%|████████████████████████▎  | 1.7989203778676557/2 [03:57<00:18, 90.43s/it]



 91%|█████████████████████████▎  | 1.812415654520826/2 [03:59<00:16, 89.51s/it]



 91%|████████████████████████▋  | 1.8259109311739965/2 [04:00<00:15, 87.28s/it]



 92%|█████████████████████████▊  | 1.839406207827167/2 [04:01<00:14, 87.67s/it]



 93%|████████████████████████▉  | 1.8515519568150203/2 [04:02<00:13, 89.32s/it]



 93%|█████████████████████████▏ | 1.8657219973008492/2 [04:03<00:12, 95.00s/it]



 94%|█████████████████████████▎ | 1.8792172739540196/2 [04:05<00:10, 89.94s/it]



 95%|███████████████████████████▍ | 1.89271255060719/2 [04:06<00:09, 88.70s/it]



100%|█████████████████████████▉| 1.9999999999998948/2 [04:12<00:00, 126.03s/it]


Test Epoch: 2	Accuracy: 6009/9981 (60%)






In [126]:
def predict(tensor):
    # Use the model to predict the label of the waveform
    tensor = tensor.to(device)
    tensor = model(tensor.unsqueeze(0))
    tensor = get_likely_index(tensor)
    tensor = index_to_label(tensor.squeeze())
    return tensor


waveform, _, sample_rate = train_set[-1]
ipd.Audio(waveform.numpy(), rate=sample_rate)

print(f"Expected: {utterance}. Predicted: {predict(waveform)}.")

Expected: 16000. Predicted: backward.


In [70]:
ig = IntegratedGradients(model)
sa = Saliency(model)
dl = DeepLift(model)
fa = FeatureAblation(model)

In [71]:
dataiter = iter(val_loader)
data, target = next(dataiter)
data = data.to(device)
target = target.to(device)

In [72]:
outputs = model(data)

In [73]:
outputs.shape

torch.Size([64, 35])

In [74]:
preds = outputs.argmax(dim=-1)

In [75]:
preds.shape

torch.Size([64])

## SA

In [76]:
sa_attr = sa.attribute(data, preds)

In [78]:
sa_attr.shape

torch.Size([64, 10, 16000])

In [95]:
avg_sa_attr = sa_attr.abs().mean(-1).mean(0).detach().cpu().numpy()

In [118]:
avg_sa_attr/np.linalg.norm(avg_sa_attr, 1)

array([0.6493612 , 0.03742831, 0.04198129, 0.03836196, 0.03845372,
       0.0374977 , 0.04018834, 0.0392105 , 0.0384088 , 0.03910828],
      dtype=float32)

## DL

In [97]:
dl_attr = dl.attribute(data, torch.zeros_like(data).to(device), preds)

In [98]:
dl_attr.shape

torch.Size([64, 10, 16000])

In [99]:
avg_dl_attr = dl_attr.abs().mean(-1).mean(0).detach().cpu().numpy()

In [119]:
avg_dl_attr/np.linalg.norm(avg_dl_attr, 1)

array([0.30607808, 0.07396897, 0.08260281, 0.07592342, 0.07581228,
       0.07408911, 0.07898825, 0.07813543, 0.07642604, 0.07797565],
      dtype=float32)

## IG

In [106]:
ig_attr = ig.attribute(data, torch.zeros_like(data).to(device), preds)

In [107]:
ig_attr.shape

torch.Size([64, 10, 16000])

In [108]:
avg_ig_attr = ig_attr.abs().mean(-1).mean(0).detach().cpu().numpy()

In [120]:
avg_ig_attr/np.linalg.norm(avg_ig_attr, 1)

array([0.38590728, 0.06523575, 0.07424971, 0.0675805 , 0.06741073,
       0.06510134, 0.07118871, 0.06848693, 0.06662644, 0.06821263])

## FA

In [121]:
feature_mask = np.arange(n_channels)

In [135]:
feature_mask = feature_mask[np.newaxis,:,np.newaxis]

In [136]:
feature_mask = feature_mask.repeat(sample_rate, axis=-1).repeat(batch_size, axis=0)

In [139]:
feature_mask=torch.from_numpy(feature_mask)

In [140]:
fa_attr = fa.attribute(data, torch.zeros_like(data).to(device), target=preds, feature_mask=feature_mask.to(device))

In [141]:
fa_attr.shape

torch.Size([64, 10, 16000])

In [142]:
avg_fa_attr = fa_attr.abs().mean(-1).mean(0).detach().cpu().numpy()

In [None]:
avg_fa_attr/np.linalg.norm(avg_fa_attr, 1)