In [1]:
%load_ext autoreload
%autoreload 2

import os
import json
from Pytorch_VAE_LSTM import LSTMVAE
from music_dataset import TorchMusicDataset

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
import numpy as np

import scipy.io.wavfile as wavfile

In [2]:
epochs = 400
batch_size = 4
hidden_dim = 100
latent_dim = 100
n_seconds = 10

In [3]:
evolution_folder = "evolution"

In [4]:
if not os.path.exists(evolution_folder):
    os.mkdir(evolution_folder)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
dataset = TorchMusicDataset("../cleaned_data", n_seconds)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
n_timesteps, n_features  = dataset.__getitem__(1).shape[:2]
n_timesteps, n_features

(44100, 11)

In [7]:
dataset.__getitem__(1).shape

torch.Size([44100, 11])

In [8]:
vae = LSTMVAE(n_timesteps, n_features, hidden_dim, latent_dim, device).to(device)
vae.optimizer = optim.Adam(vae.parameters())

In [9]:
%%time
epoch_losses = []
for epoch in range(epochs):
    epoch_loss = 0
    n_steps = 0
    for batch_x in dataloader:
        batch_loss = vae.training_step(batch_x.to(device))
        epoch_loss += batch_loss
        n_steps += 1
    epoch_loss = epoch_loss/n_steps
    epoch_losses.append(epoch_loss)
    print(f"Epoch {epoch} loss: {epoch_loss}")
    if (epoch % 10) == 0:
        torch.save(vae.state_dict(), "model.torch")
        with open("losses.json", "w") as f:
            json.dump(epoch_losses, f)

Epoch 0 loss: 0.056406800326492104
Epoch 1 loss: 0.04300870787352323
Epoch 2 loss: 0.04201727291835206
Epoch 3 loss: 0.04189609006579433
Epoch 4 loss: 0.04187817077285477
Epoch 5 loss: 0.041873248355197054
Epoch 6 loss: 0.04187262002910887
Epoch 7 loss: 0.04187326469059501
Epoch 8 loss: 0.04187192729541234
Epoch 9 loss: 0.041871373594871586
Epoch 10 loss: 0.04187207324164254
Epoch 11 loss: 0.04187260731256434
Epoch 12 loss: 0.04187198148242065
Epoch 13 loss: 0.041871180733931916
Epoch 14 loss: 0.04187163289902466
Epoch 15 loss: 0.04187336938721793
Epoch 16 loss: 0.041871178881930454
Epoch 17 loss: 0.041872623222214835
Epoch 18 loss: 0.04187199448102287
Epoch 19 loss: 0.04187090148351022
Epoch 20 loss: 0.041873045153915885
Epoch 21 loss: 0.041868849407349314
Epoch 22 loss: 0.041871445415807625
Epoch 23 loss: 0.041870018914341924
Epoch 24 loss: 0.04187108224257827
Epoch 25 loss: 0.041871726151023594
Epoch 26 loss: 0.04187081505145345
Epoch 27 loss: 0.04187036700014557
Epoch 28 loss: 0.04

KeyboardInterrupt: 

In [12]:
torch.save(vae.state_dict(), "model.torch")

In [13]:
with open("losses.json", "w") as f:
    json.dump(epoch_losses, f)

In [14]:
print("END")

END
