In [1]:
%load_ext autoreload
%autoreload 2
import json
from CVAE import Resnet1DBlock, CVAE, calculate_output_shape_convtranspose
from music_dataset import TorchMusicDataset
import torch.nn as nn
import numpy as np

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from scipy.io.wavfile import write

In [2]:
N_SECONDS = 2
DOWNSAMPLE_RATIO = 20

In [3]:
EPOCHS = 300
BATCH_SIZE = 2
LATENT_DIM = 2

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(DEVICE)

In [5]:
dataset = TorchMusicDataset("../cleaned_data", N_SECONDS, DOWNSAMPLE_RATIO)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, generator=torch.Generator(device='cuda'))
n_features, n_timesteps  = dataset.__getitem__(1).shape
n_features, n_timesteps

(1, 4410)

In [6]:
for batch in dataloader:
    break

In [7]:
cvae = CVAE(LATENT_DIM, DEVICE, n_timesteps, BATCH_SIZE).to(DEVICE)
cvae.optimizer = optim.Adam(cvae.parameters(), lr=0.0003)

In [8]:
%%time
losses = []
for epoch in range(EPOCHS):
    n_steps = 0
    epoch_loss = 0
    for batch_x in dataloader:
        loss = cvae.training_step(batch_x)
        epoch_loss += loss
        n_steps += 1
    epoch_loss /= n_steps
    print(epoch, epoch_loss)
    losses.append(epoch_loss)
    with open("loss.json", "w") as f:
        json.dump(losses, f)

  return func(*args, **kwargs)


0 0.4234396821260452
1 0.26102894350886346
2 0.17971415422856807
3 0.19547316752374172
4 0.18429199293255805
5 0.19280827149748803
6 0.1600872442126274
7 0.18140337064862253
8 0.16676100827753543
9 0.19960481844842434
10 0.19678015895187856
11 0.14810525439679623
12 0.17210095398128034
13 0.1861790668964386
14 0.1858897364884615
15 0.1699416770040989
16 0.16080526750534774
17 0.18415784694254397
18 0.1939677257835865
19 0.1917008064687252
20 0.18411960527300836
21 0.2163825024664402
22 0.18408796913921832
23 0.18245952300727367
24 0.20931226000189782
25 0.1864846894145012
26 0.18173902556300164
27 0.17388261444866657
28 0.18959308102726935
29 0.16258222706615924
30 0.17344501614570618
31 0.19687197580933571
32 0.1929529318213463
33 0.1768783188611269
34 0.19287771806120874
35 0.20295538626611231
36 0.18366374909877778
37 0.20159560821950437
38 0.19222786702215672
39 0.19800915598869323
40 0.1661548388749361
41 0.20432084575295448
42 0.19686425276100636
43 0.19122443109750747
44 0.18449

In [9]:
generated_sequences = cvae.generate_sequences(2)
generated_sequences.shape

(2, 4410)

In [10]:
output_filename = 'gen_audio_VAE_LSTM.wav'
write(output_filename, 44100//DOWNSAMPLE_RATIO, generated_sequences[0])

print(f"Audio saved as {output_filename}")

Audio saved as gen_audio_VAE_LSTM.wav
