# 0. Requirements & Config

In [1]:
requirements_str = """
torch>=1.7.0
librosa
audiomentations
pydub
tqdm
einops"""

with open("requirements.txt", "w") as f:
    f.write(requirements_str)

!pip install -r requirements.txt

Collecting audiomentations (from -r requirements.txt (line 4))
  Downloading audiomentations-0.42.0-py3-none-any.whl.metadata (11 kB)
Collecting numpy-minmax<1,>=0.3.0 (from audiomentations->-r requirements.txt (line 4))
  Downloading numpy_minmax-0.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.0 kB)
Collecting numpy-rms<1,>=0.4.2 (from audiomentations->-r requirements.txt (line 4))
  Downloading numpy_rms-0.6.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.5 kB)
Collecting python-stretch<1,>=0.3.1 (from audiomentations->-r requirements.txt (line 4))
  Downloading python_stretch-0.3.1-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Downloading audiomentations-0.42.0-py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.5/86.5 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nump

In [2]:
conf_str = """# sample config to run a demo training of 20 epochs

data_root: ./data/
train_list_file: ./data/training_list.txt
val_list_file: ./data/validation_list.txt
test_list_file: ./data/testing_list.txt
label_map: ./data/label_map.json

pkl_root: ./google_speech_commands_v2/
pkl_train: ./google_speech_commands_v2/train_dataset.pkl
pkl_test: ./google_speech_commands_v2/test_dataset.pkl

exp:
    exp_dir: ./runs
    exp_name: exp-0.0.1
    device: auto
    log_freq: 20    # log every l_f steps
    log_to_file: False
    log_to_stdout: True
    val_freq: 1    # validate every v_f epochs
    n_workers: 1
    pin_memory: True


hparams:
    seed: 0
    batch_size: 512
    n_epochs: 20
    l_smooth: 0.1

    audio:
        sr: 16000
        n_mels: 40
        n_fft: 480
        win_length: 480
        hop_length: 160
        center: False

    model:
        name: kwt-2 # if name is provided below settings will be ignored during model creation
        input_res: [40, 98]
        patch_res: [40, 1]
        num_classes: 35
        mlp_dim: 256
        dim: 64
        heads: 1
        depth: 12
        dropout: 0.0
        emb_dropout: 0.1
        pre_norm: False

    optimizer:
        opt_type: adamw
        opt_kwargs:
          lr: 0.001
          weight_decay: 0.1

    scheduler:
        n_warmup: 10
        max_epochs: 140
        scheduler_type: cosine_annealing

    augment:
        spec_aug:
            n_time_masks: 2
            time_mask_width: 25
            n_freq_masks: 2
            freq_mask_width: 7"""

!mkdir -p configs
with open("configs/kwt2_colab.yaml", "w+") as f:
  f.write(conf_str)

conf_file = "configs/kwt2_colab.yaml"

- `config_parser.py` - `get_config`

In [24]:
import yaml
import os
import torch
import sys


def get_config(config_file: str) -> dict:

    with open(config_file, "r") as f:
        base_config = yaml.load(f, Loader=yaml.FullLoader)

    if base_config["exp"]["device"] == "auto":
        base_config["exp"]["device"] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        base_config["hparams"]["device"] = base_config["exp"]["device"]

    return base_config

# 1. Define the model

In [3]:
import torch
import torch.fft
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange


# Basically vision transformer, ViT that accepts MFCC + SpecAug. Refer to:
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class PostNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.norm(self.fn(x, **kwargs))


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, pre_norm=True, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])

        P_Norm = PreNorm if pre_norm else PostNorm

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                P_Norm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                P_Norm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class KWT(nn.Module):
    def __init__(self, input_res, patch_res, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0., pre_norm = True, **kwargs):
        super().__init__()

        num_patches = int(input_res[0]/patch_res[0] * input_res[1]/patch_res[1])

        patch_dim = channels * patch_res[0] * patch_res[1]
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_res[0], p2 = patch_res[1]),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, pre_norm, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        x = self.to_patch_embedding(x)

        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)


