This notebook is dedicated to Keyword Spotting (KWS)

In [1]:
### Download file for easier downloading and dataset creation
! wget https://gist.githubusercontent.com/Kirili4ik/6ac5c745ff8dad094e9c464c08f66f3e/raw/63daacc17f52a7d90f7f4166a3f5deef62b165db/dataset_utils.py

'wget' is not recognized as an internal or external command,
operable program or batch file.


### Most imports

In [4]:
#!pip install wandb
!pip install easydict
!pip install --no-deps torchaudio==0.9.0

Collecting torchaudio==0.9.0
  Using cached torchaudio-0.9.0-cp37-cp37m-manylinux1_x86_64.whl (1.9 MB)
Installing collected packages: torchaudio
Successfully installed torchaudio-0.9.0


In [6]:
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm import tqdm

# import wandb

import torch
from torch.utils.data import DataLoader
from torch import distributions

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import WeightedRandomSampler
from torch.nn.utils.rnn import pad_sequence

import torchaudio
from IPython import display as display_

from torch.cuda.amp import autocast

### Helper functions

In [7]:
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


set_seed(21)

In [8]:
def count_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in model_parameters])

### Task

In this notebook we will implement a model for finding a keyword in a stream.

We will implement the version with CRNN because it is easy and improves the model. 
(from https://www.dropbox.com/s/22ah2ba7dug6pzw/KWS_Attention.pdf)

### Configuration

In [9]:
key_word = 'sheila'   # We will use 1 key word -- 'sheila'

In [10]:
#!pip install easydict
from easydict import EasyDict as edict


def make_config(key_word):
    config = {
        'key_word'      : key_word,
        'batch_size'    : 256,
        'learning_rate' : 3e-4,
        'weight_decay'  : 1e-5,
        'num_epochs'    : 35,
        'n_mels'        : 40,         # number of mels for melspectrogram
        'kernel_size'   : (20, 5),    # size of kernel for convolution layer in CRNN
        'stride'        : (8, 2),     # size of stride for convolution layer in CRNN
        'hidden_size'   : 128,        # size of hidden representation in GRU
        'gru_num_layers': 2,          # number of GRU layers in CRNN
        'gru_num_dirs'  : 2,          # number of directions in GRU (2 if bidirectional)
        'num_classes'   : 2,          # number of classes (2 for "no word" or "sheila is in audio")
        'sample_rate'   : 16000,
        'device'        : torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    }

    config = edict(config)
    return config

config = make_config(key_word)
print('keyword is', config.key_word, 'device is', config.device)

keyword is sheila device is cpu


#### Augmentations


In [11]:
class AugsCreation():

    def __init__(self):
        self.background_noises = [
                'speech_commands/_background_noise_/white_noise.wav',
                'speech_commands/_background_noise_/dude_miaowing.wav',
                'speech_commands/_background_noise_/doing_the_dishes.wav',
                'speech_commands/_background_noise_/exercise_bike.wav',
                'speech_commands/_background_noise_/pink_noise.wav',
                'speech_commands/_background_noise_/running_tap.wav'
        ]


    def add_rand_noise(self, audio):
        
        # randomly choose noise
        noise_num = torch.randint(low=0, high=len(self.background_noises), size=(1,)).item()    
        noise = torchaudio.load(self.background_noises[noise_num])[0].squeeze()    
        
        noise_level = torch.Tensor([1])  # [0, 40]

        noise_energy = torch.norm(noise)
        audio_energy = torch.norm(audio)
        alpha = (audio_energy / noise_energy) * torch.pow(10, -noise_level / 20)

        start = torch.randint(low=0, high=int(noise.size(0) - audio.size(0) - 1), size=(1,)).item()
        noise_sample = noise[start : start + audio.size(0)]

        audio_new = audio + alpha * noise_sample
        audio_new.clamp_(-1, 1)
        return audio_new


    def __call__(self, wav):
        aug_num = torch.randint(low=0, high=4, size=(1,)).item()   # choose 1 random aug from augs
        augs = [
            lambda x: x,
            lambda x: (x + distributions.Normal(0, 0.01).sample(x.size())).clamp_(-1, 1),
            lambda x: torchaudio.transforms.Vol(.25)(x),
            lambda x: self.add_rand_noise(x)
        ]
        
        return augs[aug_num](wav)

#### Download, generate lables & create Datasets:

In [12]:
from utils.dataset_utils import DatasetDownloader

dataset_downloader = DatasetDownloader(key_word)
labeled_data, _ = dataset_downloader.generate_labeled_data()

labeled_data.sample(3)

Downloading data...
Ready!

0it [00:00, ?it/s]


Classes: 
Creating labeled dataframe:





ValueError: a must be greater than 0 unless no samples are taken

In [13]:
from sklearn.model_selection import train_test_split

# create 2 dataframes for train/val so we can use augmentations only for train
train_df, val_df = train_test_split(labeled_data, test_size=0.2, stratify=labeled_data['label'],  random_state=21)
train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)


