In [1]:
import glob
import csv
import sys, os.path

import sklearn
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision import transforms
#from torch.utils.tensorboard import SummaryWriter
#from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Our modules
import sys
sys.path.append('.')
sys.path.append('..')

from vae import configs, train, plot_utils, models
from vae.data import build_dataloader
from vae.latent_spaces import dimensionality_reduction, plot_spaces
from vae.reconstructions import plot_reconstructions
from vae.models import model_utils

In [3]:
classes = ['violin', 'viola', 'cello', 'double-bass',
                'clarinet', 'bass-clarinet', 'saxophone', 'flute', 'oboe', 'bassoon', 'contrabassoon',
                'cor-anglais', 'french-horn', 'trombone', 'trumpet', 'tuba', 'english-horn',
                'guitar', 'mandolin', 'banjo', 'chromatic-percussion']

chromatic_perc = ['agogo-bells', 'banana shaker', 'bass drum', 'bell-tree', 'cabasa',
                        'castanets', 'chinese-cymbal', 'clash-cymbals', 'cowbell', 'djembe', 'djundjun', 'flexatone', 'guiro',
                        'lemon-shaker',  'motor-horn',  'ratchet', 'sheeps-toenails', 'sizzle-cymbal', 'sleigh-bells', 'snare-drum',
                        'spring-coil', 'squeaker', 'strawberry-shaker', 'surdo', 'suspended-cymbal', 'swanee-whistle',
                        'tambourine', 'tam-tam', 'tenor drum', 'thai gong', 'tom-toms', 'train-whistle', 'triangle',
                        'vibraslap', 'washboard', 'whip', 'wind-chimes', 'woodblock']

labels_list = [i for i in range(len(classes))]


In [4]:
labels_list

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

In [8]:
model_name = 'supervised_timbre'
input = 'mel_cut'
epochs = 100


train_dataset, train_dataloader, val_dataset, val_dataloader = build_dataloader.build_dataset(input, model_name)
print('Number of files in the training dataset:', len(train_dataset))
print('Number of files in the validation dataset:', len(val_dataset))

# show configs
configs.show_configs(model=model_name)

# import model
model = model_utils.import_model(model_name, input)

# show model
model_utils.show_model(model, input)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=configs.ParamsConfig.LEARNING_RATE, weight_decay=1e-5)

print('Start training model on', device, '...')

train_loss = []
val_loss = []

for epoch in range(epochs):
    model.train()

    train_loss.append(0)
    num_batches = 0

    pbar = tqdm(total=len(train_dataloader))
    print("Epoch:", epoch)

    for image_batch, y in train_dataloader:
        image_batch = image_batch.to(device, dtype=torch.float)
        y = y.to(device, dtype=torch.long)

        # vae reconstruction
        y_train_pred = model(image_batch)

        # reconstruction error
        total_loss = criterion(y_train_pred, y)

        # backpropagation
        optimizer.zero_grad()
        total_loss.backward()

        # one step of the optmizer (using the gradients from backpropagation)
        optimizer.step()

        train_loss[-1] += total_loss.item()
        num_batches += 1

        pbar.update()

    train_loss[-1] /= num_batches

    pbar.close()

    print("training_avg_loss={:.2f}\n".format(train_loss[-1]))

    model.eval()
    val_loss.append(0)
    num_batches = 0
    with torch.no_grad():
        for image_batch, y in val_dataloader:
            image_batch = image_batch.to(device, dtype=torch.float)
            y = y.to(device, dtype=torch.long)
            
            # vae reconstruction
            y_val_pred = model(image_batch)
            
            total_val_loss = criterion(y_val_pred, y)

            val_loss[-1] += total_val_loss.item()
                
            num_batches += 1

    val_loss[-1] /= num_batches

    print("val_avg_loss={:.2f}\n".format(val_loss[-1]))

    checkpoint = {
                    'model': model.state_dict(),
                    'train loss': train_loss,
                    'val loss': val_loss,
                 }



    if not os.path.exists(configs.ParamsConfig.TRAINED_MODELS_PATH):
        os.mkdir(configs.ParamsConfig.TRAINED_MODELS_PATH)

    # save trained model every 10 epochs
    if epoch % 10 == 0:
        torch.save(checkpoint,
                    os.path.join(configs.ParamsConfig.TRAINED_MODELS_PATH, 'saved_model_' + str(epoch) + "epochs.pth"))

        plot_utils.plot_supervised_losses(model=model, model_name=model_name, trained_epochs=epoch)

  0%|                                                                                          | 0/727 [00:00<?, ?it/s]

Number of files in the training dataset: 11630
Number of files in the validation dataset: 2051
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
MultiheadAttention-1  [[-1, 2, 128], [-1, 22, 22]]               0
            Linear-2                   [-1, 64]         180,288
Total params: 180,288
Trainable params: 180,288
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.94
Params size (MB): 0.69
Estimated Total Size (MB): 1.64
----------------------------------------------------------------
None
Start training model on cuda ...
Epoch: 0


KeyboardInterrupt: 