In [None]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
import torchvision

import numpy as np

import sys
import datetime
import librosa
import os
from IPython.core.debugger import set_trace

In [None]:
sample_rate = 44100
seconds = 30

placeholder_dataset = []

for wav_file in os.listdir("./data"):
  if wav_file.endswith(".wav"):
    y, sample_rate = librosa.load(path = os.path.join("./data/", wav_file), sr = sample_rate, mono = True)
    y = y[y != 0]
    duration = y.shape[0] // sample_rate
    for i in range(0, duration, seconds):
      placeholder_dataset.append(y[i * sample_rate : (i + seconds) * sample_rate])

num_subsamples = len(placeholder_dataset)

dataset = np.empty((num_subsamples, sample_rate * seconds), np.float32)
for data in placeholder_dataset:
  np.append(dataset, data)

In [None]:
dataset = np.load("./data/dataset.npy")
dataset

In [None]:
np.save("./data/dataset.npy", dataset)

In [None]:
class Discriminator(nn.Module):
  def __init__(self, input_features, output_features):
    super(Discriminator, self).__init__()
    self.input_features = input_features
    self.output_features = output_features

    self.l_in = nn.Linear(self.input_features, 64, bias = False)
    self.h1 = nn.Linear(64, 32, bias = False)
    self.batch_norm = nn.BatchNorm1d(64, eps = 1e-03, momentum = 0.5)
    self.l_out = nn.Linear(32, output_features, bias = False)

  def forward(self, x):
    x = F.dropout(F.leaky_relu(self.l_in(x), 0.2, inplace=True), 0.2)
    x = self.batch_norm(x)
    x = F.dropout(F.leaky_relu(self.h1(x), 0.2, inplace=True), 0.2)
    x = torch.sigmoid(self.l_out(x))
    return x

class Generator(nn.Module):
  def __init__(self, input_features, output_features):
    super(Generator, self).__init__()
    self.input_features = input_features
    self.output_features = output_features

    self.l_in = nn.Linear(self.input_features, 32)
    self.batch_norm1 = nn.BatchNorm1d(32, eps = 1e-04, momentum = 0.4)
    self.h1 = nn.Linear(32, 64)
    self.batch_norm2 = nn.BatchNorm1d(64, eps = 1e-04, momentum = 0.2)
    self.h2 = nn.Linear(64, 128)
    self.l_out = nn.Linear(128, output_features)
  
  def forward(self, x):
    x = F.relu(self.l_in(x), inplace = True)
    x = self.batch_norm1(x)
    x = F.dropout(F.relu(self.h1(x), inplace = True), 0.2)
    x = self.batch_norm2(x)
    x = F.dropout(F.relu(self.h2(x), inplace = True), 0.2)
    x = torch.tanh(self.l_out(x))
    return x