from utils.dataset_utils import TrainDataset

# Sample is a dict of utt, word and label
transform_tr = AugsCreation()
train_set = TrainDataset(df=train_df, kw=config.key_word, transform=transform_tr)
val_set   = TrainDataset(df=val_df,   kw=config.key_word)

print('all train + val samples:', len(train_set)+len(val_set))

ValueError: With n_samples=0, test_size=0.2 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.

#### Sampler for oversampling:

In [None]:
# We should provide to WeightedRandomSampler _weight for every sample_; by default it is 1/len(target)

def get_sampler(target):
    class_sample_count = np.array(
        [len(np.where(target == t)[0]) for t in np.unique(target)])   # for every class count it's number of occ.
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in target])

    samples_weight = torch.from_numpy(samples_weight)
    samples_weigth = samples_weight.double()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    return sampler

In [None]:
train_sampler = get_sampler(train_set.df['label'].values)
val_sampler   = get_sampler(val_set.df['label'].values)

In [None]:
def batch_data(data):
    wavs = []
    labels = []    
        
    for el in data:
        wavs.append(el['utt'])
        labels.append(el['label'])

    # torch.nn.utils.rnn.pad_sequence takes list(Tensors) and returns padded (with 0.0) Tensor
    wavs = pad_sequence(wavs, batch_first=True)    
    labels = torch.Tensor(labels).type(torch.long)
    return wavs, labels

###  Dataloaders

In [None]:
# Here we are obliged to use shuffle=False because of our sampler with randomness inside.

train_loader = DataLoader(train_set, batch_size=config.batch_size,
                          shuffle=False, collate_fn=batch_data, 
                          sampler=train_sampler,
                          num_workers=2, pin_memory=True)

val_loader = DataLoader(val_set, batch_size=config.batch_size,
                        shuffle=False, collate_fn=batch_data, 
                        sampler=val_sampler,
                        num_workers=2, pin_memory=True)

### Creating MelSpecs on GPU for speeeed: 

In [None]:
class LogMelspec():
    
    def __init__(self, is_train, config):
        # with augmentations
        if is_train:
            self.melspec = nn.Sequential(
                torchaudio.transforms.MelSpectrogram(sample_rate=config.sample_rate,  n_mels=config.n_mels),
                torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
                torchaudio.transforms.TimeMasking(time_mask_param=35),
            ).to(config.device)

        # no augmentations
        else:
            self.melspec = torchaudio.transforms.MelSpectrogram(
                sample_rate=config.sample_rate,
                n_mels=config.n_mels
            ).to(config.device)
            

    def __call__(self, batch):
        return torch.log(self.melspec(batch).clamp_(min=1e-9, max=1e9))  # already on device

In [None]:
melspec_train = LogMelspec(is_train=True, config=config)
melspec_val = LogMelspec(is_train=False, config=config)

### Quality measurment functions:

In [None]:
# FA - true: 0, model: 1
# FR - true: 1, model: 0

