# Import and misc

In [None]:
!pip install torchaudio==0.9.1

In [None]:
from typing import Tuple
from tqdm import tqdm
from itertools import islice
import dataclasses

import numpy as np

import torch
import torch.nn.functional as F
from torch import nn
from torch import distributions
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch.nn.utils.rnn import pad_sequence

import torchaudio
from IPython import display as display_

# Task

In this notebook we will implement a model for finding a keyword in a stream.

We will implement the version with CRNN because it is easy and improves the model. 
(from https://www.dropbox.com/s/22ah2ba7dug6pzw/KWS_Attention.pdf)

In [None]:
@dataclasses.dataclass
class TaskConfig:
    keyword: str = 'sheila'  # We will use 1 key word -- 'sheila'
    batch_size: int = 256
    learning_rate: float = 3e-4
    weight_decay: float = 1e-5
    num_epochs: int = 25
    n_mels: int = 40
    kernel_size: Tuple[int, int] = (20, 5)
    stride: Tuple[int, int] = (8, 2)
    hidden_size: int = 128
    gru_num_layers: int = 2
    bidirectional: bool = False
    num_classes: int = 2
    sample_rate: int = 16000
    device: torch.device = torch.device(
        'cuda:0' if torch.cuda.is_available() else 'cpu')

# Data

In [None]:
!wget https://gist.githubusercontent.com/Kirili4ik/6ac5c745ff8dad094e9c464c08f66f3e/raw/63daacc17f52a7d90f7f4166a3f5deef62b165db/dataset_utils.py

In [None]:
from dataset_utils import DatasetDownloader, TrainDataset

dataset_downloader = DatasetDownloader(TaskConfig.keyword)
labeled_data, _ = dataset_downloader.generate_labeled_data()

labeled_data.sample(3)

### Augmentations

In [None]:
class AugsCreation:

    def __init__(self):
        self.background_noises = [
            'speech_commands/_background_noise_/white_noise.wav',
            'speech_commands/_background_noise_/dude_miaowing.wav',
            'speech_commands/_background_noise_/doing_the_dishes.wav',
            'speech_commands/_background_noise_/exercise_bike.wav',
            'speech_commands/_background_noise_/pink_noise.wav',
            'speech_commands/_background_noise_/running_tap.wav'
        ]

    def add_rand_noise(self, audio):

        # randomly choose noise
        noise_num = torch.randint(low=0, high=len(
            self.background_noises), size=(1,)).item()
        noise = torchaudio.load(self.background_noises[noise_num])[0].squeeze()

        noise_level = torch.Tensor([1])  # [0, 40]

        noise_energy = torch.norm(noise)
        audio_energy = torch.norm(audio)
        alpha = (audio_energy / noise_energy) * \
            torch.pow(10, -noise_level / 20)

        start = torch.randint(low=0, high=int(
            noise.size(0) - audio.size(0) - 1), size=(1,)).item()
        noise_sample = noise[start: start + audio.size(0)]

        audio_new = audio + alpha * noise_sample
        audio_new.clamp_(-1, 1)
        return audio_new

    def __call__(self, wav):
        aug_num = torch.randint(low=0, high=4, size=(1,)).item()   # choose 1 random aug from augs
        augs = [
            lambda x: x,
            lambda x: (x + distributions.Normal(0, 0.01).sample(x.size())).clamp_(-1, 1),
            lambda x: torchaudio.transforms.Vol(.25)(x),
            lambda x: self.add_rand_noise(x)
        ]

        return augs[aug_num](wav)

In [None]:
indexes = torch.randperm(len(labeled_data))
train_indexes = indexes[:int(len(labeled_data) * 0.8)]
val_indexes = indexes[int(len(labeled_data) * 0.8):]

train_df = labeled_data.iloc[train_indexes].reset_index(drop=True)
val_df = labeled_data.iloc[val_indexes].reset_index(drop=True)

In [None]:
# Sample is a dict of utt, word and label
transform_tr = AugsCreation()
train_set = TrainDataset(df=train_df, kw=TaskConfig.keyword, transform=transform_tr)
val_set = TrainDataset(df=val_df, kw=TaskConfig.keyword)

### Sampler for oversampling:

In [None]:
# We should provide to WeightedRandomSampler _weight for every sample_; by default it is 1/len(target)

def get_sampler(target):
    class_sample_count = np.array(
        [len(np.where(target == t)[0]) for t in np.unique(target)])   # for every class count it's number of occ.
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in target])
    samples_weight = torch.from_numpy(samples_weight)
    samples_weigth = samples_weight.double()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    return sampler

