* [pytorch tutorials](https://pytorch.org/tutorials/)
* [torchaudio](https://pytorch.org/audio/stable/index.html)

In [None]:
# %%bash
# pip install torch==1.12.1
# pip install torchaudio==0.12.1
# pip install omegaconf==2.2.3
# pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

In [None]:
import os
import random
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional, Callable, Dict, List, Any, Tuple

import omegaconf
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchaudio
import thop
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

import IPython.display as ipd

In [None]:
SEED = 777

os.environ['PYTHONHASHSEED'] = str(SEED)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
@dataclass
class Model:
    kernels: Tuple[int]
    strides: Tuple[int]
    channels: Tuple[int]
    hidden_size: int
    activation: str

@dataclass
class Optim:
    lr: float
    n_epochs: int
    batch_size: int
        
@dataclass
class Features:
    n_fft: int
    win_length: int
    hop_length: int
    n_mels: int

@dataclass
class Augmentations:
    freq_mask_param: int
    time_mask_param: int
    min_gain: float
    max_gain: float

# Dataset

In [None]:
class SpotterDataset(torch.utils.data.Dataset):
    
    def __init__(
            self, manifest_path: Path, idx_to_keyword: List[str],
            transform, ids: Optional[List[int]] = None
        ):
        super().__init__()
        
        self.transform = transform
        
        manifest = pd.read_csv(manifest_path)
        if ids is not None:
            manifest = manifest.loc[ids]
        self.wav_files = [
            manifest_path.parent / wav_path for wav_path in manifest.path
        ]
        
        keyword_to_idx = {
            keyword: idx for idx, keyword in enumerate(idx_to_keyword)
        }
        self.labels = [
            keyword_to_idx[keyword] for keyword in manifest.label
        ]
        
    def __len__(self):
        return len(self.wav_files)
    
    def __getitem__(self, idx):
        wav, sr = torchaudio.load(self.wav_files[idx])
        features = self.transform(wav)
        return wav[0], features, self.labels[idx]

In [None]:
def collator(data):
    specs = []
    labels = []

    for wav, features, label in data:
        specs.append(features)
        labels.append(label)

    specs = torch.cat(specs)  
    labels = torch.Tensor(labels).long()
    return specs, labels

In [None]:
class SpecScaler(torch.nn.Module):
    def forward(self, x):
        return torch.log(x.clamp_(1e-9, 1e9))

class RandomGain(torch.nn.Module):
    def __init__(self, min_gain: float=0.5, max_gain: float=1.0):
        super().__init__()
        self.min_gain = min_gain
        self.max_gain = max_gain

    def forward(self, audio: torch.Tensor) -> torch.Tensor:  
        gain = random.uniform(self.min_gain, self.max_gain / audio.abs().max())
        audio = torchaudio.transforms.Vol(gain, gain_type="amplitude")(audio)
        return audio

In [None]:
def prepare_dataloaders(conf: omegaconf.DictConfig) -> Tuple[torch.utils.data.DataLoader]:
    
    train_transform = torch.nn.Sequential(
        RandomGain(min_gain=conf.augs.min_gain, max_gain=conf.augs.max_gain),
        torchaudio.transforms.MelSpectrogram(sample_rate=conf.sample_rate, **conf.features),
        torchaudio.transforms.FrequencyMasking(freq_mask_param=conf.augs.freq_mask_param),
        torchaudio.transforms.TimeMasking(time_mask_param=conf.augs.time_mask_param),
        SpecScaler()
    )

    val_transform = torch.nn.Sequential(
        torchaudio.transforms.MelSpectrogram(sample_rate=conf.sample_rate, **conf.features),
        SpecScaler()
    )
    
    dataset = SpotterDataset(
        manifest_path=Path(conf.train_manifest),
        idx_to_keyword=conf.idx_to_keyword,
        transform=train_transform
    )
    
    val_count = int(len(dataset) * conf.val_fraction)
    ids = torch.randperm(len(dataset), generator=torch.Generator().manual_seed(SEED))

    val_ids = ids[:val_count]
    train_ids = ids[val_count:]

    train_dataset = SpotterDataset(
        manifest_path=Path(conf.train_manifest),
        idx_to_keyword=conf.idx_to_keyword,
        transform=train_transform,
        ids=train_ids
    )

    val_dataset = SpotterDataset(
        manifest_path=Path(conf.train_manifest),
        idx_to_keyword=conf.idx_to_keyword,
        transform=val_transform,
        ids=val_ids
    )
    
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=conf.optim.batch_size,
        shuffle=True,
        collate_fn=collator,
    )

    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=conf.optim.batch_size,
        shuffle=False,
        collate_fn=collator,
    )
    
    return train_dataloader, val_dataloader

