In [1]:
%load_ext autoreload
%autoreload 2
from CNN import AudioClassifier
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from custom_dataset import SpectrogramDataset, BinaryDataset, create_sampler
from training_pipeline import repeat_training, set_seed, worker_init_fn, plot_results
from collections import Counter
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

## Parameters

In [2]:
SEED = 42
set_seed(SEED)
repetitions = 4
lr = 0.001
epochs = 100
tolerance = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

alpha = 1
dropout = 0.3
weight_decay = 0.0001
augmented_fraction = 0.5
label_smoothing = 0.1

batch_size = 1024
n_workers = 4
prefetch_factor = 2 if n_workers > 0 else None
persistent_workers = True if n_workers > 0 else False

## 10 classes + unknown

In [3]:
data_path = "data/train/audio_transformed"
train_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.TRAIN, augmentation=True, augmented_fraction=augmented_fraction)
val_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.VAL)
test_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.TEST)

sampler = create_sampler(train_dataset, alpha)
train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor,persistent_workers=persistent_workers, worker_init_fn=worker_init_fn)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)

In [4]:
def init_cnn_all_classes():
     return AudioClassifier(num_classes=11, drop=dropout)

set_seed(SEED)

model_dir = f"output/models/all_classes/final/cnn"
history_dir = f"output/history/all_classes/final/cnn"

os.makedirs(model_dir, exist_ok=True)
os.makedirs(history_dir, exist_ok=True)

model_path = model_dir + "/cnn.pth"
history_path = history_dir + "/cnn.pkl"

repeat_training(repetitions, init_cnn_all_classes, lr, model_path, history_path, epochs, train_loader, val_loader, test_loader, device, tolerance=tolerance, weight_decay=weight_decay, label_smoothing=label_smoothing)

training iteration: 1 of 4
starting training...
epoch: 1, training loss: 0.002289330211695257, training accuracy: 14.357579079235828, training balanced accuracy: 14.358138315684712
epoch: 1, validation loss: 0.002432755141722591, validation accuracy: 9.944101206237129, validation balanced accuracy: 23.721408018578856
model saved

epoch: 2, training loss: 0.0021288441736198615, training accuracy: 22.3614155966176, training balanced accuracy: 22.174287799510676
epoch: 2, validation loss: 0.00232615701308423, validation accuracy: 13.033245072080023, validation balanced accuracy: 31.19589957315798
model saved

epoch: 3, training loss: 0.0019948937178516776, training accuracy: 29.35327278421547, training balanced accuracy: 29.1880138111524
epoch: 3, validation loss: 0.002148007217243371, validation accuracy: 18.181818181818183, validation balanced accuracy: 42.496035043095155
model saved

epoch: 4, training loss: 0.001925592780617237, training accuracy: 32.75328844347009, training balanced 

## 10 classes

In [3]:
data_path = "data/train/audio_transformed"
train_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.TRAIN, augmentation=True, augmented_fraction=augmented_fraction, use_unknown=False)
val_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.VAL, use_unknown=False)
test_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.TEST, use_unknown=False)

sampler = create_sampler(train_dataset, alpha)
train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor,persistent_workers=persistent_workers, worker_init_fn=worker_init_fn)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)

In [None]:
def init_cnn_without_unknown():
     return AudioClassifier(num_classes=10, drop=dropout)

set_seed(SEED)

model_dir = f"output/models/without_unknown/final/cnn"
history_dir = f"output/history/without_unknown/final/cnn"

os.makedirs(model_dir, exist_ok=True)
os.makedirs(history_dir, exist_ok=True)

model_path = model_dir + "/cnn.pth"
history_path = history_dir + "/cnn.pkl"

repeat_training(repetitions, init_cnn_without_unknown, lr, model_path, history_path, epochs, train_loader, val_loader, test_loader, device, tolerance=tolerance, weight_decay=weight_decay, label_smoothing=label_smoothing)

training iteration: 1 of 4
starting training...
epoch: 1, training loss: 0.002330244224117031, training accuracy: 13.313194519365627, training balanced accuracy: 13.294825365324561
epoch: 1, validation loss: 0.0025997006435194447, validation accuracy: 15.289095847885138, validation balanced accuracy: 15.252407750613575
model saved

epoch: 2, training loss: 0.0022907563559457778, training accuracy: 16.096666307044988, training balanced accuracy: 16.123512783981433
epoch: 2, validation loss: 0.0024243471092339586, validation accuracy: 23.088863019014358, validation balanced accuracy: 22.993101910870255
model saved

epoch: 3, training loss: 0.0022450334651015675, training accuracy: 19.014996223972382, training balanced accuracy: 19.051850149748248
epoch: 3, validation loss: 0.002408158052878922, validation accuracy: 22.894838960031045, validation balanced accuracy: 22.8438396988337
model saved

epoch: 4, training loss: 0.002181763414400837, training accuracy: 21.631243931384184, training 

KeyboardInterrupt: 

## Binary case

In [3]:
data_path = "data/train/audio_transformed"
train_dataset = BinaryDataset(data_path, set_type=SpectrogramDataset.TRAIN, augmentation=True, augmented_fraction=augmented_fraction)
val_dataset = BinaryDataset(data_path, set_type=SpectrogramDataset.VAL)
test_dataset = BinaryDataset(data_path, set_type=SpectrogramDataset.TEST)

sampler = create_sampler(train_dataset, alpha)
train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor,persistent_workers=persistent_workers, worker_init_fn=worker_init_fn)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)

In [7]:
def init_cnn_binary():
     return AudioClassifier(num_classes=2, drop=dropout)

set_seed(SEED)

model_dir = f"output/models/binary/final/cnn"
history_dir = f"output/history/binary/final/cnn"

os.makedirs(model_dir, exist_ok=True)
os.makedirs(history_dir, exist_ok=True)

model_path = model_dir + "/cnn.pth"
history_path = history_dir + "/cnn.pkl"

repeat_training(repetitions, init_cnn_binary, lr, model_path, history_path, epochs, train_loader, val_loader, test_loader, device, tolerance=tolerance, weight_decay=weight_decay, label_smoothing=label_smoothing)

training iteration: 1 of 4
starting training...
epoch: 1, training loss: 0.000677801558973645, training accuracy: 51.44456623864704, training balanced accuracy: 51.44461379113672
epoch: 1, validation loss: 0.0007163865508735793, validation accuracy: 49.07325684024713, validation balanced accuracy: 53.927656467923704
model saved

epoch: 2, training loss: 0.0006770348491953938, training accuracy: 52.403695584090194, training balanced accuracy: 52.398821742415855
epoch: 2, validation loss: 0.0007207519203538438, validation accuracy: 47.74933804060018, validation balanced accuracy: 53.82883795998664

epoch: 3, training loss: 0.0006747268691000952, training accuracy: 53.562480425931724, training balanced accuracy: 53.464036979752485
epoch: 3, validation loss: 0.0007029189787110219, validation accuracy: 56.29596940276552, validation balanced accuracy: 51.99046344859769
model saved




KeyboardInterrupt

