In [5]:
from utils.utils import *
from utils.wavegan import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_CAPACITY=32
AUDIO_LENGTH = 16384 #[16384, 32768, 65536] 
SAMPLING_RATE = 16000
NORMALIZE_AUDIO = True 
CHANNELS = 1
LEARNING_RATE = 1e-4
Z_DIM = 100
NUM_EPOCHS = 100
BATCH_SIZE=63
CRITIC_ITERATIONS = 2
LAMBDA_GP = 10


dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
# comment mnist above and uncomment below for training on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

# for tensorboard plotting
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/GAN_MNIST/real")
writer_fake = SummaryWriter(f"logs/GAN_MNIST/fake")
step = 0

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


def calculate_discriminator_loss(discriminator, real, generated,device,LAMBDA = 10):
    '''
    Wasserstein loss with Gradient Penalty 
    Check https://arxiv.org/pdf/1704.00028.pdf for pseudo code
    LAMBDA: penalty parameter (=10 in the paper)
    
    '''
        disc_out_gen = discriminator(generated)
        disc_out_real = discriminator(real)
        
        batch_size,C,L=real.shape
        eps=torch.rand((batch_size, 1, 1)).repeat(1, C, L).to(device)
      
        interpolated_sound = (1 - eps) * real + (eps) * generated
        
        mixed_score = discriminator(interpolated_sound)
        
        ones = torch.ones(mixed_score.size()).to(device)
        
        gradients = grad(
            outputs=mixed_score,
            inputs=interpolated_sound,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        # calculate gradient penalty
        grad_penalty = (
            LAMBDA
            * ((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()
        )
        
        # normal Wasserstein loss
        loss = disc_out_gen.mean() - disc_out_real.mean()
        # adding gradient penalty with param LAMBDA (=10 in paper)
        loss_GP = loss + grad_penalty
        return loss_GP, loss
    
    
    
def get_number_parameters(model):
    
    """
    Prints the number of trainable parameters of the model
    
    """
        
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print(params,'trainable parameters')


def initialize_weights(m,debug=False):
    """
    Weights initializer: initialize weights with mean=0 and std= .02 like in DCGAN
    
    debug=True prints if the layer has been initialized
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1 or classname.find('BatchNorm') != -1:
        nn.init.constant_(m.bias.data, 0)
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if debug:
            print('init',classname)
    else:
        if debug:
            print('noinit',classname)
            
            
            
##################   Generator    ##########################   


       
class WaveGenerator(nn.Module):
    """
    Generator for WaveGAN
    
    d: model size (default 64)
    c: number of channels in the data (default 1)
    inplace: boolean (defaul True) #arg for the Relu 
    
    See page 15 of the Wavegan paper https://openreview.net/pdf?id=ByMVTsR5KQ 
    
    """
    def __init__(self,d=64, c=1 ,inplace=True):
        super(WaveGenerator, self).__init__()
        self.d=d # model size
        self.c=c # = 1 in the paper
        self.dense1= nn.Linear(100, 256*self.d)
        self.padding=11
        self.seq = nn.Sequential(
            # input is dense(Z), going into a convolution
            nn.ReLU(inplace), #out (n,16,16d)
            nn.ConvTranspose1d( 16*self.d, self.d * 8, 25, 4, self.padding,1, bias=True), # (25,16d,4d) | (n,64,8d)
            nn.ReLU(inplace), #no batch norm
         
            nn.ConvTranspose1d(self.d * 8, self.d * 4, 25, 4, self.padding,1, bias=True),#(25, 8d, 4d)| (n, 256, 4d)
            nn.ReLU(inplace),
        
            nn.ConvTranspose1d( self.d * 4,self.d * 2, 25, 4, self.padding,1, bias=True),#(25, 4d, 2d) | (n, 1024, 2d)
            nn.ReLU(inplace),
      
            nn.ConvTranspose1d( self.d * 2, self.d, 25, 4, self.padding,1, bias=True), #(25, 2d, d) | (n, 4096, d)
            nn.ReLU(inplace),
            nn.ConvTranspose1d( self.d, self.c, 25, 4, self.padding,1, bias=True),#(25, d, c) | (n, 16384, c)
            nn.Tanh() # as suggested       
        )

    def forward(self, x):
        #input (n,100)
        x=self.dense1(x) # output (n,256*d)
        x=torch.reshape(x, (-1,16*self.d,16)) # output (n,16,16d)
        
        return self.seq(x) #(n, 16384, c), c=1

    
################### phase suffling ##########################

class PhaseShuffling(nn.Module):
    """
    
    PhaseShuffling layer: shifts the features by a random int value between [-n,n]
    n: shift factor
    
    # paper code in tf https://github.com/chrisdonahue/wavegan/blob/master/wavegan.py
    """
    def __init__(self, n):
        super(PhaseShuffling, self).__init__()
        self.n = n

    def forward(self, x):
        #x:(n batch,channels,xlen)
        if self.n==0:
            return x
        shifts =  int(torch.Tensor(1).random_(0, 2*self.n + 1)) - self.n # same shuffle for data in batch
       
        if shifts > 0:
            return F.pad(x[..., :-shifts], (shifts, 0), mode='reflect')
        else:
            return F.pad(x[..., -shifts:], (0, -shifts), mode='reflect')




##################     Critic      ##########################
    
# See page 15 of the Wavegan paper https://openreview.net/pdf?id=ByMVTsR5KQ        
class WaveDiscriminator(nn.Module):
    """
    Generator for WaveGAN
    
    d: model size (default 64)
    c: number of channels in the data (default 1)
    inplace: boolean (defaul True) #arg for the Leaky Relu  
    
    See page 15 of the Wavegan paper https://openreview.net/pdf?id=ByMVTsR5KQ 
    
    """
    def __init__(self,d=64, c=1,n=2,inplace=True):
        super(WaveDiscriminator, self).__init__()
     
        self.d=d # model size
        self.c=c # = 1 in the paper
        self.padding=11
        leak=0.2
        self.n=n
        
        self.dense= nn.Linear(256*self.d, 1)
        self.seq = nn.Sequential(
            # input is audio or WaveGenerator(z)
            nn.Conv1d( self.c, self.d, 25, 4, self.padding, bias=True), #(n,4096,d)
            nn.LeakyReLU(leak,inplace=inplace),
            PhaseShuffling(n=self.n),
            
            nn.Conv1d( self.d, 2*self.d, 25, 4, self.padding, bias=True),  #(n,1024,2d)
            nn.LeakyReLU(leak,inplace=inplace),
            PhaseShuffling(n=self.n),
            
            nn.Conv1d( self.d * 2, self.d * 4, 25, 4, self.padding, bias=True),  #(n,256,4d)
            nn.LeakyReLU(leak,inplace=inplace),
            PhaseShuffling(n=self.n),
            
            nn.Conv1d( self.d * 4, self.d * 8, 25, 4, self.padding, bias=True),  #(n,64,8d)
            nn.LeakyReLU(leak,inplace=inplace),
            PhaseShuffling(n=self.n),
            
            nn.Conv1d( self.d * 8, self.d * 16, 25, 4, self.padding, bias=True),  #(n,16,16d)
            nn.LeakyReLU(leak,inplace=inplace)
        )
               

    def forward(self, x):
        x=self.seq(x)
        x=torch.reshape(x, (-1,256*self.d)) 
        #print(x.shape)
        return self.dense(x) 


def testing():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Create the generator
    waveG = WaveGenerator().to(device)
    #initialize wights
    waveG.apply(initialize_weights)
    print(waveG)

    # Create the Discriminator
    waveD = WaveDiscriminator().to(device)
    #initialize wights
    waveD.apply(initialize_weights)
    print(waveD)


    N, in_channels, L = 8, 1, 16384
    noise_dim = 100
    x = torch.randn((N, in_channels,L))
    waveD = WaveDiscriminator()
    print(waveD(x).shape)
    waveG = WaveGenerator(d=64,c=1)
    z = torch.randn((N, 100))
    print(waveG(z).shape)
    print(waveD(waveG(z)))
    
testing()

WaveGenerator(
  (dense1): Linear(in_features=100, out_features=16384, bias=True)
  (seq): Sequential(
    (0): ReLU(inplace=True)
    (1): ConvTranspose1d(1024, 512, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (2): ReLU(inplace=True)
    (3): ConvTranspose1d(512, 256, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (4): ReLU(inplace=True)
    (5): ConvTranspose1d(256, 128, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (6): ReLU(inplace=True)
    (7): ConvTranspose1d(128, 64, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (8): ReLU(inplace=True)
    (9): ConvTranspose1d(64, 1, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (10): Tanh()
  )
)
WaveDiscriminator(
  (dense): Linear(in_features=16384, out_features=1, bias=True)
  (seq): Sequential(
    (0): Conv1d(1, 64, kernel_size=(25,), stride=(4,), padding=(11,))
    (1): LeakyReLU(negative_slope=0.2, inplace=T