In [None]:
train_sampler = get_sampler(train_set.df['label'].values)
val_sampler = get_sampler(val_set.df['label'].values)

In [None]:
class Collator:
    
    def __call__(self, data):
        wavs = []
        labels = []    

        for el in data:
            wavs.append(el['utt'])
            labels.append(el['label'])

        # torch.nn.utils.rnn.pad_sequence takes list(Tensors) and returns padded (with 0.0) Tensor
        wavs = pad_sequence(wavs, batch_first=True)    
        labels = torch.Tensor(labels).long()
        return wavs, labels

###  Dataloaders

In [None]:
# Here we are obliged to use shuffle=False because of our sampler with randomness inside.

train_loader = DataLoader(train_set, batch_size=TaskConfig.batch_size,
                          shuffle=False, collate_fn=Collator(),
                          sampler=train_sampler)
#                           num_workers=2, pin_memory=True)

val_loader = DataLoader(val_set, batch_size=TaskConfig.batch_size,
                        shuffle=False, collate_fn=Collator(),
                        sampler=val_sampler,
                        num_workers=2, pin_memory=True)

### Creating MelSpecs on GPU for speeeed: 

In [None]:
class LogMelspec():

    def __init__(self, is_train, config):
        # with augmentations
        if is_train:
            self.melspec = nn.Sequential(
                torchaudio.transforms.MelSpectrogram(
                    sample_rate=config.sample_rate,  n_mels=config.n_mels),
                torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
                torchaudio.transforms.TimeMasking(time_mask_param=35),
            ).to(config.device)

        # no augmentations
        else:
            self.melspec = torchaudio.transforms.MelSpectrogram(
                sample_rate=config.sample_rate,
                n_mels=config.n_mels
            ).to(config.device)

    def __call__(self, batch):
        # already on device
        return torch.log(self.melspec(batch).clamp_(min=1e-9, max=1e9))

In [None]:
melspec_train = LogMelspec(is_train=True, config=TaskConfig)
melspec_val = LogMelspec(is_train=False, config=TaskConfig)

### Quality measurment functions:

In [None]:
# FA - true: 0, model: 1
# FR - true: 1, model: 0

def count_FA_FR(preds, labels):
    FA = torch.sum(preds[labels == 0])
    FR = torch.sum(labels[preds == 0])
    
    # torch.numel - returns total number of elements in tensor
    return FA.item() / torch.numel(preds), FR.item() / torch.numel(preds)

In [None]:
def get_au_fa_fr(probs, labels):
    sorted_probs, _ = torch.sort(probs)
    sorted_probs = torch.cat((torch.Tensor([0]), sorted_probs, torch.Tensor([1])))
    labels = torch.cat(labels, dim=0)
        
    FAs, FRs = [], []
    for prob in sorted_probs:
        preds = (probs >= prob) * 1
        FA, FR = count_FA_FR(preds, labels)        
        FAs.append(FA)
        FRs.append(FR)
    # plt.plot(FAs, FRs)
    # plt.show()

    # ~ area under curve using trapezoidal rule
    return -np.trapz(FRs, x=FAs)

### Model

In [None]:
# Pay attention to _groups_ param

def SepConv(in_size, out_size, kernel_size, stride, padding=0):
    return nn.Sequential(
        torch.nn.Conv1d(in_size, in_size, kernel_size[1],
                        stride=stride[1], groups=in_size,
                        padding=padding),

        torch.nn.Conv1d(in_size, out_size, kernel_size=1,
                        stride=stride[0], groups=int(in_size / kernel_size[0]))
    )

In [None]:
class CRNN(nn.Module):
    
    def __init__(self, config):
        super(CRNN, self).__init__()

        self.sepconv = SepConv(in_size=config.n_mels, out_size=config.hidden_size,
                               kernel_size=config.kernel_size, stride=config.stride)

        self.gru = nn.GRU(input_size=config.hidden_size, hidden_size=config.hidden_size,
                          num_layers=config.gru_num_layers,
                          dropout=0.1,
                          bidirectional=config.bidirectional)

    def forward(self, x, hidden):
        x = self.sepconv(x)

        # (BS, hidden, seq_len) ->(seq_len, BS, hidden)
        x = x.permute(2, 0, 1)
        x, hidden = self.gru(x, hidden)
        # x : (seq_len, BS, hidden * num_dirs)
        # hidden : (num_layers * num_dirs, BS, hidden)

        return x, hidden

