# Torch-KWT Tutorial

This notebook will guide you through the steps to training and running inference on Google Speech Commands V2 (35) with the [Torch-KWT](https://github.com/ID56/Torch-KWT) repository.

## Setup

### 1. Clone the repository

In [1]:
!git clone https://github.com/ID56/Torch-KWT.git

Cloning into 'Torch-KWT'...
remote: Enumerating objects: 137, done.[K
remote: Counting objects: 100% (137/137), done.[K
remote: Compressing objects: 100% (83/83), done.[K
remote: Total 137 (delta 64), reused 108 (delta 40), pack-reused 0[K
Receiving objects: 100% (137/137), 110.33 KiB | 941.00 KiB/s, done.
Resolving deltas: 100% (64/64), done.


In [2]:
cd Torch-KWT/

/content/Torch-KWT


### 2. Install requirements

In [3]:
!pip install -qr requirements.txt

[K     |████████████████████████████████| 636 kB 7.4 MB/s 
[K     |████████████████████████████████| 1.8 MB 53.5 MB/s 
[K     |████████████████████████████████| 138 kB 63.3 MB/s 
[K     |████████████████████████████████| 133 kB 65.7 MB/s 
[K     |████████████████████████████████| 170 kB 64.0 MB/s 
[K     |████████████████████████████████| 97 kB 7.7 MB/s 
[K     |████████████████████████████████| 63 kB 1.9 MB/s 
[K     |████████████████████████████████| 62 kB 959 kB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires requests~=2.23.0, but you have requests 2.26.0 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m


### 3. Download the Google Speech Commands V2 dataset

We'll be saving it to the `./data/` folder.

In [4]:
!sh ./download_gspeech_v2.sh ./data/

--2021-08-04 22:31:13--  http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 142.250.141.128, 2607:f8b0:4023:c0b::80
Connecting to download.tensorflow.org (download.tensorflow.org)|142.250.141.128|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2428923189 (2.3G) [application/gzip]
Saving to: ‘STDOUT’


2021-08-04 22:32:11 (40.1 MB/s) - written to stdout [2428923189/2428923189]



In [5]:
!ls data

_background_noise_  five     left     README.md		tree
backward	    follow   LICENSE  right		two
bed		    forward  marvin   seven		up
bird		    four     nine     sheila		validation_list.txt
cat		    go	     no       six		visual
dog		    happy    off      stop		wow
down		    house    on       testing_list.txt	yes
eight		    learn    one      three		zero


As you can see, the dataset provides a `validation_list.txt` and a `testing_list.txt` as the split. We'll run a simple script `make_data_list.py` to also generate a `training_list.txt`, as well as a `label_map.json` that maps numeric indices to class labels.

In [6]:
!python make_data_list.py -v ./data/validation_list.txt -t ./data/testing_list.txt -d ./data/ -o ./data/

Number of training samples: 84843
Number of validation samples: 9981
Number of test samples: 11005
Saved data lists and label map.


## Using Pre-trained Models for Inference

The Torch-KWT repository provides a single checkpoint for KWT-1 at present, which has a 95.98 % accuracy on the test set. We can use this checkpoint to run some inferences, before we look at training.

### 4. Downloading Pre-Trained Model

In [7]:
!wget -O "kwt1_pretrained.ckpt" "https://drive.google.com/uc?id=1Pglq3kFy9BVFk-bPVsbNuX_fzMGJ5uwy&export=download"

--2021-08-04 22:39:37--  https://drive.google.com/uc?id=1Pglq3kFy9BVFk-bPVsbNuX_fzMGJ5uwy&export=download
Resolving drive.google.com (drive.google.com)... 142.250.141.101, 142.250.141.100, 142.250.141.139, ...
Connecting to drive.google.com (drive.google.com)|142.250.141.101|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-14-0g-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/3r83s3iga8itc7g9jgeq0diiiha4sfhf/1628116725000/07781108439908404460/*/1Pglq3kFy9BVFk-bPVsbNuX_fzMGJ5uwy?e=download [following]
--2021-08-04 22:39:39--  https://doc-14-0g-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/3r83s3iga8itc7g9jgeq0diiiha4sfhf/1628116725000/07781108439908404460/*/1Pglq3kFy9BVFk-bPVsbNuX_fzMGJ5uwy?e=download
Resolving doc-14-0g-docs.googleusercontent.com (doc-14-0g-docs.googleusercontent.com)... 142.250.141.132, 2607:f8b0:4023:c0b::84
Connecting to doc-14-0g-docs.googleusercontent.com (d

### 5. Creating a KWT-1 Model & Loading the Pre-Trained Weights

In [8]:
import torch
from models.kwt import kwt_from_name

model = kwt_from_name("kwt-1")

ckpt = torch.load("kwt1_pretrained.ckpt", map_location="cpu")["model_state_dict"]
model.load_state_dict(ckpt)

<All keys matched successfully>

### 6. Loading Label Map and Setting Audio Params

We'll need the label map we made earlier to see exactly what class is being predicted. We'll also need to set some parameters needed for audio processing.

In [9]:
import json

with open("data/label_map.json", "r") as f:
  label_map = json.load(f)


audio_settings = {
    "sr": 16000,
    "n_mels": 40,
    "n_fft": 480,
    "win_length": 480,
    "hop_length": 160,
    "center": False
}

### 7. Exploring Some Test Set Audio Files

Let's take a few files from the test set and listen to them:

In [13]:
with open("data/testing_list.txt", "r") as f:
  data_list = f.read().rstrip().split()
data_list[:3]  

['./data/left/85d2ac4b_nohash_1.wav',
 './data/seven/e41a903b_nohash_4.wav',
 './data/dog/af130f12_nohash_0.wav']

In [17]:
import IPython

print(data_list[0])
IPython.display.Audio(data_list[0])

./data/left/85d2ac4b_nohash_1.wav


In [18]:
print(data_list[1])
IPython.display.Audio(data_list[1])

./data/seven/e41a903b_nohash_4.wav


In [19]:
print(data_list[2])
IPython.display.Audio(data_list[2])

./data/dog/af130f12_nohash_0.wav


### 8. Running Inference

The function `cache_item_loader` from the repository can take care of most of the data loading, so we'll use that.

In [20]:
from utils.dataset import cache_item_loader

@torch.no_grad()
def run_single_inference(model, path, audio_settings, label_map):
  """Loads a short ~1s audio file from the speech commands dataset and runs inference with the provided model."""

  data = cache_item_loader(path, sr=16000, cache_level=2, audio_settings=audio_settings)
  data = torch.from_numpy(data).float().reshape(-1, 1, *data.shape)
  pred_cls = model(data).argmax(1).item()
  
  print(f'Predicted class: {pred_cls} ({label_map[str(pred_cls)]})')

In [24]:
for i in range(3):
  run_single_inference(model, data_list[i], audio_settings, label_map)

Predicted class: 15 (left)
Predicted class: 23 (seven)
Predicted class: 4 (dog)


## Training

For training, we only need to provide the config file.

### 9. Setting Up Your Config File

For this example, we'll be using the `sample_configs/base_config.yaml`. In fact, you should be able to use this config to reproduce the results of the provided pretrained KWT-1 checkpoint if you follow the exact settings (training for 140 epochs / ~23000 steps @ batch_size = 512).

You can also use [wandb](wandb.ai) to log your runs. Either provide a path to a txt file containing your API key, or set the env variable "WANDB_API_KEY", like:

```
os.environ["WANDB_API_KEY"] = "yourkey"
```

We will not be using wandb in this example, but feel free to try it.

In [25]:
conf_str = """# sample config to run a demo training of 20 epochs

data_root: ./data/
train_list_file: ./data/training_list.txt
val_list_file: ./data/validation_list.txt
test_list_file: ./data/testing_list.txt
label_map: ./data/label_map.json

exp:
    wandb: False
    wandb_api_key: <path/to/api/key>
    proj_name: torch-kwt-1
    exp_dir: ./runs
    exp_name: exp-0.0.1
    device: auto
    log_freq: 20    # log every l_f steps
    log_to_file: True
    log_to_stdout: True
    val_freq: 5    # validate every v_f epochs
    n_workers: 1
    pin_memory: True
    cache: 2 # 0 -> no cache | 1 -> cache wavs | 2 -> cache specs; stops wav augments
    

hparams:
    seed: 0
    batch_size: 512
    n_epochs: 20
    l_smooth: 0.1

    audio:
        sr: 16000
        n_mels: 40
        n_fft: 480
        win_length: 480
        hop_length: 160
        center: False
    
    model:
        name: # if name is provided below settings will be ignored during model creation   
        input_res: [98, 40]
        patch_res: [40, 1]
        num_classes: 35
        mlp_dim: 256
        dim: 64
        heads: 1
        depth: 12
        dropout: 0.0
        emb_dropout: 0.1
        pre_norm: False

    optimizer:
        opt_type: adamw
        opt_kwargs:
          lr: 0.001
          weight_decay: 0.1
    
    scheduler:
        n_warmup: 10
        max_epochs: 140
        scheduler_type: cosine_annealing

    augment:
        # resample:
            # r_min: 0.85
            # r_max: 1.15
        
        # time_shift:
            # s_min: -0.1
            # s_max: 0.1

        # bg_noise:
            # bg_folder: ./data/_background_noise_/

        spec_aug:
            n_time_masks: 2
            time_mask_width: 25
            n_freq_masks: 2
            freq_mask_width: 7"""

!mkdir -p configs
with open("configs/kwt1_colab.yaml", "w+") as f:
  f.write(conf_str)

### 10. Initiating Training

Make sure you are using a GPU runtime.

In order to train to a full 140 epochs / 23000 steps like the paper, on free resources, we need to cut down on disk I/O and audio processing time. So, we'll preemptively convert all our `.wav` files into MFCCs of shape `(98, 40)` and keep them stored in memory. This caching process may take ~6 minutes.

Since we'll be directly using MFCCs, no wav augmentations like resample, time_shift or background_noise will be used; we'll just use spectral augmentation with the settings from the paper.

In [None]:
!python train.py --conf configs/kwt1_colab.yaml

Set seed 0
Using settings:
 data_root: ./data/
exp:
  cache: 2
  device: &id001 !!python/object/apply:torch.device
  - cuda
  exp_dir: ./runs
  exp_name: exp-0.0.1
  log_freq: 20
  log_to_file: true
  log_to_stdout: true
  n_workers: 1
  pin_memory: true
  proj_name: torch-kwt-1
  save_dir: ./runs/exp-0.0.1
  val_freq: 5
  wandb: false
  wandb_api_key: <path/to/api/key>
hparams:
  audio:
    center: false
    hop_length: 160
    n_fft: 480
    n_mels: 40
    sr: 16000
    win_length: 480
  augment:
    spec_aug:
      freq_mask_width: 7
      n_freq_masks: 2
      n_time_masks: 2
      time_mask_width: 25
  batch_size: 512
  device: *id001
  l_smooth: 0.1
  model:
    depth: 12
    dim: 64
    dropout: 0.0
    emb_dropout: 0.1
    heads: 1
    input_res:
    - 98
    - 40
    mlp_dim: 256
    name: null
    num_classes: 35
    patch_res:
    - 40
    - 1
    pre_norm: false
  n_epochs: 20
  optimizer:
    opt_kwargs:
      lr: 0.001
      weight_decay: 0.1
    opt_type: adamw
  schedul