In [2]:
#C:\Users\nguye\Downloads\timit\data\lisa\data\timit\raw\TIMIT\TRAIN

%reload_ext autoreload
%autoreload 2

from utils.utils import *
from utils.wavegan import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
import pickle

device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
GPU=True
if not GPU:
    device= 'cpu'
# Data Params
DATA_PATH='./data/piano'
AUDIO_LENGTH = 16384 #[16384, 32768, 65536] 
SAMPLING_RATE = 16000
NORMALIZE_AUDIO = False 
CHANNELS = 1

#Model params
LATENT_NOISE_DIM = 100
MODEL_CAPACITY=64
LAMBDA_GP = 10

#Training params
TRAIN_DISCRIM = 5 # how many times to train the discriminator for one generator step
EPOCHS = 500
BATCH_SIZE=32
LR_GEN = 1e-4
LR_DISC = 1e-4 # alternative is bigger lr instead of high TRAIN_DISCRIM
BETA1 = 0.5
BETA2 = 0.9


# Dataset and Dataloader

#load into vram
train_set = AudioDataset_ram(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO,device=device)

#load into cpu ram
#train_set = AudioDataset_ram(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO,device='cpu')

#load at training from disk
#train_set = AudioDataset(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO,start_only=False)

print(train_set.__len__())


loading sample 1420: 100%|█████████████████████████████████████████████████████| 1426/1426 [00:48<00:00, 29.19sample/s]


1426


In [19]:
from IPython.display import Audio 
from IPython.core.display import display
import torchaudio
import torch
import numpy as np
SAMPLING_RATE=16000


audio_batch=train_set[0].cpu()

def audio_player(audio_batch,autoplay=False):
    
    for i in range(audio_batch.shape[0]):
        display(Audio(audio_batch[i].numpy(), rate=SAMPLING_RATE, autoplay=autoplay))



audio_player(audio_batch)


In [22]:
train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=0) #if loading into vram set num_workers to 0
#generator and discriminator
wave_gen = WaveGenerator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)
wave_disc = WaveDiscriminator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)

#random weights init
initialize_weights(wave_gen)
initialize_weights(wave_disc)
wave_gen
wave_disc
wave_gen.train()
wave_disc.train()

#Adam optim for both generator iand discriminator
optimizer_gen = optim.Adam(wave_gen.parameters(), lr=LR_GEN, betas=(BETA1, BETA2))
optimizer_disc = optim.Adam(wave_disc.parameters(), lr=LR_DISC, betas=(BETA1, BETA2))

start=-1
if start>0:
    wave_disc.load_state_dict(torch.load('./save/wavedisc/wave_'+str(start)+'_93.pt'))
    wave_gen.load_state_dict(torch.load('./save/wavegen/gen_'+str(start)+'_93.pt'))

In [None]:
#training
import pickle
step = start+1 # for restart from saved weights
hist=[]
for epoch in range(EPOCHS):
    with tqdm(train_loader, unit="batch") as tepoch: 
        for batch_id, real_audio in enumerate(tepoch):  
            tepoch.set_description(f"Epoch {epoch}")
            real_audio = real_audio.to(device)
            
            #Train Discriminator 
            for train_step in range(TRAIN_DISCRIM):
                noise = torch.randn(real_audio.shape[0], LATENT_NOISE_DIM).to(device)
                #print(noise.shape)
                fake_audio = wave_gen(noise)
                disc_real = wave_disc(real_audio).reshape(-1)
                disc_fake = wave_disc(fake_audio).reshape(-1)
                loss_disc = wasserstein_loss(wave_disc, real_audio, fake_audio,device,LAMBDA = LAMBDA_GP)
                wave_disc.zero_grad()
                loss_disc.backward(retain_graph=True)
                optimizer_disc.step()

            # Train the generator!
            all_wasserstein = wave_disc(fake_audio).reshape(-1)
            loss = -torch.mean(all_wasserstein)
            wave_gen.zero_grad()
            loss.backward()
            optimizer_gen.step()
            step += 1
            # Print progress, save stats, and save model
            hist.append([loss.item(),loss_disc.item()])
            if batch_id % 5 == 0 and batch_id > 0:
                tepoch.set_postfix(gen_loss=loss.item(), disc_loss=loss_disc.item())

    #if batch_id % 100 == 0 and batch_id > 0:
    with open('./save/wavehist/hist_'+str(step)+'_'+str(epoch)+'_'+str(batch_id)+'.pkl', 'wb') as f:
        pickle.dump(hist, f)
    torch.save(wave_gen.state_dict(), './save/wavegen/gen_'+str(epoch)+'_'+str(step)+'.pt')
    torch.save(wave_disc.state_dict(), './save/wavedisc/wave_'+str(epoch)+'_'+str(step)+'.pt')
    with torch.no_grad():
        fake = wave_gen(noise)
        torch.save(fake, './save/wavefake/fake_'+str(epoch)+'_'+str(step)+'.pt')


