In [1]:
from CNN import AudioClassifier
import os
import torch
import random
import numpy as np
from torch.utils.data import DataLoader
from custom_dataset import SpectrogramDataset
from training_pipeline import repeat_training

In [2]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [3]:
data_path = "data/train/audio_transformed"
train_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.TRAIN)
val_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.VAL)
test_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.TEST)

batch_size = 1024
n_workers = 4
prefetch_factor = 2 if n_workers > 0 else None
persistent_workers = True if n_workers > 0 else False

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)

In [4]:
repetitions = 4
lr = 0.001
epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_dir = "output/models/cnn"
history_dir = "output/history/cnn"

os.makedirs(model_dir, exist_ok=True)
os.makedirs(history_dir, exist_ok=True)

model_path = model_dir + "/cnn.pth"
history_path = history_dir + "/cnn.pkl"

repeat_training(repetitions, AudioClassifier, lr, model_path, history_path, epochs, train_loader, val_loader, test_loader, device)

training iteration: 1 of 4
starting training...
epoch: 1, training loss: 0.00155712549790222, training accuracy: 62.51565925461948
epoch: 1, validation loss: 0.0015782962023563054, validation accuracy: 62.09179170344219
model saved

epoch: 2, training loss: 0.0014336691930717411, training accuracy: 63.713592233009706
epoch: 2, validation loss: 0.0014897912135717623, validation accuracy: 62.09179170344219
model saved

epoch: 3, training loss: 0.0013390097485938641, training accuracy: 63.82516442217351
epoch: 3, validation loss: 0.0013213789736603806, validation accuracy: 62.68020005884084
model saved

epoch: 4, training loss: 0.001225963864922337, training accuracy: 64.17554024428438
epoch: 4, validation loss: 0.0012701942083168534, validation accuracy: 63.10679611650485
model saved

epoch: 5, training loss: 0.0011721997869787565, training accuracy: 64.69425305355465
epoch: 5, validation loss: 0.0011908375680288802, validation accuracy: 64.13651073845249
model saved

epoch: 6, training 