# BlaschkeNet Training Notebook

In [1]:
import os
from os.path import join
import math
from typing import Tuple

import numpy as np
import torch
from torch import nn
from torch.optim import SGD, Optimizer
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Training on device: {device}")

Training on device: cuda:0


In [4]:
root_path: str = join("drive", "MyDrive")
# personal_path: str = join("School", "UW BS MS", "WI 24", "MATH 515 Project")  # Elliott's path
personal_path: str =  join("MATH 515 Project")  # Robert's path
data_path: str = join(root_path, personal_path)
if os.path.exists(data_path):
  print(f"Successfuly found data path: \"{data_path}\"")
else:
  raise ValueError(f"Unable to find data path: \"{data_path}\"")

Successfuly found data path: "drive/MyDrive/MATH 515 Project"


# Dataset

In [11]:
sampling_rate: int = 11250
duration: int = 10
waveform_length: int = sampling_rate * duration

In [12]:
train_val_split: float = 0.9
batch_size: int = 8


def get_dataloaders() -> Tuple[DataLoader, DataLoader]:
  dataset = torch.load(join(data_path, "romance.pt"))
  train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])
  return DataLoader(train_dataset, batch_size=batch_size, shuffle=True), DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

# Model

In [13]:
import torch
from torch import nn


class BlaschkeNet(nn.Module):
    def __init__(
            self,
            waveform_length: int,
            num_conv_blocks: int = 2,
            num_classes: int = 3,
            kernel_size: int = 3,
            p: float = 0.5
    ):
        super(BlaschkeNet, self).__init__()
        self.conv_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(in_channels=1 if i == 0 else 64, out_channels=64, kernel_size=kernel_size),
                nn.ReLU(),
                nn.Dropout(p=p),
                nn.MaxPool1d(kernel_size=2, stride=2)
            )
            for i in range(num_conv_blocks)
        ])
        self.flatten = nn.Flatten()
        with torch.no_grad():
            sample_x = torch.zeros(1, 1, waveform_length)
            for block in self.conv_blocks:
              sample_x = block(sample_x)
            sample_x = self.flatten(sample_x)
        hidden_dim = sample_x.shape[-1]
        self.fc1 = nn.Linear(hidden_dim, 256)
        self.fc2 = nn.Linear(256, num_classes)
        print(hidden_dim)

    def forward(self, x: torch.Tensor):
        # x.shape = (batch_size, waveform_length)
        x = x.unsqueeze(dim=1)
        # x.shape = (batch_size, 1, waveform_length)
        for block in self.conv_blocks:
          x = block(x)
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        # x.shape = (batch_size, num_classes)
        return torch.softmax(x, dim=1)

In [14]:
import torch
from torch import nn


class BlaschkeRNN(nn.Module):
    def __init__(
            self,
            waveform_length: int,
            hidden_dim: int = 512,
            num_classes: int = 3
    ):
        super(BlaschkeRNN, self).__init__()
        self.rnn = nn.RNN(waveform_length, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x: torch.Tensor):
        x, _ = self.rnn(x)
        x = self.fc(x)
        return torch.softmax(x, dim=1)

# Training

In [16]:
num_epochs: int = 10

def train(model: BlaschkeNet, train_loader: DataLoader, val_loader: DataLoader, optimizer: Optimizer):
  model.to(device)

  train_history = []
  val_history = []
  for _ in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader):
      x = batch[0][:, 1:].to(dtype=torch.float32).to(device)
      y = batch[0][:, 0].to(dtype=torch.int64).to(device)

      optimizer.zero_grad()

      y_hat = model(x)

      loss = F.cross_entropy(y_hat, y)
      loss.backward()

      optimizer.step()

      train_loss += loss.item()
    train_history.append(train_loss / len(train_loader))

    # model.eval()
    val_loss = 0
    accuracy = 0
    # with torch.no_grad():
    for batch in tqdm(val_loader):
      x = batch[0][:, 1:].to(dtype=torch.float32).to(device)
      y = batch[0][:, 0].to(dtype=torch.int64).to(device)
      optimizer.zero_grad()
      y_hat = model(x)
      loss = F.cross_entropy(y_hat, y)
      loss.backward()
      optimizer.step()
      val_loss += loss.item()

      pred = torch.argmax(y_hat, dim=1)
      accuracy += torch.sum((y == pred).to(dtype=torch.float32))

    val_history.append(val_loss / len(val_loader))

    print(f"Epoch {len(train_history)}, train-loss={train_history[-1]}, val-loss={val_history[-1]}, val-acc={accuracy/len(val_loader.dataset)}")
    torch.save(model.state_dict(), f"blaschke-net-ckpt-{len(train_history)}")
  return train_history, val_history


