In [1]:
from transformer import SpeechTransformer
import os
import torch
import random
import numpy as np
from torch.utils.data import DataLoader
from custom_dataset import SpectrogramDataset
from training_pipeline import repeat_training

In [2]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [9]:
data_path = "data/train/audio_transformed"
train_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.TRAIN)
val_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.VAL)
test_dataset = SpectrogramDataset(data_path, set_type=SpectrogramDataset.TEST)

batch_size = 512
n_workers = 4
prefetch_factor = 2 if n_workers > 0 else None
persistent_workers = True if n_workers > 0 else False

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)

In [10]:
repetitions = 4
lr = 0.001
epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_dir = "output/models/transformer"
history_dir = "output/history/transformer"

os.makedirs(model_dir, exist_ok=True)
os.makedirs(history_dir, exist_ok=True)

model_path = model_dir + "/transformer.pth"
history_path = history_dir + "/transformer.pkl"

repeat_training(repetitions, SpeechTransformer, lr, model_path, history_path, epochs, train_loader, val_loader, test_loader, device)

training iteration: 1 of 4
starting training...
epoch: 1, training loss: 0.0029692551931392577, training accuracy: 63.14594425305356
epoch: 1, validation loss: 0.003057063743800898, validation accuracy: 62.09179170344219
model saved

epoch: 2, training loss: 0.002922780541111002, training accuracy: 63.713592233009706
epoch: 2, validation loss: 0.003053391723290511, validation accuracy: 62.09179170344219
model saved

epoch: 3, training loss: 0.002924942224517796, training accuracy: 63.713592233009706
epoch: 3, validation loss: 0.003064874437985332, validation accuracy: 62.09179170344219

epoch: 4, training loss: 0.0029233512793855163, training accuracy: 63.713592233009706
epoch: 4, validation loss: 0.0030574091841943755, validation accuracy: 62.09179170344219

epoch: 5, training loss: 0.0028905911768375052, training accuracy: 63.713592233009706
epoch: 5, validation loss: 0.0029639611615262054, validation accuracy: 62.09179170344219
model saved

epoch: 6, training loss: 0.002843595614188

KeyboardInterrupt: 