def kwt_from_name(model_name: str):

    models = {
        "kwt-1": {
            "input_res": [40, 98],
            "patch_res": [40, 1],
            "num_classes": 35,
            "mlp_dim": 256,
            "dim": 64,
            "heads": 1,
            "depth": 12,
            "dropout": 0.0,
            "emb_dropout": 0.1,
            "pre_norm": False
        },

        "kwt-2": {
            "input_res": [40, 98],
            "patch_res": [40, 1],
            "num_classes": 35,
            "mlp_dim": 512,
            "dim": 128,
            "heads": 2,
            "depth": 12,
            "dropout": 0.0,
            "emb_dropout": 0.1,
            "pre_norm": False
        },

        "kwt-3": {
            "input_res": [40, 98],
            "patch_res": [40, 1],
            "num_classes": 35,
            "mlp_dim": 768,
            "dim": 192,
            "heads": 3,
            "depth": 12,
            "dropout": 0.0,
            "emb_dropout": 0.1,
            "pre_norm": False
        }
    }

    assert model_name in models.keys(), f"Unsupported model_name {model_name}; must be one of {list(models.keys())}"

    return KWT(**models[model_name])

# 2. Dataset functions

## 2.1 Spectrogram augmentation

In [4]:
import numpy as np
import numba as nb
import librosa

#@nb.jit(nopython=True, cache=True)
@nb.jit(nopython=True)
def spec_augment(mel_spec: np.ndarray, n_time_masks: int, time_mask_width: int, n_freq_masks: int, freq_mask_width: int):
    offset, begin = 0, 0

    for _ in range(n_time_masks):
        offset = np.random.randint(0, time_mask_width)
        begin = np.random.randint(0, mel_spec.shape[1] - offset)
        mel_spec[:, begin: begin + offset] = 0.0

    for _ in range(n_freq_masks):
        offset = np.random.randint(0, freq_mask_width)
        begin = np.random.randint(0, mel_spec.shape[0] - offset)
        mel_spec[begin: begin + offset, :] = 0.0

    return mel_spec


## 2.2 Dataset wrapper and get loader

In [31]:
from torch.utils.data import ConcatDataset, Dataset
import pickle

def merge_val_test(val_dataset: Dataset, test_dataset: Dataset):
    return ConcatDataset([val_dataset, test_dataset])


def save_dataset(dataset, file_path):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, "wb") as f:
        pickle.dump(dataset, f)
    print(f"Dataset saved to {file_path}.")


def load_dataset(file_path):
    with open(file_path, 'rb') as file:
        dataset = pickle.load(file)
    print(f"Dataset loaded from {file_path}.")
    return dataset

In [8]:
class GoogleSpeechDataset(Dataset):
    """Dataset wrapper for Google Speech Commands V2 to save, since output is in numpy array."""

    def __init__(self, data_list: list, audio_settings: dict, label_map: dict = None):
        super().__init__()

        self.data_list = data_list
        self.audio_settings = audio_settings

        # labels: if no label map is provided, will not load labels. (Use for inference)
        if label_map is not None:
            self.label_list = []
            label_2_idx = {v: int(k) for k, v in label_map.items()}
            for path in data_list:
                # Store the integer index instead of the string label
                self.label_list.append(label_2_idx[path.split("/")[-2]])
        else:
            self.label_list = None


    def __len__(self):
        return len(self.data_list)


    def __getitem__(self, idx):
        x = librosa.load(self.data_list[idx], sr=self.audio_settings["sr"])[0]

        # this will return MFCC for saved .pkl file (in numpy array)
        x = librosa.util.fix_length(x, size=self.audio_settings["sr"])
        x = librosa.feature.melspectrogram(y=x, **self.audio_settings)
        x = librosa.feature.mfcc(S=librosa.power_to_db(x), n_mfcc=self.audio_settings["n_mels"])

        if self.label_list is not None:
            label = self.label_list[idx]
            return x, label
        else:
            return x

In [43]:
from torch.utils.data import ConcatDataset, Dataset
import pickle