def count_FA_FR(preds, labels):
    FA = torch.sum(preds[labels == 0])
    FR = torch.sum(labels[preds == 0])
    
    # torch.numel - returns total number of elements in tensor
    return FA.item() / torch.numel(preds), FR.item() / torch.numel(preds)

In [None]:
def get_au_fa_fr(probs, labels):
    sorted_probs, _ = torch.sort(probs)
    sorted_probs = torch.cat((torch.Tensor([0]), sorted_probs, torch.Tensor([1])))
    labels = torch.cat(labels, dim=0)
        
    FAs, FRs = [], []
    for prob in sorted_probs:
        preds = (probs >= prob) * 1
        FA, FR = count_FA_FR(preds, labels)        
        FAs.append(FA)
        FRs.append(FR)
    # plt.plot(FAs, FRs)
    # plt.show()

    # ~ area under curve using trapezoidal rule
    return -np.trapz(FRs, x=FAs)

### Model

In [None]:
# Pay attention to _groups_ param

def SepConv(in_size, out_size, kernel_size, stride, padding=0):
    return nn.Sequential(
        torch.nn.Conv1d(in_size, in_size, kernel_size[1], 
                        stride=stride[1], groups=in_size,
                        padding=padding),
        
        torch.nn.Conv1d(in_size, out_size, kernel_size=1, 
                        stride=stride[0], groups=int(in_size / kernel_size[0]))
    )

In [None]:
class CRNN(nn.Module):
    def __init__(self, config):
        super(CRNN, self).__init__()
            
        self.sepconv = SepConv(in_size=config.n_mels, out_size=config.hidden_size, 
                               kernel_size=config.kernel_size, stride=config.stride)
        
        self.gru = nn.GRU(input_size=config.hidden_size, hidden_size=config.hidden_size, 
                          num_layers=config.gru_num_layers, 
                          dropout=0.1, 
                          bidirectional=True if config.gru_num_dirs==2 else False)

    
    def forward(self, x, hidden):
        x = self.sepconv(x)
        
        # (BS, hidden, seq_len) ->(seq_len, BS, hidden) 
        x = x.permute(2, 0, 1)
        x, hidden = self.gru(x, hidden)
        # x : (seq_len, BS, hidden * num_dirs)
        # hidden : (num_layers * num_dirs, BS, hidden)
                        
        return x, hidden

In [None]:
class AttnMech(nn.Module):
    def __init__(self, config):
        super(AttnMech, self).__init__()
        
        lin_size = config.hidden_size * config.gru_num_dirs
        
        self.Wx_b = nn.Linear(lin_size, lin_size)
        self.Vt   = nn.Linear(lin_size, 1, bias=False)
        

    def forward(self, inputs, data=None):
        
        # count only 1 e_t
        if data is None:
            x = inputs
            x = torch.tanh(self.Wx_b(x))
            e = self.Vt(x)
            return e
        
        # recount attention for full vector e
        e = inputs
        data = data.transpose(0, 1)                # (BS, seq_len, hid_size*num_dirs)
        alphas = F.softmax(e, dim=-1).unsqueeze(1)
        c = torch.matmul(alphas, data).squeeze()   # attetntion_vector
        return c

In [None]:
class FullModel(nn.Module):
    def __init__(self, config, CRNN_model, attn_layer):
        super(FullModel, self).__init__()
        
        self.CRNN_model = CRNN_model
        self.attn_layer = attn_layer
        
        # ll_in_size, ll_out_size = HIDDEN_SIZE * GRU_NUM_DIRS, NUM_CLASSES
        # last layer
        self.U = nn.Linear(config.hidden_size * config.gru_num_dirs, 
                           config.num_classes, bias=False)

        
    def forward(self, batch, hidden):
        output, hidden = self.CRNN_model(batch, hidden)
        # output : (seq_len, BS, hidden * num_dirs)
        # hidden : (num_layers * num_dirs, BS, hidden)
        
        e = []
        for seq_el in output:
            e_t = self.attn_layer(seq_el) # (BS, 1)
            e.append(e_t)
        e = torch.cat(e, dim=1)           # (BS, seq_len)
        
        c = self.attn_layer(e, output)    # attention_vector
        Uc = self.U(c)        
        return Uc               # we will need to get probs, so we use return logits