def main():
  model = BlaschkeNet(waveform_length=waveform_length, num_conv_blocks=2, p=0.5)
  model.to(device)
  num_params = sum([torch.sum(torch.tensor(param.size())) for param in model.parameters()])
  print(f"Initializing model with {num_params} parameters")
  train_loader, val_loader = get_dataloaders()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  train(model, train_loader, val_loader, optimizer)


main()

1799872
Initializing model with 1800973 parameters


100%|██████████| 42/42 [00:09<00:00,  4.33it/s]
100%|██████████| 5/5 [00:01<00:00,  4.53it/s]


Epoch 1, train-loss=1.3341829634848095, val-loss=1.3514448642730712, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:09<00:00,  4.26it/s]
100%|██████████| 5/5 [00:01<00:00,  4.54it/s]


Epoch 2, train-loss=1.3193020281337557, val-loss=1.3514448642730712, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:09<00:00,  4.24it/s]
100%|██████████| 5/5 [00:01<00:00,  4.52it/s]


Epoch 3, train-loss=1.3193020252954393, val-loss=1.3264448642730713, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:09<00:00,  4.38it/s]
100%|██████████| 5/5 [00:01<00:00,  4.43it/s]


Epoch 4, train-loss=1.3193020252954393, val-loss=1.3514448642730712, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:09<00:00,  4.37it/s]
100%|██████████| 5/5 [00:01<00:00,  4.53it/s]


Epoch 5, train-loss=1.3193020281337557, val-loss=1.3264448642730713, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:09<00:00,  4.37it/s]
100%|██████████| 5/5 [00:01<00:00,  4.34it/s]


Epoch 6, train-loss=1.3193020281337557, val-loss=1.3014448642730714, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:09<00:00,  4.30it/s]
100%|██████████| 5/5 [00:01<00:00,  4.51it/s]


Epoch 7, train-loss=1.3193020281337557, val-loss=1.3514448642730712, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:10<00:00,  4.16it/s]
100%|██████████| 5/5 [00:01<00:00,  4.54it/s]


Epoch 8, train-loss=1.3193020224571228, val-loss=1.3264448642730713, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:09<00:00,  4.36it/s]
100%|██████████| 5/5 [00:01<00:00,  4.27it/s]


Epoch 9, train-loss=1.3222782157716297, val-loss=1.3514448642730712, val-acc=0.2222222238779068


100%|██████████| 42/42 [00:10<00:00,  4.12it/s]
100%|██████████| 5/5 [00:01<00:00,  4.46it/s]


Epoch 10, train-loss=1.3222782186099462, val-loss=1.3264448642730713, val-acc=0.2222222238779068


# Evaluation

In [None]:
def evaluate():
  artifact_name = "blaschke-net-ckpt-3"
  model = BlaschkeNet(waveform_length=waveform_length, num_conv_blocks=2, p=0.25)
  model.to(device)
  num_params = sum([torch.sum(torch.tensor(param.size())) for param in model.parameters()])
  print(f"Initializing model with {num_params} parameters")
  model.load_state_dict(torch.load(artifact_name))
  _, val_loader = get_dataloaders()

  model.eval()
  accuracy = 0
  with torch.no_grad():
    for batch in tqdm(val_loader):
      x = batch[0][:, 1:].to(dtype=torch.float32).to(device)
      y = batch[0][:, 0].to(dtype=torch.int64).to(device)
      y_hat = model(x)
      pred = torch.argmax(y_hat, dim=1)
      accuracy += torch.mean((y == pred).to(dtype=torch.float32))
  accuracy /= len(val_loader)
  print(f"Model artifact {artifact_name} has accuracy {accuracy}")


evaluate()