class PrecomputedSpeechDataset(Dataset):
    """
        API ~ GoogleSpeechDataset, use when training to ensure real-time spec_aug
        __getitem__ -> (x, label) if label else x
    """

    def __init__(self, pkl_path, aug_settings: dict = None, label_map: dict = None, train=False):
        super().__init__()
        self.dataset = load_dataset(pkl_path)   # list of (x, label) or [x]
        self.aug_settings = aug_settings
        self.train = train

        # labels: same as GoogleSpeechDataset
        if label_map is not None:
            self.label_list = []
            #label_2_idx = {v: int(k) for k, v in label_map.items()}
            # The label map is not needed here anymore since the labels in the pickle file
            # are already integer indices.
            for sample in self.dataset:
                if isinstance(sample, tuple) and len(sample) == 2:
                    _, y = sample
                    self.label_list.append(y) # Use the integer label directly
                else:
                    self.label_list.append(None)
        else:
            self.label_list = None


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]

        if isinstance(sample, tuple) and len(sample) == 2:
            x, y = sample
        else:
            x, y = sample, None

        # augment
        if self.train and self.aug_settings is not None:
            if "spec_aug" in self.aug_settings:
                x = spec_augment(x, **self.aug_settings["spec_aug"])

        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        if x.ndim == 2:
            x = x.unsqueeze(0)  # (n_MFCC, T)

        if self.label_list is not None:
            label = self.label_list[idx]
            label = torch.tensor(label, dtype=torch.long)
            return x, label
        elif y is not None:
            y = torch.tensor(y, dtype=torch.long)
            return x, y
        else:
            return x

In [10]:
def get_loader(dataset, config, train=True):
  dataloader = DataLoader(
      dataset,
      batch_size=config["hparams"]["batch_size"],
      num_workers=config["exp"]["n_workers"],
      pin_memory=config["exp"]["pin_memory"],
      shuffle=True if train else False
  )

  return dataloader

# 3. Download Google Speech Commands V2 dataset

In [11]:
download_gspeech_v2_str = """
#!/bin/bash

data_dir=$1
curr_dir=$PWD

mkdir -p $data_dir

cd $data_dir
wget http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz -O - | tar -xz

cd $curr_dir"""

with open('download_gspeech_v2.sh', 'w') as f:
    f.write(download_gspeech_v2_str)

!sh ./download_gspeech_v2.sh ./data/

--2025-08-20 12:37:44--  http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 142.250.145.207, 74.125.128.207, 74.125.143.207, ...
Connecting to download.tensorflow.org (download.tensorflow.org)|142.250.145.207|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2428923189 (2.3G) [application/gzip]
Saving to: ‘STDOUT’


2025-08-20 12:38:44 (39.1 MB/s) - written to stdout [2428923189/2428923189]



In [12]:
!ls data

_background_noise_  five     left     README.md		tree
backward	    follow   LICENSE  right		two
bed		    forward  marvin   seven		up
bird		    four     nine     sheila		validation_list.txt
cat		    go	     no       six		visual
dog		    happy    off      stop		wow
down		    house    on       testing_list.txt	yes
eight		    learn    one      three		zero


 The dataset provides a `validation_list.txt` and a `testing_list.txt` as the split. We'll run a simple script `make_data_list.py` to also generate a `training_list.txt`, as well as a `label_map.json` that maps numeric indices to class labels.

### Validate dataset directory

In [14]:
import os

def walk_through_dir(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} files in '{dirpath}'.")

data_path = './data/'
walk_through_dir(data_path)

There are 36 directories and 5 files in './data/'.
There are 0 directories and 1579 files in './data/follow'.
There are 0 directories and 2031 files in './data/cat'.
There are 0 directories and 1557 files in './data/forward'.
There are 0 directories and 3880 files in './data/go'.
There are 0 directories and 3860 files in './data/six'.
There are 0 directories and 2123 files in './data/wow'.
There are 0 directories and 4044 files in './data/yes'.
There are 0 directories and 3934 files in './data/nine'.
There are 0 directories and 3801 files in './data/left'.
There are 0 directories and 2014 files in './data/bed'.
There are 0 directories and 2022 files in './data/sheila'.
There are 0 directories and 3890 files in './data/one'.
There are 0 directories and 2113 files in './data/house'.
There are 0 directories and 1575 files in './data/learn'.
There are 0 directories and 3941 files in './data/no'.
There are 0 directories and 2064 files in './data/bird'.
There are 0 directories and 3778 files

