In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# import sys
# sys.path.insert(0, '../')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"


import numpy as np
import tensorflow as tf
from gantools import evaluation
from gantools import data
from gantools import utils
from gantools import plot
from gantools.model import WGAN, LapWGAN
from gantools.gansystem import GANsystem
import matplotlib.pyplot as plt
import pickle
from gantools import blocks

from scipy.io import wavfile
from IPython.core.display import HTML

# Useful functions

In [None]:


# this is a wrapper that take a filename and publish an html <audio> tag to listen to it

def wavPlayer(filepath):
    """ will display html 5 player for compatible browser

    Parameters :
    ------------
    filepath : relative filepath with respect to the notebook directory ( where the .ipynb are not cwd)
               of the file to play

    The browser need to know how to play wav through html5.

    there is no autoplay to prevent file playing when the browser opens
    """
    
    src = """
    <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <title>Simple Test</title>
    </head>
    
    <body>
    <audio controls="controls" style="width:600px" >
      <source src="files/%s" type="audio/wav" />
      Your browser does not support the audio element.
    </audio>
    </body>
    """%(filepath)
    display(HTML(src))

def play_sound(x, fs, filename=None):
    if filename is None:
        filename = str(np.random.randint(10000))+'.wav'
    wavfile.write(filename, np.int(fs), (x*(2**15)).astype(np.int16))
    wavPlayer(filename)


def load_gan(savepath, system=GANsystem,model=WGAN):
    import gantools
    pathparams = os.path.join(savepath, 'params.pkl')
    with open(pathparams, 'rb') as f:          
        params = params = pickle.load(f)
    params['save_dir'] = savepath
    return system(model, params)

def plot_signals(sigs,
                nx=1,
                ny=1,
                *args,
                **kwargs):
    """
    Draw multiple images. This function conveniently draw multiple images side
    by side.

    Parameters
    ----------
    sigs : List of Signales
        - Matrix [ n , sx ]
    """
    ndim = len(sigs.shape)
    nimg = sigs.shape[0]

    if ndim == 1:
        raise ValueError('The input seems to contain only one signal')
    elif ndim == 2:
        if nx*ny>nimg:
            raise ValueError("Not enough signals")
    else:
        raise ValueError('The input contains to many dimensions')

    f, ax = plt.subplots(ny, nx, sharey=True, figsize=(4*nx,3*ny))
    it = 0
    lim = np.max(np.abs(sigs))
    xlim = (-lim, lim) 
    for i in range(nx):
        for j in range(ny):
            if nx==1 or ny==1:
                ax[j+i].plot(sigs[it])
                ax[j+i].set_ylim(xlim)                
            else:
                ax[j,i].plot(sigs[it])
                ax[j,i].set_ylim(xlim)
            it += 1


# Parameters

In [None]:
globalpath = '../saved_results/nsynth/'


In [None]:
scalings = 4**np.arange(4,-1,-1)
fs = 16000/scalings
nsamples = 2**15//scalings
Nsamples = 100
samples = [None]*len(scalings)

# Load real data

In [None]:
dataset = data.load.load_nsynth_dataset()
sample_real_final = dataset.get_samples(100)
samples_real = []
for scaling in scalings:
    samples_real.append(blocks.np_downsample_1d(sample_real_final, scaling))

# Generate all new samples

In [None]:
for n in range(len(scalings)):
    savepath = os.path.join(globalpath, 'WGAN_nsynth_{}_checkpoints'.format(nsamples[n]))
    if n==0:
        obj = load_gan(savepath, model=WGAN)
        samples[n] = obj.generate(N=Nsamples)
    else:
        obj = load_gan(savepath, model=LapWGAN)
        samples[n] = obj.generate(X_down=samples[n-1])
samples = [np.squeeze(sample) for sample in samples]

In [None]:
for scaling, sample in zip(scalings,samples):
    plot_signals(sample[:],nx=4,ny=1)
    plt.suptitle('Fake, scaling {}'.format(scaling))
    
for scaling, sample in zip(scalings,samples_real):
    plot_signals(sample[:],nx=4,ny=1)
    plt.suptitle('Real, scaling {}'.format(scaling))

In [None]:
n_select = 14
for n, scaling in enumerate(scalings):
    print('Downsampling {}'.format(scaling))
    cfs = fs[n]
    sig = samples[n][n_select]
    x=np.arange(len(sig))/cfs
    plt.figure(figsize=(7,3))
    plt.plot(x,sig)
    play_sound(sig, cfs)

In [None]:
n_select = 14
for n, scaling in enumerate(scalings):
    print('Downsampling {}'.format(scaling))
    cfs = fs[n]
    sig = samples_real[n][n_select]
    x=np.arange(len(sig))/cfs
    plt.figure(figsize=(7,3))
    plt.plot(x,sig)
    play_sound(sig, cfs)

# Step by step evaluation

In [None]:
samples_step = [None]*len(scalings)
for n in range(len(scalings)):
    savepath = os.path.join(globalpath, 'WGAN_nsynth_{}_checkpoints'.format(nsamples[n]))
    if n==0:
        samples_step[n] = samples_real[n]
    else:
        obj = load_gan(savepath, model=LapWGAN)
        X_down = samples_real[n-1]
        samples_step[n] = obj.generate(X_down=np.reshape(X_down, [*X_down.shape[:2],1]))
samples_step = [np.squeeze(sample) for sample in samples_step]

In [None]:
for scaling, fake, real in zip(scalings,samples_step, samples_real):
    plot_signals(fake[:],nx=4,ny=1)
    plt.suptitle('Fake, scaling {}'.format(scaling))
    plot_signals(real[:],nx=4,ny=1)
    plt.suptitle('Real, scaling {}'.format(scaling))  


In [None]:
n_select = 14
for n, scaling in enumerate(scalings):
    cfs = fs[n]
    sig = samples_step[n][n_select]
    x=np.arange(len(sig))/cfs
    print('Downsampling {} - fake'.format(scaling))
    play_sound(sig, cfs)
    print('Downsampling {} - real'.format(scaling))
    play_sound(samples_real[n][n_select], cfs)
    
    plt.figure(figsize=(7,3))
    plt.plot(x,sig)
