In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# import sys
# sys.path.insert(0, '../')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"


import numpy as np
import tensorflow as tf

from gantools import data
from gantools import utils
from gantools import plot
from gantools.model import WGAN, LapWGAN, UpscalePatchWGAN, UpscalePatchWGANBorders
from gantools.gansystem import GANsystem
from gantools.data import fmap
from gantools import evaluation
import functools
import matplotlib.pyplot as plt
from copy import deepcopy

# Parameters

In [None]:
downscale = 4

# Data handling

Load the data

In [None]:
dataset = data.load.load_audio_dataset(scaling=downscale, patch=False, spix=4096, augmentation=True, smooth=4, type='piano')

In [None]:
# The dataset can return an iterator.
it = dataset.iter(10)
print(next(it).shape)
del it

In [None]:
# Get all the data
X = dataset.get_all_data()[:,:,0].flatten()

Display the histogram of the pixel densities after the forward map

In [None]:
plt.hist(X, 100)
print('min: {}'.format(np.min(X)))
print('max: {}'.format(np.max(X)))
plt.yscale('log')

In [None]:
# to free some memory
del X

Let us plot 16 samples

In [None]:
plot.audio.plot_signals(dataset.get_samples(16)[:,:,0],nx=4,ny=4)


In [None]:
plot.audio.play_sound(dataset.get_samples(16)[0,:,0], fs=16000//downscale)

# Define parameters for the WGAN

In [None]:
time_str = 'piano_8k_patch2'
global_path = 'saved_results'

name = 'WGAN' + '_' + time_str


## Parameters

In [None]:
bn = False

md = 64

params_discriminator = dict()
params_discriminator['stride'] = [2, 2, 2, 2, 2]
params_discriminator['nfilter'] = [2*md, 2*md, 2*md, 2*md, 2*md]
params_discriminator['shape'] = [[25], [25], [25], [25], [25]]
params_discriminator['batch_norm'] = [bn, bn, bn, bn, bn]
params_discriminator['full'] = []
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True
params_discriminator['data_size'] = 1

params_generator = dict()
params_generator['stride'] = [1, 1, 1, 1, 1]
params_generator['latent_dim'] = 32*32
params_generator['nfilter'] = [md, md, md, md, 1]
params_generator['shape'] = [[25], [25], [25], [25], [25]]
params_generator['batch_norm'] = [bn, bn, bn, bn]
params_generator['full'] = []
params_generator['summary'] = True
params_generator['non_lin'] = tf.nn.tanh
params_generator['data_size'] = 1
params_generator['borders'] = dict()
params_generator['borders']['width_full'] = None
params_generator['borders']['nfilter'] = [4, 8, 7]
params_generator['borders']['batch_norm'] = [bn, bn, bn]
params_generator['borders']['shape'] = [[25], [25], [25]]
params_generator['borders']['stride'] = [2, 4, 2]
params_generator['borders']['data_size'] = 1
params_generator['borders']['width_full'] = 128

params_optimization = dict()
params_optimization['batch_size'] = 64
params_optimization['epsilon'] = 1e-8
params_optimization['epoch'] = 10000
params_optimization = dict()
params_optimization['batch_size'] = 64
params_optimization['epsilon'] = 1e-8
params_optimization['epoch'] = 10000


# all parameters
params = dict()
params['net'] = dict() # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['discriminator'] = params_discriminator
params['net']['prior_distribution'] = 'gaussian'
params['net']['shape'] = [4096, 2] # Shape of the image
params['net']['gamma_gp'] = 10 # Gradient penalty
params['net']['upsampling'] = 4
params['net']['fs'] = 16000//downscale

params['optimization'] = params_optimization
params['summary_every'] = 100 # Tensorboard summaries every ** iterations
params['print_every'] = 50 # Console summaries every ** iterations
params['save_every'] = 1000 # Save the model every ** iterations
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')



In [None]:
resume, params = utils.test_resume(True, params)
params['optimization']['epoch'] = 10000


# Build the model

In [None]:
wgan = GANsystem(UpscalePatchWGANBorders, params)

# Train the model

In [None]:
wgan.train(dataset, resume=resume)

# Generate new samples
To have meaningful statistics, be sure to generate enough samples
* 2000 : 32 x 32
* 500 : 64 x 64
* 200 : 128 x 128


In [None]:
with tf.Session() as sess:
    wgan.load(sess=sess)

In [None]:
N = 2000 # Number of samples
gen_sample = np.squeeze(wgan.generate(N=N))

Display a few fake samples

In [None]:
plot_signals(gen_sample,nx=4,ny=4);
plt.suptitle("Fake samples");

# Evaluation of the sample quality

In [None]:
plot.audio.play_sound(gen_sample[0,:], fs=16000//downscale)