In [19]:
testing_list_path = '/content/data/testing_list.txt'

def get_classes(content):
    classes = set()
    for line in content.splitlines():
        parts = line.split('/')
        if len(parts) > 3:
            classes.add(parts[2])
        else:
            classes.add(parts[0])
    return classes


with open(testing_list_path, 'r') as file:
    content = file.read()
    #print(content)
    class_set = get_classes(content)
    print(class_set)
    print(len(class_set))

{'no', 'bird', 'left', 'five', 'house', 'forward', 'bed', 'one', 'dog', 'learn', 'two', 'down', 'three', 'eight', 'nine', 'right', 'zero', 'up', 'off', 'stop', 'on', 'marvin', 'seven', 'four', 'visual', 'follow', 'tree', 'six', 'backward', 'wow', 'go', 'yes', 'sheila', 'cat', 'happy'}
35


## 3.1 Get loader from .pkl file
1. Create 3 Dataset train - val - test
2. Merge val - test
3. Save as .pkl file
4. Load .pkl file and change into dataloader

In [20]:
import json
import os
#from utils.dataset import get_train_val_test_split
from torch.utils import data
from torch.utils.data import Dataset, DataLoader
import numpy as np
import functools
import librosa
import glob
import os
from tqdm import tqdm
import multiprocessing as mp
import json

#from utils.augment import time_shift, resample, spec_augment
from audiomentations import AddBackgroundNoise


def get_train_val_test_split(root: str, val_file: str, test_file: str):
    """Creates train, val, and test split according to provided val and test files.

    Args:
        root (str): Path to base directory of the dataset.
        val_file (str): Path to file containing list of validation data files.
        test_file (str): Path to file containing list of test data files.

    Returns:
        train_list (list): List of paths to training data items.
        val_list (list): List of paths to validation data items.
        test_list (list): List of paths to test data items.
        label_map (dict): Mapping of indices to label classes.
    """

    ####################
    # Labels
    ####################

    label_list = [label for label in sorted(os.listdir(root)) if os.path.isdir(os.path.join(root, label)) and label[0] != "_"]
    label_map = {idx: label for idx, label in enumerate(label_list)}

    ###################
    # Split
    ###################

    all_files_set = set()
    for label in label_list:
        all_files_set.update(set(glob.glob(os.path.join(root, label, "*.wav"))))

    with open(val_file, "r") as f:
        val_files_set = set(map(lambda a: os.path.join(root, a), f.read().rstrip("\n").split("\n")))

    with open(test_file, "r") as f:
        test_files_set = set(map(lambda a: os.path.join(root, a), f.read().rstrip("\n").split("\n")))

    assert len(val_files_set.intersection(test_files_set)) == 0, "Sanity check: No files should be common between val and test."

    all_files_set -= val_files_set
    all_files_set -= test_files_set

    train_list, val_list, test_list = list(all_files_set), list(val_files_set), list(test_files_set)

    print(f"Number of training samples: {len(train_list)}")
    print(f"Number of validation samples: {len(val_list)}")
    print(f"Number of test samples: {len(test_list)}")

    return train_list, val_list, test_list, label_map

def generate_data_lists(val_list_file, test_list_file, data_root, out_dir):

    train_list, val_list, test_list, label_map = get_train_val_test_split(data_root, val_list_file, test_list_file)

    with open(os.path.join(out_dir, "training_list.txt"), "w+") as f:
        f.write("\n".join(train_list))

    with open(os.path.join(out_dir, "validation_list.txt"), "w+") as f:
        f.write("\n".join(val_list))

    with open(os.path.join(out_dir, "testing_list.txt"), "w+") as f:
        f.write("\n".join(test_list))

    with open(os.path.join(out_dir, "label_map.json"), "w+") as f:
        json.dump(label_map, f)

    print("Saved data lists and label map.")


In [21]:
generate_data_lists('./data/validation_list.txt', './data/testing_list.txt', './data/', './data/')