class GAN():
  def __init__(self, dataset, batch_size, shuffle, song_features, noise_vector_latent_dim, num_output_samples):
    
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    self.dataset = dataset
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.song_features = song_features

    self.data_loader = torch.utils.data.DataLoader(self.dataset, batch_size = self.batch_size, shuffle = self.shuffle)
    self.num_batches = len(self.data_loader)
    
    self.noise_vector_latent_dim = noise_vector_latent_dim
    self.num_output_samples = num_output_samples

    self.discriminator = Discriminator(input_features = song_features, output_features = 1)
    self.generator = Generator(input_features = noise_vector_latent_dim, output_features = song_features)
    self.discriminator = self.discriminator.to(self.device)
    self.generator = self.generator.to(self.device)

    self.d_opt = optim.RMSprop(self.discriminator.parameters(), lr = 0.001, alpha = 0.7, eps = 1e-05, weight_decay = 1e-03)
    self.g_opt = optim.RMSprop(self.generator.parameters(), lr = 0.001, alpha = 0.7, eps = 1e-05, weight_decay = 1e-03)
    
    self.samples = []

    self.BCELoss = nn.BCELoss()
    self.BCELoss = self.BCELoss.to(self.device)

  def train_disc(self, opt, real, fake, step):
    opt.zero_grad()
    
    smoothed_labels = np.zeros((real.size(0), 1), dtype = np.float32)
    for i in range(len(smoothed_labels)):
      smoothed_labels[i] = 0.9
    smoothed_labels = torch.from_numpy(smoothed_labels)
    smoothed_labels = smoothed_labels.to(self.device)

    pred_real = self.discriminator(real)
    error_real = self.BCELoss(pred_real, smoothed_labels)
    error_real.backward()

    pred_fake = self.discriminator(fake)
    error_fake = self.BCELoss(pred_fake, torch.zeros(real.size(0), 1).to(self.device))
    error_fake.backward()

    opt.step()

    return error_real, error_fake, error_real + error_fake

  def train_gen(self, opt, fake, step):
    opt.zero_grad()
    
    smoothed_labels = np.zeros((fake.size(0), 1), dtype = np.float32)
    for i in range(len(smoothed_labels)):
      smoothed_labels[i] = 0.9
    smoothed_labels = torch.from_numpy(smoothed_labels)
    smoothed_labels = smoothed_labels.to(self.device)
    
    pred_fake = self.discriminator(fake)
    error_fake = self.BCELoss(pred_fake, smoothed_labels)
    error_fake.backward()

    opt.step()

    return error_fake

  def noise(self,  N):
    x = torch.randn((N, self.noise_vector_latent_dim))
    return x.to(self.device)
    
  def challenge_discriminator(self, real: torch.Tensor, noise_size: int, rate: float):
    chance = np.random.randint(0, 100)
    real = real.to(self.device)
    x = torch.randn(noise_size)
    if chance <= int(rate * 100):
      return real + 0.2 * x.to(self.device)
    else:
      return real

  def vec2wave(self, vec, size):
    return vec.view(vec.size(0), size)

  def train(self, epochs, start_epoch, eval_every, save_every):
    step = 0

    test_noise = self.noise(self.num_output_samples)
    test_noise.to(self.device)
    
    torch.backends.cudnn.benchmark = True

    sys.stdout.write("\r" + "Going into train mode")
    self.discriminator.train()
    self.generator.train()
    
    for epoch in range(start_epoch, epochs):
      for n_batch, real in enumerate(self.data_loader):
        N = real.size(0)
        step += 1

        real = real.view(N, self.song_features)

        noisify_real_rate = 0.01
        if step % 50 == 0:
          noisify_real_rate = 0.3
        if step % 100 == 0:
          noisify_real_rate = 0.5
        if step % 1000 == 0:
          noisify_real_rate = 0.7

        real = self.challenge_discriminator(real = real, noise_size = self.song_features, rate = noisify_real_rate)
        real = real.to(self.device)
        
        fake = self.generator(self.noise(N)).detach()
        fake = fake.to(self.device)
        
        d_error_real, d_error_fake, d_error_total = self.train_disc(self.d_opt, real, fake, step)

        fake = self.generator(self.noise(N))
        fake = fake.to(self.device)
        
        g_error = self.train_gen(self.g_opt, fake, step)
        
        sys.stdout.write("\r" + f"d_error_real = {d_error_real:.2f} -> d_error_fake = {d_error_fake:.2f} -> d_error_total = {d_error_total:.2f} -> g_error = {g_error:.2f} -> epoch = {epoch + 1} -> batch = {n_batch + 1} / {self.num_batches}")

        if (epoch + 1) % eval_every == 0 and n_batch == 0:
          sys.stdout.write("\r" + "Updating list of samples")
          self.samples.append(self.vec2wave(self.generator(test_noise), self.song_features).cpu().data)
          np.save(f"./djenerated_samples_raw/{self.num_output_samples}_samples_at_epoch_{epoch + 1}.npy", self.samples[-1].numpy())
        
        if (epoch + 1) % save_every == 0 and n_batch == 0:
            sys.stdout.write("\r" + "Saving Discriminator model | Saving Generator model")
            torch.save(
              {
                  "epoch" : epoch,
                  "model_state_dict" : self.discriminator.state_dict(),
                  "optimizer_state_dict" : self.d_opt.state_dict()
              }, 
              "./models/discriminator.pth")

            torch.save(
            {
              "epoch" : epoch,
              "model_state_dict" : self.generator.state_dict(),
              "optimizer_state_dict" : self.g_opt.state_dict()
            }, 
            "./models/generator.pth")    

        
  def resume_training(self, epochs, eval_every, save_every):
    sys.stdout.write("\r" + "Loading checkpoints")
    discriminator_checkpoint = torch.load("./models/discriminator.pth")
    generator_checkpoint = torch.load("./models/generator.pth")

    sys.stdout.write("\r" + "Getting most recent epoch")
    start_epoch = discriminator_checkpoint['epoch']
    
    sys.stdout.write("\r" + "Loading optimizers")
    self.d_opt.load_state_dict(discriminator_checkpoint['optimizer_state_dict'])
    self.g_opt.load_state_dict(generator_checkpoint['optimizer_state_dict'])

    sys.stdout.write("\r" + "Loading models")
    self.discriminator.load_state_dict(discriminator_checkpoint['model_state_dict'])
    self.generator.load_state_dict(generator_checkpoint['model_state_dict'])

    self.discriminator = self.discriminator.to(self.device)
    self.generator = self.generator.to(self.device)
    
    sys.stdout.write("\r" + "Fetching batch norm gradients")
    self.discriminator.eval()
    self.generator.eval()
    
    self.train(epochs = epochs, start_epoch = start_epoch, eval_every = eval_every, save_every = save_every)

  def load_generator(self):
    generator_checkpoint = torch.load("./models/generator.pth")
    self.generator.load_state_dict(generator_checkpoint['model_state_dict'])
    return self.generator

  def get_all_generated_samples(self):
    return self.samples

In [None]:
gan = GAN(
  dataset = dataset,
  batch_size = 9,
  shuffle = True,
  song_features = sample_rate * seconds, 
  noise_vector_latent_dim = 100,
  num_output_samples = 9
)

In [None]:
gan.train(start_epoch = 0, epochs = 100000, eval_every = 1000, save_every = 2500)