# Train Loop

In [None]:
def train_one_epoch(model, criterion, optimizer, loader, device, epoch_index, tb_writer, log_interval=100):
    
    model.train()
    
    running_loss = 0.
    running_true_preds, running_preds = 0, 0
    last_loss, last_acc = 0., 0.

    for i, (inputs, labels) in enumerate(loader):
        
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        
        logits = model(inputs)
        preds = logits.argmax(1)

        loss = criterion(logits, labels)
        loss.backward()

        optimizer.step()
        
        running_loss += loss.item()
        running_true_preds += (preds == labels).sum()
        running_preds += torch.numel(preds)
        
        if i % log_interval == log_interval - 1:
            last_loss = running_loss / log_interval
            last_acc = running_true_preds / running_preds 
            tb_x = epoch_index * len(loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            tb_writer.add_scalar('Accuracy/train', last_acc, tb_x)
            running_loss = 0.
            running_true_preds, running_preds = 0, 0

    return last_loss, last_acc


@torch.no_grad()
def validation(model, criterion, loader, device, epoch_index, tb_writer):
    
    model.eval()
    
    running_loss = 0.
    running_true_preds, running_preds = 0, 0

    for i, (inputs, labels) in enumerate(loader):
        
        inputs, labels = inputs.to(device), labels.to(device)
        
        logits = model(inputs)
        preds = logits.argmax(1)
        
        running_loss += criterion(logits, labels).item()
        running_true_preds += (preds == labels).sum()
        running_preds += torch.numel(preds)
        
    loss = running_loss / len(loader)
    acc = running_true_preds / running_preds 
    
    tb_x = epoch_index + 1
    tb_writer.add_scalar('Loss/val', loss, tb_x)
    tb_writer.add_scalar('Accuracy/val', acc, tb_x)
    
    return loss, acc

In [None]:
def init_logger_and_dump_params(conf, model, description=''):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    exp_dir = Path(f"runs/{timestamp}{'_' + description if description else ''}")
    ckpt_dir = exp_dir / "ckpts"
    ckpt_dir.mkdir(parents=True)
    with open(exp_dir / 'conf.yaml', 'w') as f:
        omegaconf.OmegaConf.save(config=conf, f=f)

    tb_writer = SummaryWriter(exp_dir)
    
    rand_features = torch.randn(1, conf.features.n_mels, conf.sample_rate // conf.features.hop_length + 1)
    macs, params = thop.profile(
        model,
        inputs=(rand_features.to(conf.device),)
    )

    tb_writer.add_scalar('MACs', macs, 0)
    tb_writer.add_scalar('Params', params, 0)
    return ckpt_dir, tb_writer

In [None]:
def train_loop(conf, model, criterion, optimizer, train_dataloader, val_dataloader, description=''):
    
    ckpt_dir, tb_writer = init_logger_and_dump_params(conf, model, description=description)

    best_val_acc = -1.

    for epoch in range(conf.optim.n_epochs):

        avg_loss, avg_acc = train_one_epoch(
            model, criterion, optimizer, train_dataloader,
            conf.device, epoch, tb_writer
        )

        val_loss, val_acc = validation(
            model, criterion, val_dataloader,
            conf.device, epoch, tb_writer
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(
                model.state_dict(),
                ckpt_dir / f"model_epoch_{epoch + 1}_val_acc_{val_acc:.3f}.ckpt"
            )

# Model

In [None]:
class Conv1dNet(torch.nn.Module):
    
    def __init__(self, in_features, n_classes, conf: omegaconf.dictconfig.DictConfig):
        
        super().__init__()
        
        activation = getattr(torch.nn, conf.activation)()
        
        features = in_features
        
        module_list = []
        
        for kernel_size, stride, channels in zip(conf.kernels, conf.strides, conf.channels):
            
            module_list.extend([
                torch.nn.Conv1d(
                    in_channels=features, out_channels=features, kernel_size=kernel_size,
                    stride=stride, groups=features
                ),
                activation,
                torch.nn.Conv1d(in_channels=features, out_channels=channels, kernel_size=1),
                torch.nn.BatchNorm1d(num_features=channels),
                activation,
                torch.nn.MaxPool1d(kernel_size=stride)
            ])
            
            features = channels

        module_list.extend([
            torch.nn.AdaptiveAvgPool1d(1),
            torch.nn.Flatten(),

            torch.nn.Linear(channels, conf.hidden_size),
            activation,
            torch.nn.Linear(conf.hidden_size, n_classes),
        ])
        
        self.model = torch.nn.Sequential(*module_list)
        
    def forward(self, x):
        return self.model(x)

# Data

download from [kaggle](https://www.kaggle.com/t/830d20b353bd4e0d80630a97835f14a6)

In [None]:
%%bash
rm -rf test train
unzip -q train.zip
unzip -q test.zip

# Experiments

In [None]:
@dataclass
class DistillConfig:
    weight: float = 0.1
    softmax_temp: float = 10


@dataclass
class ExpConfig:
    
    model: Model
    
    train_manifest: str = 'train/manifest_.csv'
    
    sample_rate: int = 16_000
    val_fraction: float = 0.1
    idx_to_keyword: List[str] = ('sber', 'joy', 'afina', 'salut', 'filler')
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    
    optim: Optim = Optim(
        lr=1e-3, n_epochs=20, batch_size=64
    )
    features: Features = Features(
        n_fft=400, win_length=400, hop_length=160, n_mels=64
    )
    augs: Augmentations = Augmentations(
        freq_mask_param=0, time_mask_param=0,
        min_gain=0.5, max_gain=1.0
    )
    distill: DistillConfig = DistillConfig()

## Train Teacher

In [None]:
teacher_model_cfg = Model(
    kernels=[3, 3, 3, 3],
    strides=[2, 2, 1, 1],
    channels=[32, 32, 64, 128],
    hidden_size=16,
    activation='ReLU'
)
teacher_conf = omegaconf.OmegaConf.structured(ExpConfig(model=teacher_model_cfg))

In [None]:
train_dataloader, val_dataloader = prepare_dataloaders(teacher_conf)
teacher_model = (
    Conv1dNet(
        in_features=teacher_conf.features.n_mels, 
        n_classes=len(teacher_conf.idx_to_keyword),
        conf=teacher_conf.model
    )
    .to(teacher_conf.device)
)
optimizer = torch.optim.Adam(params=teacher_model.parameters(), lr=teacher_conf.optim.lr)
criterion = torch.nn.CrossEntropyLoss()

train_loop(
    teacher_conf, teacher_model, criterion, optimizer,
    train_dataloader, val_dataloader, description='teacher'
)

## Train Student

In [None]:
student_model_cfg = Model(
    kernels=[3],
    strides=[4],
    channels=[16],
    hidden_size=16,
    activation='ReLU'
)

In [None]:
student_conf = omegaconf.OmegaConf.structured(ExpConfig(model=student_model_cfg))
train_dataloader, val_dataloader = prepare_dataloaders(student_conf)
student_model = (
    Conv1dNet(
        in_features=student_conf.features.n_mels, 
        n_classes=len(student_conf.idx_to_keyword),
        conf=student_conf.model
    )
    .to(student_conf.device)
)
optimizer = torch.optim.Adam(params=student_model.parameters(), lr=student_conf.optim.lr)
criterion = torch.nn.CrossEntropyLoss()

train_loop(
    student_conf, student_model, criterion, optimizer,
    train_dataloader, val_dataloader, description='student'
)

## Distill Teacher Into Student

In [None]:
teacher_model = (
    Conv1dNet(
        in_features=teacher_conf.features.n_mels, 
        n_classes=len(teacher_conf.idx_to_keyword),
        conf=teacher_conf.model
    )
    .to(teacher_conf.device)
)

teacher_model.load_state_dict(
    torch.load('runs/20221011_113801_teacher/ckpts/model_epoch_9_val_acc_0.903.ckpt'),
    strict=False
)
teacher_model.eval()

In [None]:
student_conf = omegaconf.OmegaConf.structured(
    ExpConfig(
        model=student_model_cfg,
        distill=DistillConfig(weight=0.90, softmax_temp=10)
    )
)
train_dataloader, val_dataloader = prepare_dataloaders(student_conf)

student_model = (
    Conv1dNet(
        in_features=student_conf.features.n_mels, 
        n_classes=len(student_conf.idx_to_keyword),
        conf=student_conf.model
    )
    .to(student_conf.device)
)

optimizer = torch.optim.Adam(params=student_model.parameters(), lr=student_conf.optim.lr)
student_criterion = torch.nn.CrossEntropyLoss()

In [None]:
def train_one_epoch_kd(
        model, teacher_model, distill_conf, student_criterion, optimizer, loader,
        device, epoch_index, tb_writer, log_interval=100
    ):
    
    kd_criterion = torch.nn.KLDivLoss(reduction='batchmean')
    
    T = distill_conf.softmax_temp
    
    model.train()
    teacher_model.eval()
    
    running_loss = 0.
    running_sum_loss = 0.
    running_kd_loss = 0.
    running_true_preds, running_preds = 0, 0
    last_loss, last_acc = 0., 0.

    for i, (inputs, labels) in enumerate(loader):
        
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        logits = model(inputs)
        preds = logits.argmax(1)
        student_loss = student_criterion(logits, labels)
        
        with torch.no_grad():
            teacher_probs = F.softmax(teacher_model(inputs) / T, dim=1)
        
        kd_loss = kd_criterion(F.log_softmax(logits / T, dim=1), teacher_probs)
        
        loss = (1 - distill_conf.weight) * student_loss + distill_conf.weight * kd_loss * T ** 2
        loss.backward()

        optimizer.step()
        
        running_loss += student_loss.item()
        running_kd_loss += kd_loss.item()
        running_sum_loss += loss.item()
        running_true_preds += (preds == labels).sum()
        running_preds += torch.numel(preds)
        
        if i % log_interval == log_interval - 1:
            last_loss = running_loss / log_interval
            last_kd_loss = running_kd_loss / log_interval
            last_sum_loss = running_sum_loss / log_interval
            last_acc = running_true_preds / running_preds 
            tb_x = epoch_index * len(loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            tb_writer.add_scalar('Loss_kd/train', last_kd_loss, tb_x)
            tb_writer.add_scalar('Loss_sum/train', last_sum_loss, tb_x)
            tb_writer.add_scalar('Accuracy/train', last_acc, tb_x)
            running_loss, running_kd_loss, running_sum_loss = 0., 0., 0.
            running_true_preds, running_preds = 0, 0

    return last_loss, last_acc

In [None]:
ckpt_dir, tb_writer = init_logger_and_dump_params(student_conf, student_model, description='student_distillation')

best_val_acc = -1.

for epoch in range(student_conf.optim.n_epochs):

    avg_loss, avg_acc = train_one_epoch_kd(
        student_model, teacher_model, student_conf.distill, student_criterion, optimizer, 
        train_dataloader, student_conf.device, epoch, tb_writer
    )

    val_loss, val_acc = validation(
        student_model, student_criterion, val_dataloader,
        student_conf.device, epoch, tb_writer
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(
            student_model.state_dict(),
            ckpt_dir / f"model_epoch_{epoch + 1}_val_acc_{val_acc:.3f}.ckpt"
        )

# Error Analysis

In [None]:
# iter through val_dataloader
# plot confusion matrix
# listen to misclassification