Number of training samples: 84843
Number of validation samples: 9981
Number of test samples: 11005
Saved data lists and label map.


In [29]:
def create_pkl_data(conf_file):
  config = get_config(conf_file)

  with open(config["label_map"], "r") as f:
    label_map = json.load(f)

  with open(config["train_list_file"], "r") as f:
    train_list = f.read().rstrip().split("\n")

  with open(config["test_list_file"], "r") as f:
    test_list = f.read().rstrip().split("\n")

  with open(config["val_list_file"], "r") as f:
    val_list = f.read().rstrip().split("\n")

  train_first_dataset = GoogleSpeechDataset(data_list=train_list,
                                            audio_settings=config["hparams"]["audio"],
                                            label_map=label_map)
  test_first_dataset = GoogleSpeechDataset(data_list=test_list,
                                           audio_settings=config["hparams"]["audio"],
                                           label_map=label_map)
  val_first_dataset = GoogleSpeechDataset(data_list=val_list,
                                          audio_settings=config["hparams"]["audio"],
                                          label_map=label_map)

  test_merged_dataset = merge_val_test(val_first_dataset, test_first_dataset)

  save_dataset(train_first_dataset, config["pkl_train"])
  save_dataset(test_merged_dataset, config["pkl_test"])

In [32]:
create_pkl_data(conf_file)

Dataset saved to ./google_speech_commands_v2/train_dataset.pkl.
Dataset saved to ./google_speech_commands_v2/test_dataset.pkl.


# Set up for training functions

- `misc.py` Miscellaneous helper functions.

In [33]:
"""Miscellaneous helper functions."""

import torch
from torch import nn, optim
import numpy as np
import random
import os
import wandb


def seed_everything(seed: str) -> None:

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    print(f'Set seed {seed}')


def count_params(model: nn.Module) -> int:
    return sum(map(lambda p: p.data.numel(), model.parameters()))


def calc_step(epoch: int, n_batches: int, batch_index: int) -> int:
    return (epoch - 1) * n_batches + (1 + batch_index)


def log(log_dict: dict, step: int, config: dict) -> None:
    """Handles logging for metric tracking server, local disk and stdout.

    Args:
        log_dict (dict): Log metric dict.
        step (int): Current step.
        config (dict): Config dict.
    """
    log_message = f"Step: {step} | " + " | ".join([f"{k}: {v}" for k, v in log_dict.items()])

    # write logs to disk
    if config["exp"]["log_to_file"]:
        log_file = os.path.join(config["exp"]["save_dir"], "training_log.txt")

        with open(log_file, "a+") as f:
            f.write(log_message + "\n")

    # show logs in stdout
    if config["exp"]["log_to_stdout"]:
        print(log_message)


def get_model(model_config: dict) -> nn.Module:

    if model_config["name"] is not None:
        return kwt_from_name(model_config["name"])
    else:
        return KWT(**model_config)


def save_model(epoch: int, val_acc: float, save_path: str, net: nn.Module, optimizer : optim.Optimizer = None, log_file : str = None) -> None:

    ckpt_dict = {
        "epoch": epoch,
        "val_acc": val_acc,
        "model_state_dict": net.state_dict(),
        "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else optimizer
    }

    torch.save(ckpt_dict, save_path)

    log_message = f"Saved {save_path} with accuracy {val_acc}."
    print(log_message)

    if log_file is not None:
        with open(log_file, "a+") as f:
            f.write(log_message + "\n")


- LR Scheduler - Cosine
- Optimizer

In [34]:
from torch import optim, nn
from torch.optim import lr_scheduler


class WarmUpLR(lr_scheduler._LRScheduler):

    def __init__(self, optimizer: optim.Optimizer, total_iters: int, last_epoch: int = -1):
        """Initializer for WarmUpLR"""

        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """Learning rate will be set to base_lr * last_epoch / total_iters."""

        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]


def get_scheduler(optimizer: optim.Optimizer, scheduler_type: str, T_max: int) -> lr_scheduler._LRScheduler:

    if scheduler_type == "cosine_annealing":
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=1e-8)
    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")

    return scheduler

from torch import nn, optim


