In [1]:
%reload_ext autoreload
%autoreload 2

from utils.utils import *
from utils.wavegan import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
import pickle

device = "cuda" if torch.cuda.is_available() else "cpu"

GPU=True
if not GPU:
    device= 'cpu'
# Data Params
DATA_PATH='./eecs-audio-data/nsynth/nsynth-train/audio'
AUDIO_LENGTH = 16384 #[16384, 32768, 65536] 
SAMPLING_RATE = 16000
NORMALIZE_AUDIO = False 
CHANNELS = 1

#Model params
LATENT_NOISE_DIM = 100
MODEL_CAPACITY=64
LAMBDA_GP = 10

#Training params
TRAIN_DISCRIM = 5 # how many times to train the discriminator for one generator step
EPOCHS = 500
BATCH_SIZE=64
LR_GEN = 1e-4
LR_DISC = 1e-4 # alternative is bigger lr instead of high TRAIN_DISCRIM
BETA1 = 0.5
BETA2 = 0.9

In [2]:
# Dataset and Dataloader
train_set = AudioDataset(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO)
print(train_set.__len__())
train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=6)

272185


  cpuset_checked))


In [3]:
#generator and discriminator
wave_gen = WaveGenerator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)
wave_disc = WaveDiscriminator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)

#random weights init
initialize_weights(wave_gen)
initialize_weights(wave_disc)
wave_gen
wave_disc
wave_gen.train()
wave_disc.train()

#Adam optim for both generator iand discriminator
optimizer_gen = optim.Adam(wave_gen.parameters(), lr=LR_GEN, betas=(BETA1, BETA2))
optimizer_disc = optim.Adam(wave_disc.parameters(), lr=LR_DISC, betas=(BETA1, BETA2))

In [None]:
#training
import pickle
step = 0
hist=[]
for epoch in range(EPOCHS):
    with tqdm(train_loader, unit="batch") as tepoch: 
        for batch_id, real_audio in enumerate(tepoch):  
            tepoch.set_description(f"Epoch {epoch}")
            real_audio = real_audio.to(device)
            
            #Train Discriminator 
            for train_step in range(TRAIN_DISCRIM):
                noise = torch.randn(real_audio.shape[0], LATENT_NOISE_DIM).to(device)
                #print(noise.shape)
                fake_audio = wave_gen(noise)
                disc_real = wave_disc(real_audio).reshape(-1)
                disc_fake = wave_disc(fake_audio).reshape(-1)
                loss_disc = wasserstein_loss(wave_disc, real_audio, fake_audio,device,LAMBDA = LAMBDA_GP)
                wave_disc.zero_grad()
                loss_disc.backward(retain_graph=True)
                optimizer_disc.step()

            # Train the generator!
            all_wasserstein = wave_disc(fake_audio).reshape(-1)
            loss = -torch.mean(all_wasserstein)
            wave_gen.zero_grad()
            loss.backward()
            optimizer_gen.step()

            # Print progress, save stats, and save model
            hist.append([loss.item(),loss_disc.item()])
            if batch_id % 10 == 0 and batch_id > 0:
                tepoch.set_postfix(gen_loss=loss.item(), disc_loss=loss_disc.item())
            
        torch.save(wave_gen.state_dict(), './save/wavegen/gen_'+str(step)+'_'+str(epoch)+'_'+str(batch_id)+'.pt')
        torch.save(wave_disc.state_dict(), './save/wavedisc/wave_'+str(step)+'_'+str(epoch)+'_'+str(batch_id)+'.pt')
        with torch.no_grad():
            fake = wave_gen(noise)
            torch.save(fake, './save/wavefake/fake_'+str(step)+'_'+str(epoch)+'_'+str(batch_id)+'.pt')
                    

            step += 1

Epoch 0: 100%|█████████████████████████████████████| 94/94 [02:34<00:00,  1.64s/batch, disc_loss=-6.34, gen_loss=-17.1]
Epoch 1: 100%|█████████████████████████████████████| 94/94 [02:31<00:00,  1.61s/batch, disc_loss=-9.01, gen_loss=-17.2]
Epoch 2: 100%|█████████████████████████████████████| 94/94 [02:33<00:00,  1.63s/batch, disc_loss=-7.43, gen_loss=-5.67]
Epoch 3: 100%|█████████████████████████████████████| 94/94 [02:33<00:00,  1.63s/batch, disc_loss=-7.67, gen_loss=-4.14]
Epoch 4: 100%|████████████████████████████████████████| 94/94 [02:33<00:00,  1.63s/batch, disc_loss=-9.6, gen_loss=5.4]
Epoch 5: 100%|█████████████████████████████████████| 94/94 [02:33<00:00,  1.63s/batch, disc_loss=-7.49, gen_loss=-6.21]
Epoch 6: 100%|█████████████████████████████████████| 94/94 [02:33<00:00,  1.64s/batch, disc_loss=-5.84, gen_loss=0.804]
Epoch 7: 100%|█████████████████████████████████████| 94/94 [02:34<00:00,  1.64s/batch, disc_loss=-4.89, gen_loss=-.241]
Epoch 8: 100%|██████████████████████████