# BlaschkeNet Training Notebook

In [None]:
import os
from os.path import join
import math
from typing import Tuple

import numpy as np
import torch
from torch import nn
from torch.optim import SGD, Optimizer
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Training on device: {device}")

Training on device: cuda:0


In [None]:
root_path: str = join("drive", "MyDrive")
# personal_path: str = join("School", "UW BS MS", "WI 24", "MATH 515 Project")  # Elliott's path
personal_path: str =  join("MATH 515 Project")  # Robert's path
data_path: str = join(root_path, personal_path)
if os.path.exists(data_path):
  print(f"Successfuly found data path: \"{data_path}\"")
else:
  raise ValueError(f"Unable to find data path: \"{data_path}\"")

Successfuly found data path: "drive/MyDrive/MATH 515 Project"


# Dataset

In [None]:
win_length_ms: int = 25
overlap_ms: int = 10

sampling_rate: int = 5000
duration: int = 5
waveform_length: int = sampling_rate * duration

hop_length: int = int(sampling_rate * 0.001 * overlap_ms)
win_length: int = int(sampling_rate * 0.001 * win_length_ms)
n_fft: int = int(sampling_rate * 0.001 * win_length)

In [None]:
train_val_split: float = 0.9
num_bins: int = math.ceil(1 + n_fft / 2)
num_frames: int = math.ceil((waveform_length - win_length) / hop_length + 1)

batch_size: int = 4


def get_dataloaders() -> Tuple[DataLoader, DataLoader]:
  dataset = torch.load(join(data_path, "dataset5k_plus.pt"))
  train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])
  return DataLoader(train_dataset, batch_size=batch_size, shuffle=True), DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

# Model

In [None]:
import torch
from torch import nn


class BlaschkeNet(nn.Module):
    def __init__(
            self,
            num_bins: int,
            num_frames: int,
            num_classes: int = 3,
            kernel_size: int = 3,
            p: float = 0.25
    ):
        super(BlaschkeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=kernel_size)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=kernel_size)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(p=p)
        self.flatten = nn.Flatten()
        with torch.no_grad():
            sample_input = torch.zeros(1, 2, num_bins, num_frames)
            sample_output = self.pool1(torch.relu(self.bn1(self.conv1(sample_input))))
            sample_output = self.pool2(torch.relu(self.bn2(self.conv2(sample_output))))
        hidden_dim = int(torch.prod(torch.tensor(sample_output.shape)))
        print(hidden_dim)


        # lol
        hidden_dim = 598272


        self.fc1 = nn.Linear(hidden_dim, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x: torch.Tensor):
        # x.shape = (batch_size, num_bins, num_frames, 2)
        x = x.permute(0, 3, 1, 2)
        # x.shape = (batch_size, 2, num_bins, num_frames)
        # print(x.shape)
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        # print(x.shape)
        x = self.flatten(self.dropout(x))
        # print(x.shape)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        # x.shape = (batch_size, num_classes)
        return torch.softmax(x, dim=1)

# Training

In [None]:
num_epochs: int = 10

def train(model: BlaschkeNet, train_loader: DataLoader, val_loader: DataLoader, optimizer: Optimizer):
  model.to(device)

  train_history = []
  val_history = []
  for _ in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader):
      x = torch.stft(batch[0][:, 1:], n_fft, hop_length, win_length, return_complex=False).to(dtype=torch.float32).to(device)
      # print(x.shape)
      y = batch[0][:, 0].to(dtype=torch.int64).to(device)

      optimizer.zero_grad()

      y_hat = model(x)

      loss = F.cross_entropy(y_hat, y)
      loss.backward()

      optimizer.step()

      train_loss += loss.item()
    train_history.append(train_loss / len(train_loader))

    model.eval()
    val_loss = 0
    with torch.no_grad():
      for batch in tqdm(val_loader):
        x = torch.stft(batch[0][:, 1:], n_fft, hop_length, win_length, return_complex=False).to(dtype=torch.float32).to(device)
        y = batch[0][:, 0].to(dtype=torch.int64).to(device)
        y_hat = model(x)
        loss = F.cross_entropy(y_hat, y)
        val_loss += loss.item()
    val_history.append(val_loss / len(val_loader))

    print(f"Epoch {len(train_history)}, train-loss={train_history[-1]}, val-loss={val_history[-1]}")
    torch.save(model.state_dict(), f"blaschke-net-ckpt-{len(train_history)}")
  return train_history, val_history


