# Title

Description of main training pipeline

In [11]:
!cd C:\\Users\\hekto\\PycharmProjects\\MyThesis\\code
%pwd

'C:\\Users\\hekto\\PycharmProjects\\MyThesis\\code'

In [14]:
from dataclasses import dataclass, field
from typing import Callable

import numpy as np
import torch
from torch import Tensor
from torch.nn import BCELoss
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, LRScheduler
from torch.utils.data import DataLoader, Dataset

from dataset import LibriSpeech
from torch_framework.config import Config, default_config
import torch_framework.models.baseline as base
import torch_framework.models.models as m
from torch_framework.dataset import LoadVADFromTimestamps
from utils.gui import GUI, ProgressBar

In [20]:
def train_model(model: torch.nn.Module, data_loader: DataLoader, loss_fn: Callable, train_config: "TrainConfig",
                metrics: list[Callable], verbose: int = 1):

    print(f"trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    if verbose == 2:
        gui = GUI(100, "Training log")

        ax = gui.add_subplot(2, 1, 1)
        ax.set_ylim(0, .12)
        ax.set_title("loss")
        gui.add_line(1, "loss")

        ax = gui.add_subplot(2, 1, 2)
        ax.set_ylim(0, 100)
        ax.set_title("accuracy")
        gui.add_line(2, "acc")
        ax.set_xlabel("epoch")

        gui.config()

    epoch, it = 0, 0
    pf = {"SNR": "%.2f", "lr": "%.2e", "train loss": "%.3f", "train accuracy": "%.2f"}
    model.train(True)
    while True:
        avg_loss, avg_acc = 0, 0
        progress_bar = ProgressBar(data_loader, epoch, colour="BLUE", postfix_fmt=pf) if verbose > 0 else data_loader
        for audio, vad, ann in progress_bar:
            # Zero your gradients for every batch!
            train_config.optimizer.zero_grad()

            pred = model(audio)
            if pred.shape[-1] == vad.shape[-1] + 1:
                pred = pred[:, :-1]
            assert pred.shape == vad.shape, f"{pred.shape=}, {vad.shape=}"

            loss = loss_fn(pred, vad)
            loss.backward()

            train_config.optimizer.step()

            for s in train_config.schedulers:
                s.step()

            if verbose > 0:
                # vad_pred = prob_to_vad(pred, audio.shape[-1])[None, :]
                acc = binary_accuracy(pred, vad)
                # prog_bar.set_postfix_str(pf(loss.item(), acc))
                progress_bar.set_value("lr", train_config.optimizer.param_groups[-1]['lr'])
                progress_bar.set_value("train loss", loss.item(), alpha=.5)
                progress_bar.set_value("train accuracy", 100 * acc, alpha=.5)
                avg_loss += loss.item()
                avg_acc += acc

                if verbose == 2:
                    gui.add_data("loss", it, loss.item())
                    gui.add_data("acc", it, 100 * acc)

            if verbose == 2:
                gui.update()

            it += 1

        print("train loss: %.2f, train acc: %.2f" % (avg_loss / len(data_loader), 100 * avg_acc / len(data_loader)) + "%")

        epoch += 1


Config

In [23]:
cfg = default_config()

@dataclass
class TrainConfig:
    optimizer: Optimizer = None

    schedulers: list[LRScheduler] = field(default_factory=lambda: [])
    batch_size: int = 8
    shuffle: bool = True

    def add_scheduler(self, scheduler: LRScheduler):
        self.schedulers.append(scheduler)

Setting up dataset

In [24]:
labels = LoadVADFromTimestamps("silero_vad_512_timestamp")
ls_data = LibriSpeech(labels=labels, size=None, config=cfg)

train_config = TrainConfig()
data_loader = DataLoader(ls_data, batch_size=train_config.batch_size, shuffle=train_config.shuffle, num_workers=0,
                         collate_fn=ls_data.default_collate_fn, persistent_workers=False)

size = 585, full dataset: 585


Functionality

In [17]:
def binary_accuracy(y_pred: Tensor, y_true: Tensor) -> float:
    assert len(y_pred.shape) == 2
    assert y_true.shape == y_pred.shape, f"{y_true.shape=}, {y_pred.shape=}"
    assert not torch.any(torch.isnan(y_pred) + torch.isinf(y_pred))
    return float(torch.mean(torch.sum(torch.round(y_pred) == y_true, dim=1) / y_true.shape[1]))


def focal_loss(alpha: float = .5, gamma: float = 0.):
    def fn(y_pred: Tensor, y_true):
        eps = 1e-6
        # print(y_pred.shape, y_true.shape, eps)
        y_pred = torch.clip(y_pred, eps, 1 - eps)
        loss = y_true * (1 - alpha) * (1 - y_pred) ** gamma * torch.log(y_pred) + \
               (1 - y_true) * alpha * y_pred ** gamma * torch.log(1 - y_pred)
        # print(loss.shape)
        return -torch.mean(loss)

    return fn

Config, define model & train

In [26]:
model = m.TDNN(cfg)
print(model)

train_config.optimizer = Adam(model.parameters(), lr=1e-2)
train_config.schedulers = [
    MultiStepLR(train_config.optimizer, [40, 60], gamma=0.1)
]

data_loader = DataLoader(ls_data, batch_size=train_config.batch_size, shuffle=train_config.shuffle, num_workers=0,
                         collate_fn=ls_data.default_collate_fn, persistent_workers=False)

train_model(model, data_loader,
            loss_fn=focal_loss(.5, 2.),
            train_config=train_config,
            metrics=[binary_accuracy],
            verbose=1)

TDNN(
  (l0): STFT(nfft=1024, out=view_as_real)
  (l1): TDLayer(513, in_channels=2, out_channels=1, kernel_size=4)
  (norm1): InstanceNorm1d(255, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (l2): TDLayer(255, in_channels=1, out_channels=1, kernel_size=4)
  (norm2): InstanceNorm1d(126, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (l3): TDLayer(126, in_channels=1, out_channels=1, kernel_size=3)
  (norm3): InstanceNorm1d(62, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (lstm): LSTM(62, 16)
  (l_end): Conv1d(16, 1, kernel_size=(1,), stride=(1,))
)
trainable params: 11483


epoch 0:   0%|[34m                    [0m|0/74 [00:00, ?it/s] -- [0mException ignored in: <generator object tqdm.__iter__ at 0x000001F13A1A7480>
Traceback (most recent call last):
  File "C:\Users\hekto\PycharmProjects\MyThesis\code\venv\Lib\site-packages\tqdm\std.py", line 1197, in __iter__
    self.close()
  File "C:\Users\hekto\PycharmProjects\MyThesis\code\venv\Lib\site-packages\tqdm\std.py", line 1275, in close
    pos = abs(self.pos)
          ^^^^^^^^^^^^^
KeyboardInterrupt: 


KeyboardInterrupt: 