Epoch 0: 100%|███████████████████████████████████████| 45/45 [01:18<00:00,  1.74s/batch, disc_loss=4.6, gen_loss=-8.08]
Epoch 1: 100%|█████████████████████████████████████| 45/45 [01:10<00:00,  1.57s/batch, disc_loss=0.877, gen_loss=-12.5]
Epoch 2: 100%|██████████████████████████████████████| 45/45 [01:08<00:00,  1.52s/batch, disc_loss=-2.98, gen_loss=4.39]
Epoch 3: 100%|██████████████████████████████████████| 45/45 [01:08<00:00,  1.52s/batch, disc_loss=-1.45, gen_loss=16.6]
Epoch 4: 100%|█████████████████████████████████████| 45/45 [01:08<00:00,  1.52s/batch, disc_loss=-16.8, gen_loss=-43.9]
Epoch 5: 100%|█████████████████████████████████████| 45/45 [01:08<00:00,  1.51s/batch, disc_loss=0.465, gen_loss=-48.3]
Epoch 6: 100%|█████████████████████████████████████| 45/45 [01:08<00:00,  1.51s/batch, disc_loss=-2.92, gen_loss=-7.86]
Epoch 7: 100%|██████████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=0.339, gen_loss=2.18]
Epoch 8: 100%|██████████████████████████

Epoch 68: 100%|████████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=-2.85, gen_loss=0.667]
Epoch 69: 100%|████████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=-2.75, gen_loss=0.417]
Epoch 70: 100%|█████████████████████████████████████| 45/45 [01:07<00:00,  1.50s/batch, disc_loss=-2.69, gen_loss=1.01]
Epoch 71: 100%|████████████████████████████████████| 45/45 [01:08<00:00,  1.51s/batch, disc_loss=-2.72, gen_loss=-1.16]
Epoch 72: 100%|█████████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=-2.4, gen_loss=0.901]
Epoch 73: 100%|████████████████████████████████████| 45/45 [01:07<00:00,  1.50s/batch, disc_loss=-2.66, gen_loss=-.694]
Epoch 74: 100%|████████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=-2.88, gen_loss=0.473]
Epoch 75: 100%|████████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=-2.74, gen_loss=-.375]
Epoch 76: 100%|█████████████████████████

Epoch 136: 100%|███████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=-2.52, gen_loss=0.713]
Epoch 137: 100%|███████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=-2.56, gen_loss=-.243]
Epoch 138: 100%|███████████████████████████████████| 45/45 [01:08<00:00,  1.51s/batch, disc_loss=-2.51, gen_loss=0.945]
Epoch 139: 100%|███████████████████████████████████| 45/45 [01:08<00:00,  1.51s/batch, disc_loss=-2.65, gen_loss=-.284]
Epoch 140: 100%|███████████████████████████████████| 45/45 [01:08<00:00,  1.51s/batch, disc_loss=-2.61, gen_loss=0.505]
Epoch 141: 100%|███████████████████████████████████| 45/45 [01:07<00:00,  1.51s/batch, disc_loss=-2.61, gen_loss=-.111]
Epoch 142: 100%|████████████████████████████████████| 45/45 [01:08<00:00,  1.52s/batch, disc_loss=-2.39, gen_loss=0.44]
Epoch 143: 100%|███████████████████████████████████| 45/45 [01:08<00:00,  1.51s/batch, disc_loss=-1.93, gen_loss=-.762]
Epoch 144: 100%|████████████████████████

Epoch 204: 100%|███████████████████████████████████| 45/45 [01:09<00:00,  1.54s/batch, disc_loss=-2.35, gen_loss=0.505]
Epoch 205: 100%|███████████████████████████████████| 45/45 [01:08<00:00,  1.52s/batch, disc_loss=-2.72, gen_loss=0.474]
Epoch 206: 100%|███████████████████████████████████| 45/45 [01:08<00:00,  1.52s/batch, disc_loss=-3.08, gen_loss=-.383]
Epoch 207: 100%|████████████████████████████████████| 45/45 [01:09<00:00,  1.55s/batch, disc_loss=-2.87, gen_loss=1.27]
Epoch 208: 100%|███████████████████████████████████| 45/45 [01:09<00:00,  1.54s/batch, disc_loss=-2.72, gen_loss=0.705]
Epoch 209: 100%|████████████████████████████████████| 45/45 [01:08<00:00,  1.52s/batch, disc_loss=-2.79, gen_loss=1.42]
Epoch 210: 100%|██████████████████████████████████| 45/45 [01:08<00:00,  1.52s/batch, disc_loss=-2.27, gen_loss=0.0644]
Epoch 211: 100%|████████████████████████████████████| 45/45 [01:08<00:00,  1.51s/batch, disc_loss=-2.52, gen_loss=1.26]
Epoch 212: 100%|████████████████████████