In [4]:
from utils.utils import *
from utils.wavegan import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
import pickle

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

GPU=True
if not GPU:
    device= 'cpu'
# Data Params
DATA_PATH='./data/drums/train'
AUDIO_LENGTH = 16384 #[16384, 32768, 65536] 
SAMPLING_RATE = 16000
NORMALIZE_AUDIO = True 
CHANNELS = 1

#Model params
LATENT_NOISE_DIM = 100
MODEL_CAPACITY=8
LAMBDA_GP = 10

#Training params
TRAIN_DISCRIM = 1 # how many times to train the discriminator for one generator step
EPOCHS = 100
BATCH_SIZE=3
LR_GEN = 1e-4
LR_DISC = 3e-4 # bigger lr instead of high TRAIN_DISCRIM
BETA1 = 0.5
BETA2 = 0.9


# Dataset and Dataloader
train_set = AudioDataset(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO)
print(train_set.__len__())

train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4)

#generator and discriminator
wave_gen = WaveGenerator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)
wave_disc = WaveDiscriminator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)

#random weights init
initialize_weights(wave_gen)
initialize_weights(wave_disc)
wave_gen.train()
wave_disc.train()

#Adam optim for both generator iand discriminator
optimizer_gen = optim.Adam(wave_gen.parameters(), lr=LR_GEN, betas=(BETA1, BETA2))
optimizer_disc = optim.Adam(wave_disc.parameters(), lr=LR_DISC, betas=(BETA1, BETA2))

2350


In [6]:
step = 0
hist=[]
for epoch in range(EPOCHS):
    with tqdm(train_loader, unit="batch") as tepoch: 
        for batch_id, real_audio in enumerate(tepoch):  
            tepoch.set_description(f"Epoch {epoch}")
            real_audio = real_audio.to(device)
            
            #Train Discriminator 
            for train_step in range(TRAIN_DISCRIM):
                noise = torch.randn(real_audio.shape[0], LATENT_NOISE_DIM, 1, 1).to(device)
                print(noise.shape)
                fake_audio = wave_gen(noise)
                disc_real = wave_disc(real).reshape(-1)
                disc_fake = wave_disc(fake).reshape(-1)
                loss_disc = wasserstein_loss(wave_disc, real_audio, fake_audio,device,LAMBDA = LAMBDA_GP)
                wave_disc.zero_grad()
                loss_disc.backward(retain_graph=True)
                optimizer_disc.step()

            # Train the generator!
            all_wasserstein = critic(fake_audio).reshape(-1)
            loss = -torch.mean(all_wasserstein)
            wave_gen.zero_grad()
            loss.backward()
            optimizer_gen.step()

            # Print progress, save stats, and save model
            hist.append([loss,loss_disc])
            if batch_idx % 5 == 0 and batch_idx > 0:
                tepoch.set_postfix(generator_loss=loss, discriminator_loss=loss_disc)
            if batch_idx % 100 == 0 and batch_idx > 0:
                torch.save(wave_gen.state_dict(), './save/gen_'+epoch+'_'+batch_idx+'.pt')
                torch.save(wave_disc.state_dict(), './save/wave_'+epoch+'_'+batch_idx+'.pt')
                with torch.no_grad():
                    fake = wave_gen(noise)
                    torch.save(fake, './save/fake.pt')
                    

            step += 1

Epoch 0:   0%|                                                                              | 0/784 [00:12<?, ?batch/s]

torch.Size([3, 1, 16384])
torch.Size([3, 100, 1, 1])