def get_optimizer(net: nn.Module, opt_config: dict) -> optim.Optimizer:

    if opt_config["opt_type"] == "adamw":
        optimizer = optim.AdamW(net.parameters(), **opt_config["opt_kwargs"])
    else:
        raise ValueError(f'Unsupported optimizer {opt_config["opt_type"]}')

    return optimizer


- LabelSmoothingLoss

In [35]:
class LabelSmoothingLoss(nn.Module):

    def __init__(self, num_classes: int, smoothing : float = 0.1, dim : int = -1):

        super().__init__()

        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = num_classes
        self.dim = dim

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

        assert 0 <= self.smoothing < 1
        pred = pred.log_softmax(dim=self.dim)

        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

# 4. Training

- `trainer.py` - `train` and `evaluate` function



In [44]:
import torch
from torch import nn, optim
from typing import Callable, Tuple
from torch.utils.data import DataLoader
#from utils.misc import log, save_model
import os
import time
from tqdm import tqdm


def train_single_batch(net: nn.Module, data: torch.Tensor, targets: torch.Tensor, optimizer: optim.Optimizer, criterion: Callable, device: torch.device) -> Tuple[float, int]:
    """Performs a single training step."""

    data, targets = data.to(device), targets.to(device)

    optimizer.zero_grad()
    outputs = net(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    correct = outputs.argmax(1).eq(targets).sum()
    return loss.item(), correct.item()


@torch.no_grad()
def evaluate(net: nn.Module, criterion: Callable, dataloader: DataLoader, device: torch.device) -> Tuple[float, float]:
    """Performs inference.

    Args:
        net (nn.Module): Model instance.
        criterion (Callable): Loss function.
        dataloader (DataLoader): Test or validation loader.
        device (torch.device): Device.

    Returns:
        accuracy (float): Accuracy.
        float: Loss scalar.
    """

    net.eval()
    correct = 0
    running_loss = 0.0

    for data, targets in tqdm(dataloader):
        data, targets = data.to(device), targets.to(device)
        out = net(data)
        correct += out.argmax(1).eq(targets).sum().item()
        loss = criterion(out, targets)
        running_loss += loss.item()

    net.train()
    accuracy = correct / len(dataloader.dataset)
    return accuracy, running_loss / len(dataloader)


def train(net: nn.Module, optimizer: optim.Optimizer, criterion: Callable, trainloader: DataLoader, schedulers: dict, config: dict) -> None:
    """Trains model.

    Args:
        net (nn.Module): Model instance.
        optimizer (optim.Optimizer): Optimizer instance.
        criterion (Callable): Loss function.
        trainloader (DataLoader): Training data loader.
        schedulers (dict): Dict containing schedulers.
        config (dict): Config dict.
    """

    step = 0
    best_acc = 0.0 #
    n_batches = len(trainloader)
    device = config["hparams"]["device"]
    log_file = os.path.join(config["exp"]["save_dir"], "training_log.txt")

    ############################
    # start training
    ############################
    net.train()

    for epoch in range(config["hparams"]["n_epochs"]):
        t0 = time.time()
        running_loss = 0.0
        correct = 0

        for batch_index, (data, targets) in enumerate(trainloader):

            if schedulers["warmup"] is not None and epoch < config["hparams"]["scheduler"]["n_warmup"]:
                schedulers["warmup"].step()

            elif schedulers["scheduler"] is not None:
                schedulers["scheduler"].step()

            ####################
            # optimization step
            ####################

            loss, corr = train_single_batch(net, data, targets, optimizer, criterion, device)
            running_loss += loss
            correct += corr

            if not step % config["exp"]["log_freq"]:
                log_dict = {"epoch": epoch, "loss": loss, "lr": optimizer.param_groups[0]["lr"]}
                log(log_dict, step, config)

            step += 1

        #######################
        # epoch complete
        #######################
        train_acc = correct / len(trainloader.dataset)
        log_dict = {
            "epoch": epoch,
            "time_per_epoch": time.time() - t0,
            "train_acc": train_acc,
            "avg_loss_per_ep": running_loss / len(trainloader)
        }
        log(log_dict, step, config)

        # save best model based on train acc
        if train_acc > best_acc:
            best_acc = train_acc
            save_path = os.path.join(config["exp"]["save_dir"], "best.pth")
            save_model(epoch, train_acc, save_path, net, optimizer, log_file)

    ###########################
    # training complete
    ###########################

    # save final ckpt
    save_path = os.path.join(config["exp"]["save_dir"], "last.pth")
    save_model(epoch, train_acc, save_path, net, optimizer, log_file)


`train.py`

In [45]:
import json
import os
import yaml
import torch
import wandb
import random
import time

from argparse import ArgumentParser
# from config_parser import get_config
# from utils.loss import LabelSmoothingLoss
# from utils.opt import get_optimizer
# from utils.scheduler import WarmUpLR, get_scheduler
# from utils.trainer import train, evaluate
# from utils.dataset import get_train_val_test_split, get_loader
# from utils.misc import seed_everything, count_params, get_model, calc_step, log

def training_pipeline(config):
    """Initiates and executes all the steps involved with model training.

    Args:
        config (dict) - Dict containing various settings for the training run.
    """
    start_func_time = time.time()
    # Get label map
    with open(config["label_map"], "r") as f:
      label_map = json.load(f)

    config["exp"]["save_dir"] = os.path.join(config["exp"]["exp_dir"], config["exp"]["exp_name"])
    os.makedirs(config["exp"]["save_dir"], exist_ok=True)

    ######################################
    # save hyperparameters for current run
    ######################################

    config_str = yaml.dump(config)
    print("Using settings:\n", config_str)

    with open(os.path.join(config["exp"]["save_dir"], "settings.txt"), "w+") as f:
        f.write(config_str)

    #####################################
    # initialize training items
    #####################################

    # data
    train_dataset = PrecomputedSpeechDataset(
        pkl_path=config["pkl_train"],
        aug_settings = config["hparams"]["augment"],
        label_map = label_map,
        train=True # for data_augment
    )

    trainloader = get_loader(train_dataset, config, train=True) # for shuffle

    # model
    model = get_model(config["hparams"]["model"])
    model = model.to(config["hparams"]["device"])
    print(f"Created model with {count_params(model)} parameters.")

    # loss
    if config["hparams"]["l_smooth"]:
        criterion = LabelSmoothingLoss(num_classes=config["hparams"]["model"]["num_classes"], smoothing=config["hparams"]["l_smooth"])
    else:
        criterion = nn.CrossEntropyLoss()

    # optimizer
    optimizer = get_optimizer(model, config["hparams"]["optimizer"])

    # lr scheduler
    schedulers = {
        "warmup": None,
        "scheduler": None
    }

    if config["hparams"]["scheduler"]["n_warmup"]:
        schedulers["warmup"] = WarmUpLR(optimizer, total_iters=len(trainloader) * config["hparams"]["scheduler"]["n_warmup"])

    if config["hparams"]["scheduler"]["scheduler_type"] is not None:
        total_iters = len(trainloader) * max(1, (config["hparams"]["scheduler"]["max_epochs"] - config["hparams"]["scheduler"]["n_warmup"]))
        schedulers["scheduler"] = get_scheduler(optimizer, config["hparams"]["scheduler"]["scheduler_type"], total_iters)


    #####################################
    # Training Run
    #####################################
    end_func_time = time.time()
    print("Initiating training.")
    print(f"Takes {end_func_time - start_func_time} seconds to start training.")
    # train(model, optimizer, criterion, trainloader, valloader, schedulers, config)
    train(model, optimizer, criterion, trainloader, schedulers, config)

    #####################################
    # Final Test
    #####################################

    test_dataset = PrecomputedSpeechDataset(
        pkl_path=config["pkl_test"],
        aug_settings = config["hparams"]["augment"],
        label_map = label_map,
        train=False
    )
    testloader = get_loader(test_dataset, config, train=False)
    final_step = calc_step(config["hparams"]["n_epochs"] + 1, len(trainloader), len(trainloader) - 1)

    # evaluating the final state (last.pth)
    print("Evaluating last ckpt")
    test_acc, test_loss = evaluate(model, criterion, testloader, config["hparams"]["device"])
    log_dict = {
        "test_loss_last": test_loss,
        "test_acc_last": test_acc
    }
    log(log_dict, final_step, config)

    # evaluating the best state (best.pth)
    ckpt = torch.load(os.path.join(config["exp"]["save_dir"], "best.pth"))
    model.load_state_dict(ckpt["model_state_dict"])
    print("Best ckpt loaded - by train_acc.")

    test_acc, test_loss = evaluate(model, criterion, testloader, config["hparams"]["device"])
    log_dict = {
        "test_loss_best": test_loss,
        "test_acc_best": test_acc
    }
    log(log_dict, final_step, config)

def main(conf_file):
    config = get_config(conf_file)
    seed_everything(config["hparams"]["seed"])
    training_pipeline(config)


Set up sample configs for KWT-2 with demo training 20 epochs
- **cache: here we set = 0 for no cache, enable data augmentation**.
- The paper train 140 epochs / 23000 steps.

In [46]:
main(conf_file)

Set seed 0
Using settings:
 data_root: ./data/
exp:
  device: &id001 !!python/object/apply:torch.device
  - cuda
  exp_dir: ./runs
  exp_name: exp-0.0.1
  log_freq: 20
  log_to_file: false
  log_to_stdout: true
  n_workers: 1
  pin_memory: true
  save_dir: ./runs/exp-0.0.1
  val_freq: 1
hparams:
  audio:
    center: false
    hop_length: 160
    n_fft: 480
    n_mels: 40
    sr: 16000
    win_length: 480
  augment:
    spec_aug:
      freq_mask_width: 7
      n_freq_masks: 2
      n_time_masks: 2
      time_mask_width: 25
  batch_size: 512
  device: *id001
  l_smooth: 0.1
  model:
    depth: 12
    dim: 64
    dropout: 0.0
    emb_dropout: 0.1
    heads: 1
    input_res:
    - 40
    - 98
    mlp_dim: 256
    name: kwt-2
    num_classes: 35
    patch_res:
    - 40
    - 1
    pre_norm: false
  n_epochs: 20
  optimizer:
    opt_kwargs:
      lr: 0.001
      weight_decay: 0.1
    opt_type: adamw
  scheduler:
    max_epochs: 140
    n_warmup: 10
    scheduler_type: cosine_annealing
  seed



Step: 0 | epoch: 0 | loss: 3.7855935096740723 | lr: 6.024096385505879e-07
Step: 20 | epoch: 0 | loss: 3.657698392868042 | lr: 1.2650602409562347e-05
Step: 40 | epoch: 0 | loss: 3.534853935241699 | lr: 2.4698795180574107e-05
Step: 60 | epoch: 0 | loss: 3.523815631866455 | lr: 3.6746987951585866e-05
Step: 80 | epoch: 0 | loss: 3.49392032623291 | lr: 4.879518072259762e-05
Step: 100 | epoch: 0 | loss: 3.516051769256592 | lr: 6.0843373493609386e-05
Step: 120 | epoch: 0 | loss: 3.49010968208313 | lr: 7.289156626462113e-05
Step: 140 | epoch: 0 | loss: 3.4255590438842773 | lr: 8.493975903563291e-05
Step: 160 | epoch: 0 | loss: 3.3385000228881836 | lr: 9.698795180664466e-05
Step: 166 | epoch: 0 | time_per_epoch: 350.80592131614685 | train_acc: 0.05147154155322183 | avg_loss_per_ep: 3.513685789452978
Saved ./runs/exp-0.0.1/best.pth with accuracy 0.05147154155322183.
Step: 180 | epoch: 1 | loss: 3.21820068359375 | lr: 0.00010903614457765641
Step: 200 | epoch: 1 | loss: 3.2096359729766846 | lr: 0.

100%|██████████| 41/41 [01:09<00:00,  1.70s/it]


Step: 3486 | test_loss_last: 1.4416520944455775 | test_acc_last: 0.7336319451062613
Best ckpt loaded.


100%|██████████| 41/41 [01:11<00:00,  1.74s/it]

Step: 3486 | test_loss_best: 1.3598168681307536 | test_acc_best: 0.7709901839321452