In [None]:
CRNN_model = CRNN(config)

attn_layer = AttnMech(config)

full_model = FullModel(config, CRNN_model, attn_layer)

full_model = full_model.to(config.device)

print(full_model)

FullModel(
  (CRNN_model): CRNN(
    (sepconv): Sequential(
      (0): Conv1d(40, 40, kernel_size=(5,), stride=(2,), groups=40)
      (1): Conv1d(40, 128, kernel_size=(1,), stride=(8,), groups=2)
    )
    (gru): GRU(128, 128, num_layers=2, dropout=0.1, bidirectional=True)
  )
  (attn_layer): AttnMech(
    (Wx_b): Linear(in_features=256, out_features=256, bias=True)
    (Vt): Linear(in_features=256, out_features=1, bias=False)
  )
  (U): Linear(in_features=256, out_features=2, bias=False)
)


In [None]:
opt = torch.optim.Adam(full_model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

In [None]:
def train_epoch(model, opt, loader, log_melspec, gru_nl, gru_nd, hidden_size, device):
    model.train()
    for i, (batch, labels) in tqdm(enumerate(loader)):
        batch, labels = batch.to(device), labels.to(device)
        batch = log_melspec(batch)

        opt.zero_grad()
        
        # define frist hidden with 0
        hidden = torch.zeros(gru_nl*2, batch.size(0), hidden_size).to(device)    # (num_layers*num_dirs, BS, hidden)
        # run model # with autocast():
        logits = model(batch, hidden)
        probs  = F.softmax(logits, dim=-1)            # we need probabilities so we use softmax & CE separately
        loss   = F.cross_entropy(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        
        opt.step()

        # logging
        argmax_probs = torch.argmax(probs, dim=-1)                
        FA, FR = count_FA_FR(argmax_probs, labels)
        acc = torch.sum(argmax_probs == labels) / torch.numel(argmax_probs)

        print(acc)

In [None]:
@torch.no_grad()
def validation(model, loader, log_melspec, gru_nl, gru_nd, hidden_size, device):
    model.eval()

    val_losses, accs, FAs, FRs = [], [], [], []
    all_probs, all_labels = [], []
    for i, (batch, labels) in tqdm(enumerate(loader)):
        batch, labels = batch.to(device), labels.to(device)
        batch = log_melspec(batch)  

        # define frist hidden with 0
        hidden = torch.zeros(gru_nl*gru_nd, batch.size(0), hidden_size).to(device)    # (num_layers * num_dirs, BS, )
        # run model   # with autocast():
        output = model(batch, hidden)
        probs  = F.softmax(output, dim=-1)            # we need probabilities so we use softmax & CE separately
        loss   = F.cross_entropy(output, labels)

        # logging
        argmax_probs = torch.argmax(probs, dim=-1)
        all_probs.append(probs[:, 1].cpu())
        all_labels.append(labels.cpu())
        val_losses.append(loss.item())
        accs.append(
            torch.sum(argmax_probs == labels).item() /  #???
            torch.numel(argmax_probs)
        )
        FA, FR = count_FA_FR(argmax_probs, labels)
        FAs.append(FA)
        FRs.append(FR)
        
    # area under FA/FR curve for whole loader
    au_fa_fr = get_au_fa_fr(torch.cat(all_probs, dim=0).cpu(), all_labels)
    wandb.log({'mean_val_loss':np.mean(val_losses), 'mean_val_acc':np.mean(accs),
                'mean_val_FA':np.mean(FAs), 'mean_val_FR':np.mean(FRs),
                'au_fa_fr':au_fa_fr})
    
    return np.mean(val_losses)

In [None]:
### TRAIN

for n in range(config.num_epochs):
    
    train_epoch(full_model, opt, train_loader, melspec_train,  
                config.gru_num_layers, config.gru_num_dirs,
                config.hidden_size, config.device)           
        
    validation(full_model, val_loader, melspec_val,
               config.gru_num_layers, config.gru_num_dirs,
               config.hidden_size, config.device)

    print('END OF EPOCH', n)

203it [01:41,  1.99it/s]
51it [00:03, 13.76it/s]


END OF EPOCH 0


203it [01:40,  2.02it/s]
51it [00:03, 15.17it/s]


END OF EPOCH 1


203it [01:41,  1.99it/s]
51it [00:03, 15.54it/s]


END OF EPOCH 2


203it [01:40,  2.01it/s]
51it [00:03, 15.64it/s]


END OF EPOCH 3


203it [01:40,  2.01it/s]
51it [00:03, 14.55it/s]


END OF EPOCH 4


203it [01:40,  2.02it/s]
51it [00:03, 15.59it/s]


END OF EPOCH 5


203it [01:42,  1.98it/s]
51it [00:03, 14.45it/s]


END OF EPOCH 6


203it [01:41,  2.00it/s]
51it [00:03, 14.95it/s]


END OF EPOCH 7


203it [01:42,  1.97it/s]
51it [00:03, 15.46it/s]


END OF EPOCH 8


203it [01:43,  1.96it/s]
51it [00:03, 15.51it/s]


END OF EPOCH 9


203it [01:41,  1.99it/s]
51it [00:03, 14.53it/s]


END OF EPOCH 10


203it [01:40,  2.02it/s]
51it [00:03, 15.50it/s]


END OF EPOCH 11


203it [01:42,  1.98it/s]
51it [00:03, 14.50it/s]


END OF EPOCH 12


203it [01:42,  1.98it/s]
51it [00:03, 15.72it/s]


END OF EPOCH 13


203it [01:43,  1.96it/s]
51it [00:03, 14.63it/s]


END OF EPOCH 14


203it [01:42,  1.97it/s]
51it [00:03, 15.44it/s]


END OF EPOCH 15


203it [01:41,  2.00it/s]
51it [00:03, 14.40it/s]


END OF EPOCH 16


203it [01:41,  2.01it/s]
51it [00:03, 15.51it/s]


END OF EPOCH 17


203it [01:42,  1.99it/s]
51it [00:03, 14.60it/s]


END OF EPOCH 18


203it [01:39,  2.03it/s]
51it [00:03, 15.42it/s]


END OF EPOCH 19


203it [01:41,  2.01it/s]
51it [00:03, 14.76it/s]


END OF EPOCH 20


203it [01:42,  1.98it/s]
51it [00:03, 15.55it/s]


END OF EPOCH 21


203it [01:42,  1.99it/s]
51it [00:03, 14.56it/s]


END OF EPOCH 22


203it [01:40,  2.02it/s]
51it [00:03, 15.65it/s]


END OF EPOCH 23


203it [01:40,  2.01it/s]
51it [00:03, 15.16it/s]


END OF EPOCH 24


203it [01:41,  1.99it/s]
51it [00:03, 15.63it/s]


END OF EPOCH 25


203it [01:42,  1.99it/s]
51it [00:03, 14.44it/s]


END OF EPOCH 26


203it [01:40,  2.02it/s]
51it [00:03, 15.68it/s]


END OF EPOCH 27


203it [01:37,  2.09it/s]
51it [00:03, 14.92it/s]


END OF EPOCH 28


203it [01:39,  2.05it/s]
51it [00:03, 15.67it/s]


END OF EPOCH 29


203it [01:40,  2.03it/s]
51it [00:03, 15.70it/s]


END OF EPOCH 30


203it [01:40,  2.03it/s]
51it [00:03, 15.33it/s]


END OF EPOCH 31


203it [01:39,  2.05it/s]
51it [00:03, 14.81it/s]


END OF EPOCH 32


203it [01:41,  2.01it/s]
51it [00:03, 14.70it/s]


END OF EPOCH 33


203it [01:37,  2.07it/s]
51it [00:03, 15.73it/s]


END OF EPOCH 34


In [None]:
torch.save({
    'model_state_dict': full_model.state_dict(),
}, 'base_35ep')