In [None]:
class AttnMech(nn.Module):

    def __init__(self, config):
        super(AttnMech, self).__init__()

        ratio = 2 if config.bidirectional else 1
        lin_size = config.hidden_size * ratio

        self.Wx_b = nn.Linear(lin_size, lin_size)
        self.Vt = nn.Linear(lin_size, 1, bias=False)

    def forward(self, hiddens, scores=None, return_context=False):
        """
        :param hiddens: output of encoder with shape: (BS, seq_len, hidden * num_dirs)
        :param scores: precomputed attention's scores (aka energy), if None then they will be computed
        :param return_context: if True then forward will return context vector else the scores will be returned
        """

        if scores is None:
            # (BS, seq_len, 1)
            scores = self.Vt(torch.tanh(self.Wx_b(hiddens))) 
        
        if not return_context:
            return scores

        alphas = F.softmax(scores, dim=1)

        return (alphas * hiddens).sum(dim=1)


In [None]:
class FullModel(nn.Module):

    def __init__(self, config, CRNN_model, attn_layer):
        super(FullModel, self).__init__()

        self.CRNN_model = CRNN_model
        self.attn_layer = attn_layer

        # ll_in_size, ll_out_size = HIDDEN_SIZE * GRU_NUM_DIRS, NUM_CLASSES
        # last layer
        ratio = 2 if config.bidirectional else 1
        self.U = nn.Linear(config.hidden_size * ratio,
                           config.num_classes, bias=False)

    def forward(self, batch, hidden=None):
        output, hidden = self.CRNN_model(batch, hidden)
        # output : (seq_len, BS, hidden * num_dirs)
        # hidden : (num_layers * num_dirs, BS, hidden)

        c = self.attn_layer(output.transpose(1, 0), return_context=True)

        return self.U(c)

In [None]:
def train_epoch(model, opt, loader, log_melspec, device):
    model.train()
    for i, (batch, labels) in tqdm(enumerate(loader)):
        batch, labels = batch.to(device), labels.to(device)
        batch = log_melspec(batch)

        opt.zero_grad()

        # run model # with autocast():
        logits = model(batch)
        # we need probabilities so we use softmax & CE separately
        probs = F.softmax(logits, dim=-1)
        loss = F.cross_entropy(logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

        opt.step()

        # logging
        argmax_probs = torch.argmax(probs, dim=-1)
        FA, FR = count_FA_FR(argmax_probs, labels)
        acc = torch.sum(argmax_probs == labels) / torch.numel(argmax_probs)

    return acc

In [None]:
@torch.no_grad()
def validation(model, loader, log_melspec, device):
    model.eval()

    val_losses, accs, FAs, FRs = [], [], [], []
    all_probs, all_labels = [], []
    for i, (batch, labels) in tqdm(enumerate(loader)):
        batch, labels = batch.to(device), labels.to(device)
        batch = log_melspec(batch)

        output = model(batch)
        # we need probabilities so we use softmax & CE separately
        probs = F.softmax(output, dim=-1)
        loss = F.cross_entropy(output, labels)

        # logging
        argmax_probs = torch.argmax(probs, dim=-1)
        all_probs.append(probs[:, 1].cpu())
        all_labels.append(labels.cpu())
        val_losses.append(loss.item())
        accs.append(
            torch.sum(argmax_probs == labels).item() /  # ???
            torch.numel(argmax_probs)
        )
        FA, FR = count_FA_FR(argmax_probs, labels)
        FAs.append(FA)
        FRs.append(FR)

    # area under FA/FR curve for whole loader
    au_fa_fr = get_au_fa_fr(torch.cat(all_probs, dim=0).cpu(), all_labels)
    return au_fa_fr

In [None]:
from collections import defaultdict
from IPython.display import clear_output
from matplotlib import pyplot as plt

history = defaultdict(list)

In [None]:
CRNN_model = CRNN(TaskConfig)
attn_layer = AttnMech(TaskConfig)
full_model = FullModel(TaskConfig, CRNN_model, attn_layer)
full_model = full_model.to(TaskConfig.device)

print(full_model)

opt = torch.optim.Adam(full_model.parameters(),
                       lr=TaskConfig.learning_rate, weight_decay=TaskConfig.weight_decay)

In [None]:
# TRAIN

for n in range(TaskConfig.num_epochs):

    train_epoch(full_model, opt, train_loader,
                melspec_train, TaskConfig.device)

    au_fa_fr = validation(full_model, val_loader,
                          melspec_val, TaskConfig.device)
    history['val_metric'].append(au_fa_fr)

    clear_output()
    plt.plot(history['val_metric'])
    plt.ylabel('Metric')
    plt.xlabel('Epoch')
    plt.grid()
    plt.show()

    print('END OF EPOCH', n)