In [1]:
from utils.utils import *
from utils.specgan import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
import pickle

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
GPU=True
if not GPU:
    device= 'cpu'
# Data Params
DATA_PATH='./data/drums'
AUDIO_LENGTH = 16384 #[16384, 32768, 65536] 
SAMPLING_RATE = 16000
NORMALIZE_AUDIO = False 
CHANNELS = 1

#Model params
LATENT_NOISE_DIM = 100
MODEL_CAPACITY=64
LAMBDA_GP = 10

#Training params
TRAIN_DISCRIM = 5 # how many times to train the discriminator for one generator step
EPOCHS = 500
BATCH_SIZE=26
LR_GEN = 1e-4
LR_DISC = 1e-4 # alternative is bigger lr instead of high TRAIN_DISCRIM
BETA1 = 0.5
BETA2 = 0.9


# Dataset and Dataloader

#load into vram
#train_set = AudioDataset_ram(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO,device=device,spectrogram=True)

#load into cpu ram
train_set = AudioDataset_ram(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO,device='cpu',spectrogram=True)

#load at training from disk
#train_set = AudioDataset(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO,spectrogram=True)

print(train_set.__len__())


loading sample 3000: 100%|█████████████████████████████████████████████████████| 3002/3002 [00:39<00:00, 76.78sample/s]


3002


In [6]:
import os
if not os.path.exists('./save/specgen'):
    os.makedirs('./save/specgen')
    
if not os.path.exists('./save/specdisc'):
    os.makedirs('./save/specdisc')
    
if not os.path.exists('./save/specfake'):
    os.makedirs('./save/specfake')
    
if not os.path.exists('./save/spechist'):
    os.makedirs('./save/spechist')

In [4]:
train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4)
#generator and discriminator
spec_gen = SpecGenerator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)
spec_disc = SpecDiscriminator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)

#random weights init
initialize_weights(spec_gen)
initialize_weights(spec_disc)

spec_gen.train()
spec_disc.train()

#Adam optim for both generator iand discriminator
optimizer_gen = optim.Adam(spec_gen.parameters(), lr=LR_GEN, betas=(BETA1, BETA2))
optimizer_disc = optim.Adam(spec_disc.parameters(), lr=LR_DISC, betas=(BETA1, BETA2))

start=-1
if start>0:
    spec_disc.load_state_dict(torch.load('./save/specdisc/disc_'+str(start)+'_93.pt'))
    spec_gen.load_state_dict(torch.load('./save/specgen/gen_'+str(start)+'_93.pt'))

In [None]:
#training
import pickle
step = start+1 # for restart from saved weights
hist=[]
for epoch in range(EPOCHS):
    with tqdm(train_loader, unit="batch") as tepoch: 
        for batch_id, real_audio in enumerate(tepoch):  
            tepoch.set_description(f"Epoch {epoch}")
            real_audio = real_audio.to(device)
            
            #Train Discriminator 
            for train_step in range(TRAIN_DISCRIM):
                noise = torch.randn(real_audio.shape[0], LATENT_NOISE_DIM).to(device)
                #print(noise.shape)
                fake_audio = spec_gen(noise)
                disc_real = spec_disc(real_audio).reshape(-1)
                disc_fake = spec_disc(fake_audio).reshape(-1)
                loss_disc = wasserstein_loss(spec_disc, real_audio, fake_audio,device,LAMBDA = LAMBDA_GP,spec_gan=True)
                spec_disc.zero_grad()
                loss_disc.backward(retain_graph=True)
                optimizer_disc.step()

            # Train the generator!
            all_wasserstein = spec_disc(fake_audio).reshape(-1)
            loss = -torch.mean(all_wasserstein)
            spec_gen.zero_grad()
            loss.backward()
            optimizer_gen.step()
            step += 1
            # Print progress, save stats, and save model
            hist.append([loss.item(),loss_disc.item()])
            if batch_id % 5 == 0 and batch_id > 0:
                tepoch.set_postfix(gen_loss=loss.item(), disc_loss=loss_disc.item())

    if epoch % 5 == 0:
        with open('./save/spechist/hist_'+str(step)+'_'+str(epoch)+'_'+str(batch_id)+'.pkl', 'wb') as f:
            pickle.dump(hist, f)
        torch.save(spec_gen.state_dict(), './save/specgen/gen_'+str(epoch)+'_'+str(step)+'.pt')
        torch.save(spec_disc.state_dict(), './save/specdisc/spec_'+str(epoch)+'_'+str(step)+'.pt')
        with torch.no_grad():
            fake = spec_gen(noise)
            torch.save(fake, './save/specfake/fake_'+str(epoch)+'_'+str(step)+'.pt')
        if step>30000:
            break


            


Epoch 0:  22%|████████▎                            | 26/116 [00:51<01:51,  1.24s/batch, disc_loss=-22.5, gen_loss=38.8]