def main():
  # model = BlaschkeNet(num_bins=3031, num_frames=1003, num_classes=3)
  model = BlaschkeNet(num_bins=num_bins, num_frames=num_frames, num_classes=3, p=0.5)
  model.to(device)
  num_params = sum([torch.sum(torch.tensor(param.size())) for param in model.parameters()])
  print(f"Initializing model with {num_params} parameters")
  train_loader, val_loader = get_dataloaders()
  optimizer = torch.optim.Adam(model.parameters())
  train(model, train_loader, val_loader, optimizer)


main()

606144
Initializing model with 599476 parameters


100%|██████████| 136/136 [00:12<00:00, 10.69it/s]
100%|██████████| 15/15 [00:00<00:00, 45.39it/s]


Epoch 1, train-loss=1.2190155036309187, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:12<00:00, 10.78it/s]
100%|██████████| 15/15 [00:00<00:00, 37.02it/s]


Epoch 2, train-loss=1.2242388865526985, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:12<00:00, 10.57it/s]
100%|██████████| 15/15 [00:00<00:00, 36.19it/s]


Epoch 3, train-loss=1.2242388865526985, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:13<00:00, 10.42it/s]
100%|██████████| 15/15 [00:00<00:00, 44.24it/s]


Epoch 4, train-loss=1.2242388865526985, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:12<00:00, 10.77it/s]
100%|██████████| 15/15 [00:00<00:00, 43.03it/s]


Epoch 5, train-loss=1.2242388865526985, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:12<00:00, 10.74it/s]
100%|██████████| 15/15 [00:00<00:00, 37.87it/s]


Epoch 6, train-loss=1.2242388865526985, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:12<00:00, 10.81it/s]
100%|██████████| 15/15 [00:00<00:00, 38.63it/s]


Epoch 7, train-loss=1.2242388865526985, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:12<00:00, 10.75it/s]
100%|██████████| 15/15 [00:00<00:00, 46.35it/s]


Epoch 8, train-loss=1.2187241806703455, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:12<00:00, 10.75it/s]
100%|██████████| 15/15 [00:00<00:00, 45.40it/s]


Epoch 9, train-loss=1.2187241806703455, val-loss=1.1681114355723063


100%|██████████| 136/136 [00:12<00:00, 10.71it/s]
100%|██████████| 15/15 [00:00<00:00, 45.81it/s]


Epoch 10, train-loss=1.2187241806703455, val-loss=1.1681114355723063


# Evaluation

In [22]:
def evaluate():
  artifact_name = "blaschke-net-ckpt-10"
  model = BlaschkeNet(num_bins=num_bins, num_frames=num_frames, num_classes=3, p=0.5)
  model.to(device)
  num_params = sum([torch.sum(torch.tensor(param.size())) for param in model.parameters()])
  print(f"Initializing model with {num_params} parameters")
  model.load_state_dict(torch.load(artifact_name))
  _, val_loader = get_dataloaders()

  model.eval()
  accuracy = 0
  with torch.no_grad():
    for batch in tqdm(val_loader):
      x = torch.stft(batch[0][:, 1:], n_fft, hop_length, win_length, return_complex=False).to(dtype=torch.float32).to(device)
      y = batch[0][:, 0].to(dtype=torch.int64).to(device)
      y_hat = model(x)
      pred = torch.argmax(y_hat, dim=1)
      accuracy += torch.mean((y == pred).to(dtype=torch.float32))
  accuracy /= len(val_loader)
  print(f"Model artifact {artifact_name} has accuracy {accuracy}")


evaluate()

606144
Initializing model with 599476 parameters


100%|██████████| 15/15 [00:00<00:00, 42.48it/s]

Model artifact blaschke-net-ckpt-10 has accuracy 0.40000003576278687