Epoch 0:   0%|                                                                              | 0/784 [00:13<?, ?batch/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (300x1 and 100x2048)

In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


def calculate_discriminator_loss(discriminator, real, generated,device,LAMBDA = 10):
    '''
    Wasserstein loss with Gradient Penalty 
    Check https://arxiv.org/pdf/1704.00028.pdf for pseudo code
    LAMBDA: penalty parameter (=10 in the paper)
    
    '''
        disc_out_gen = discriminator(generated)
        disc_out_real = discriminator(real)
        
        batch_size,C,L=real.shape
        eps=torch.rand((batch_size, 1, 1)).repeat(1, C, L).to(device)
      
        interpolated_sound = (1 - eps) * real + (eps) * generated
        
        mixed_score = discriminator(interpolated_sound)
        
        ones = torch.ones(mixed_score.size()).to(device)
        
        gradients = grad(
            outputs=mixed_score,
            inputs=interpolated_sound,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        # calculate gradient penalty
        grad_penalty = (
            LAMBDA
            * ((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()
        )
        
        # normal Wasserstein loss
        loss = disc_out_gen.mean() - disc_out_real.mean()
        # adding gradient penalty with param LAMBDA (=10 in paper)
        loss_GP = loss + grad_penalty
        return loss_GP, loss
    
    
    
def get_number_parameters(model):
    
    """
    Prints the number of trainable parameters of the model
    
    """
        
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print(params,'trainable parameters')


def initialize_weights(m,debug=False):
    """
    Weights initializer: initialize weights with mean=0 and std= .02 like in DCGAN
    
    debug=True prints if the layer has been initialized
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1 or classname.find('BatchNorm') != -1:
        nn.init.constant_(m.bias.data, 0)
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if debug:
            print('init',classname)
    else:
        if debug:
            print('noinit',classname)
            
            
            
##################   Generator    ##########################   


       
class WaveGenerator(nn.Module):
    """
    Generator for WaveGAN
    
    d: model size (default 64)
    c: number of channels in the data (default 1)
    inplace: boolean (defaul True) #arg for the Relu 
    
    See page 15 of the Wavegan paper https://openreview.net/pdf?id=ByMVTsR5KQ 
    
    """
    def __init__(self,d=64, c=1 ,inplace=True):
        super(WaveGenerator, self).__init__()
        self.d=d # model size
        self.c=c # = 1 in the paper
        self.dense1= nn.Linear(100, 256*self.d)
        self.padding=11
        self.seq = nn.Sequential(
            # input is dense(Z), going into a convolution
            nn.ReLU(inplace), #out (n,16,16d)
            nn.ConvTranspose1d( 16*self.d, self.d * 8, 25, 4, self.padding,1, bias=True), # (25,16d,4d) | (n,64,8d)
            nn.ReLU(inplace), #no batch norm
         
            nn.ConvTranspose1d(self.d * 8, self.d * 4, 25, 4, self.padding,1, bias=True),#(25, 8d, 4d)| (n, 256, 4d)
            nn.ReLU(inplace),
        
            nn.ConvTranspose1d( self.d * 4,self.d * 2, 25, 4, self.padding,1, bias=True),#(25, 4d, 2d) | (n, 1024, 2d)
            nn.ReLU(inplace),
      
            nn.ConvTranspose1d( self.d * 2, self.d, 25, 4, self.padding,1, bias=True), #(25, 2d, d) | (n, 4096, d)
            nn.ReLU(inplace),
            nn.ConvTranspose1d( self.d, self.c, 25, 4, self.padding,1, bias=True),#(25, d, c) | (n, 16384, c)
            nn.Tanh() # as suggested       
        )

    def forward(self, x):
        #input (n,100)
        x=self.dense1(x) # output (n,256*d)
        x=torch.reshape(x, (-1,16*self.d,16)) # output (n,16,16d)
        
        return self.seq(x) #(n, 16384, c), c=1

    
################### phase suffling ##########################

class PhaseShuffling(nn.Module):
    """
    
    PhaseShuffling layer: shifts the features by a random int value between [-n,n]
    n: shift factor
    
    # paper code in tf https://github.com/chrisdonahue/wavegan/blob/master/wavegan.py
    """
    def __init__(self, n):
        super(PhaseShuffling, self).__init__()
        self.n = n

    def forward(self, x):
        #x:(n batch,channels,xlen)
        if self.n==0:
            return x
        shifts =  int(torch.Tensor(1).random_(0, 2*self.n + 1)) - self.n # same shuffle for data in batch
       
        if shifts > 0:
            return F.pad(x[..., :-shifts], (shifts, 0), mode='reflect')
        else:
            return F.pad(x[..., -shifts:], (0, -shifts), mode='reflect')




##################     Critic      ##########################
    
# See page 15 of the Wavegan paper https://openreview.net/pdf?id=ByMVTsR5KQ        
class WaveDiscriminator(nn.Module):
    """
    Generator for WaveGAN
    
    d: model size (default 64)
    c: number of channels in the data (default 1)
    inplace: boolean (defaul True) #arg for the Leaky Relu  
    
    See page 15 of the Wavegan paper https://openreview.net/pdf?id=ByMVTsR5KQ 
    
    """
    def __init__(self,d=64, c=1,n=2,inplace=True):
        super(WaveDiscriminator, self).__init__()
     
        self.d=d # model size
        self.c=c # = 1 in the paper
        self.padding=11
        leak=0.2
        self.n=n
        
        self.dense= nn.Linear(256*self.d, 1)
        self.seq = nn.Sequential(
            # input is audio or WaveGenerator(z)
            nn.Conv1d( self.c, self.d, 25, 4, self.padding, bias=True), #(n,4096,d)
            nn.LeakyReLU(leak,inplace=inplace),
            PhaseShuffling(n=self.n),
            
            nn.Conv1d( self.d, 2*self.d, 25, 4, self.padding, bias=True),  #(n,1024,2d)
            nn.LeakyReLU(leak,inplace=inplace),
            PhaseShuffling(n=self.n),
            
            nn.Conv1d( self.d * 2, self.d * 4, 25, 4, self.padding, bias=True),  #(n,256,4d)
            nn.LeakyReLU(leak,inplace=inplace),
            PhaseShuffling(n=self.n),
            
            nn.Conv1d( self.d * 4, self.d * 8, 25, 4, self.padding, bias=True),  #(n,64,8d)
            nn.LeakyReLU(leak,inplace=inplace),
            PhaseShuffling(n=self.n),
            
            nn.Conv1d( self.d * 8, self.d * 16, 25, 4, self.padding, bias=True),  #(n,16,16d)
            nn.LeakyReLU(leak,inplace=inplace)
        )
               

    def forward(self, x):
        x=self.seq(x)
        x=torch.reshape(x, (-1,256*self.d)) 
        #print(x.shape)
        return self.dense(x) 


def testing():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Create the generator
    waveG = WaveGenerator().to(device)
    #initialize wights
    waveG.apply(initialize_weights)
    print(waveG)

    # Create the Discriminator
    waveD = WaveDiscriminator().to(device)
    #initialize wights
    waveD.apply(initialize_weights)
    print(waveD)


    N, in_channels, L = 8, 1, 16384
    noise_dim = 100
    x = torch.randn((N, in_channels,L))
    waveD = WaveDiscriminator()
    print(waveD(x).shape)
    waveG = WaveGenerator(d=64,c=1)
    z = torch.randn((N, 100))
    print(waveG(z).shape)
    print(waveD(waveG(z)))
    
testing()

WaveGenerator(
  (dense1): Linear(in_features=100, out_features=16384, bias=True)
  (seq): Sequential(
    (0): ReLU(inplace=True)
    (1): ConvTranspose1d(1024, 512, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (2): ReLU(inplace=True)
    (3): ConvTranspose1d(512, 256, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (4): ReLU(inplace=True)
    (5): ConvTranspose1d(256, 128, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (6): ReLU(inplace=True)
    (7): ConvTranspose1d(128, 64, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (8): ReLU(inplace=True)
    (9): ConvTranspose1d(64, 1, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (10): Tanh()
  )
)
WaveDiscriminator(
  (dense): Linear(in_features=16384, out_features=1, bias=True)
  (seq): Sequential(
    (0): Conv1d(1, 64, kernel_size=(25,), stride=(4,), padding=(11,))
    (1): LeakyReLU(negative_slope=0.2, inplace=T