This notebook is dedicated to Keyword Spotting (KWS)

In [10]:
%load_ext autoreload
%autoreload 2

In [None]:
### Download file for easier downloading and dataset creation
! wget https://gist.githubusercontent.com/Kirili4ik/6ac5c745ff8dad094e9c464c08f66f3e/raw/63daacc17f52a7d90f7f4166a3f5deef62b165db/dataset_utils.py
#!pip install wandb
!pip install easydict
!pip install --no-deps torchaudio==0.9.0

In [32]:
from utils.utils import *
set_seed(21)

### Task

In this notebook we will implement a model for finding a keyword in a stream.

We will implement the version with CRNN because it is easy and improves the model. 
(from https://www.dropbox.com/s/22ah2ba7dug6pzw/KWS_Attention.pdf)

### Configuration

In [18]:
key_word = 'sheila'   # We will use 1 key word -- 'sheila'

config = {
    'key_word'      : key_word,
    'batch_size'    : 256,
    'learning_rate' : 3e-4,
    'weight_decay'  : 1e-5,
    'num_epochs'    : 35,
    'n_mels'        : 40,         # number of mels for melspectrogram
    'kernel_size'   : (20, 5),    # size of kernel for convolution layer in CRNN
    'stride'        : (8, 2),     # size of stride for convolution layer in CRNN
    'hidden_size'   : 128,        # size of hidden representation in GRU
    'gru_num_layers': 2,          # number of GRU layers in CRNN
    'gru_num_dirs'  : 2,          # number of directions in GRU (2 if bidirectional)
    'num_classes'   : 2,          # number of classes (2 for "no word" or "sheila is in audio")
    'sample_rate'   : 16000,
    'device'        : torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
}

config = make_config(key_word, config)
print(f"keyword: '{config.key_word}'\ndevice: {config.device}")

keyword: 'sheila'
device: cuda:0


#### Download, generate lables & create Datasets:

In [19]:
from dataset_utils import DatasetDownloader

dataset_downloader = DatasetDownloader(key_word)
labeled_data, _ = dataset_downloader.generate_labeled_data()

labeled_data.sample(3)

Downloading data...
Ready!
Classes: bed, bird, cat, dog, down, eight, five, four, go, happy, house, left, marvin, nine, no, off, on, one, right, seven, sheila, six, stop, three, tree, two, up, wow, yes, zero
Creating labeled dataframe:


100%|██████████| 31/31 [06:28<00:00, 12.52s/it]


Unnamed: 0,name,word,label
63125,speech_commands/marvin/c8db14a8_nohash_1.wav,marvin,0
5266,speech_commands/five/1a9afd33_nohash_0.wav,five,0
61600,speech_commands/wow/dc75148d_nohash_0.wav,wow,0


In [24]:
from sklearn.model_selection import train_test_split
from augmentations.augs_creation import AugsCreation

# create 2 dataframes for train/val so we can use augmentations only for train
train_df, val_df = train_test_split(labeled_data, test_size=0.2, stratify=labeled_data['label'],  random_state=21)
train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)


from dataset_utils import TrainDataset

# Sample is a dict of utt, word and label
transform_tr = AugsCreation()
train_set = TrainDataset(df=train_df, kw=config.key_word, transform=transform_tr)
val_set   = TrainDataset(df=val_df,   kw=config.key_word)

print('all train + val samples:', len(train_set)+len(val_set))

all train + val samples: 64721


#### Sampler for oversampling:

In [27]:
train_sampler = get_sampler(train_set.df['label'].values)
val_sampler   = get_sampler(val_set.df['label'].values)

###  Dataloaders

In [29]:
# Here we are obliged to use shuffle=False because of our sampler with randomness inside.

train_loader = DataLoader(train_set, batch_size=config.batch_size,
                          shuffle=False, collate_fn=batch_data, 
                          sampler=train_sampler,
                          num_workers=2, pin_memory=True)

val_loader = DataLoader(val_set, batch_size=config.batch_size,
                        shuffle=False, collate_fn=batch_data, 
                        sampler=val_sampler,
                        num_workers=2, pin_memory=True)

### Creating MelSpecs on GPU for speeeed: 

In [31]:
from preprocessing.log_mel_spec import LogMelspec

melspec_train = LogMelspec(is_train=True, config=config)
melspec_val = LogMelspec(is_train=False, config=config)

### Model

In [33]:
from model.model import *

CRNN_model = CRNN(config)

attn_layer = AttnMech(config)

full_model = FullModel(config, CRNN_model, attn_layer)

full_model = full_model.to(config.device)

print(full_model)

FullModel(
  (CRNN_model): CRNN(
    (sepconv): Sequential(
      (0): Conv1d(40, 40, kernel_size=(5,), stride=(2,), groups=40)
      (1): Conv1d(40, 128, kernel_size=(1,), stride=(8,), groups=2)
    )
    (gru): GRU(128, 128, num_layers=2, dropout=0.1, bidirectional=True)
  )
  (attn_layer): AttnMech(
    (Wx_b): Linear(in_features=256, out_features=256, bias=True)
    (Vt): Linear(in_features=256, out_features=1, bias=False)
  )
  (U): Linear(in_features=256, out_features=2, bias=False)
)


In [34]:
opt = torch.optim.Adam(full_model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

In [None]:
### TRAIN
from train_utils.utils import *

for n in range(config.num_epochs):
    
    train_epoch(full_model, opt, train_loader, melspec_train,  
                config.gru_num_layers, config.gru_num_dirs,
                config.hidden_size, config.device)           
        
    validation(full_model, val_loader, melspec_val,
               config.gru_num_layers, config.gru_num_dirs,
               config.hidden_size, config.device)

    print('END OF EPOCH', n)

In [None]:
torch.save({
    'model_state_dict': full_model.state_dict(),
}, 'base_35ep')

In [36]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [39]:
!ls /gdrive/MyDrive/DLA/

KWS_seminar.ipynb  Untitled0.ipynb


In [43]:
!cp -r augmentations dataset_utils.py preprocessing model train_utils utils /gdrive/MyDrive/DLA/