In [None]:
!pip install tensorflow_io
!pip install pyyaml h5py  # Required to save models in HDF5 format

In [None]:
import tensorflow as tf
import random
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas
import tensorflow_datasets as tfds
import time
import librosa.display as lidp
import tensorflow_io as tfio
from tensorflow import keras
from tensorflow.keras import backend
from tensorflow.keras import layers
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv1D
from tensorflow.keras.layers import Conv1DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.constraints import Constraint
from numpy import expand_dims
from numpy import mean
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from IPython import display


In [None]:
import json

In [None]:
data, info = tfds.load('nsynth', try_gcs=True, split='train', with_info=True)
assert isinstance(data, tf.data.Dataset)
dataset = data.shuffle(1024).batch(32)
total = len(dataset)
print(len(dataset))
data = data.shuffle(1024).batch(32).repeat()
db_iter = iter(data)
#get data

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
#login your dirve so that you can save model data for later training

In [None]:
class ClipConstraint(Constraint):
	# set clip value when initialized
	def __init__(self, clip_value):
		self.clip_value = clip_value
 
	# clip model weights to hypercube
	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)
 
	# get the config
	def get_config(self):
		return {'clip_value': self.clip_value}

In [None]:
def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)

In [None]:
def define_critic(d,c,in_shape=(16384,1)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# weight constraint
	const = ClipConstraint(0.01)
 
	# define model
	model = Sequential()

	# downsample
	model.add(Conv1D(d, 25, strides=4, padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
	assert model.output_shape == (None, 4096, d)
	# model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# downsample
	model.add(Conv1D(2*d, 25, strides=4, padding='same', kernel_initializer=init, kernel_constraint=const))
	assert model.output_shape == (None, 1024, 2*d)
	# model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# downsample
	model.add(Conv1D(4*d, 25, strides=4, padding='same', kernel_initializer=init, kernel_constraint=const))
	assert model.output_shape == (None, 256, 4*d)
	# model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# downsample
	model.add(Conv1D(8*d, 25, strides=4, padding='same', kernel_initializer=init, kernel_constraint=const))
	assert model.output_shape == (None, 64, 8*d)
	# model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# downsample
	model.add(Conv1D(16*d, 25, strides=4, padding='same', kernel_initializer=init, kernel_constraint=const))
	assert model.output_shape == (None, 16, 16*d)
	# model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# scoring, linear activation
	model.add(Flatten())
	model.add(Dense(1))

	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	print(model.summary())
	return model

In [None]:
def define_generator(latent_dim,d,c):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# define model
	model = Sequential()

	# foundation
	n_nodes = 256 * d
	model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Reshape((16, 16*d)))
	assert model.output_shape == (None, 16, 16*d)

	# upsample
	model.add(Conv1DTranspose(8*d, 25, strides=4, padding='same', kernel_initializer=init))
	assert model.output_shape == (None, 64, 8*d)
	#model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# upsample
	model.add(Conv1DTranspose(4*d, 25, strides=4, padding='same', kernel_initializer=init))
	assert model.output_shape == (None, 256, 4*d)
	#model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# upsample
	model.add(Conv1DTranspose(2*d, 25, strides=4, padding='same', kernel_initializer=init))
	assert model.output_shape == (None, 1024, 2*d)
	#model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# upsample
	model.add(Conv1DTranspose(d, 25, strides=4, padding='same', kernel_initializer=init))
	assert model.output_shape == (None, 4096, d)
	#model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))

	# output
	model.add(Conv1DTranspose(c, 25, strides=4, activation='tanh', padding='same', kernel_initializer=init))
	assert model.output_shape == (None, 16384, c)
	return model

In [None]:
def define_gan(generator, critic):
	# make weights in the critic not trainable
	critic.trainable = False
	# connect them
	model = Sequential()
	# add generator
	model.add(generator)
	# add the critic
	model.add(critic)
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	print(model.summary())
	return model

In [None]:
def generate_real_samples(dataset, n_samples):
	# choose random instances
	batch = next(dataset)
	sound = batch["audio"]
	sound = sound[:,0:16384]
	X = tf.reshape(sound,[sound.shape[0],16384,1])
	# generate class labels, -1 for 'real'
	y = -ones((sound.shape[0], 1))
	return X, y

In [None]:
def generate_latent_points(latent_dim, n_samples):
	# generate points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

In [None]:
def generate_fake_samples(generator, latent_dim, n_samples):
	# generate points in latent space
	x_input = generate_latent_points(latent_dim, n_samples)
	# predict outputs
	X = generator.predict(x_input)
	# create class labels with 1.0 for 'fake'
	y = ones((n_samples, 1))
	return X, y

In [None]:
def summarize_performance(step, g_model, gan_model, c_model, latent_dim, n_samples=16):
  # filename2 = '/content/gdrive/My Drive/WGAN/gmodel_%04d.h5' % (step+1)
  # g_model.save(filename2)

  # filename3 = '/content/gdrive/My Drive/WGAN/cmodel_%04d.h5' % (step+1)
  # c_model.save(filename3)

  # filename4 = '/content/gdrive/My Drive/WGAN/ganmodel_%04d.h5' % (step+1)
  # gan_model.save(filename4)
  
  checkpoint.save(file_prefix = checkpoint_prefix)
  
  # prepare fake examples
  X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
  fig,axes = plt.subplots(16,1,figsize=(15, 30))
  for i in range(16):
    x = np.linspace(0,16384,16384)
    k = np.reshape(X[i,:,0],(16384))
    display.display(display.Audio(k, rate=16000))
    axes[i].plot(x,k)
  filename1 = '/content/gdrive/My Drive/WGAN/sound_at_epoch_%04d.png' % (step+1)
  plt.savefig(filename1)
  plt.show()

  print('>Saved: %s' % (filename1))

In [None]:
def plot_history(d1_hist, d2_hist, g_hist, count=-1):
	# plot history
	plt.plot(d1_hist, label='crit_real')
	plt.plot(d2_hist, label='crit_fake')
	plt.plot(g_hist, label='gen')
	plt.legend()
	name = '/content/gdrive/My Drive/WGAN/plot_line_plot_loss_%04d.png' % count
	plt.savefig(name)
	plt.show()
	plt.close()
	with open("/content/gdrive/My Drive/WGAN/d1hist.txt", "w") as fp:
		json.dump(d1_hist, fp)
	with open("/content/gdrive/My Drive/WGAN/d2hist.txt", "w") as fpa:
		json.dump(d2_hist, fpa)
	with open("/content/gdrive/My Drive/WGAN/ghist.txt", "w") as fpb:
		json.dump(g_hist, fpb)

In [None]:
c1_hist, c2_hist, g_hist = list(), list(), list()

In [None]:
with open("/content/gdrive/My Drive/WGAN/d1hist.txt", "r") as fp:
  c1_hist = json.load(fp)
with open("/content/gdrive/My Drive/WGAN/d2hist.txt", "r") as fp:
	c2_hist = json.load(fp)
with open("/content/gdrive/My Drive/WGAN/ghist.txt", "r") as fp:
	g_hist = json.load(fp)

In [None]:
# train the generator and critic
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=5, n_batch=64, n_critic=5):
	# calculate the number of batches per training epoch
	bat_per_epo = total
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# lists for keeping track of loss
	# c1_hist, c2_hist, g_hist = list(), list(), list()
	# manually enumerate epochs
	for i in range(n_steps):
		# update the critic more than the generator
		c1_tmp, c2_tmp = list(), list()
		for _ in range(n_critic):
			# get randomly selected 'real' samples
			X_real, y_real = generate_real_samples(dataset, half_batch)
			# update critic model weights
			c_loss1 = c_model.train_on_batch(X_real, y_real)
			c1_tmp.append(c_loss1)
			# generate 'fake' examples
			X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
			# update critic model weights
			c_loss2 = c_model.train_on_batch(X_fake, y_fake)
			c2_tmp.append(c_loss2)			
		# store critic loss
		c1_hist.append(mean(c1_tmp))
		c2_hist.append(mean(c2_tmp))
		# prepare points in latent space as input for the generator
		X_gan = generate_latent_points(latent_dim, n_batch)
		# create inverted labels for the fake samples
		y_gan = -ones((n_batch, 1))
		# update the generator via the critic's error
		g_loss = gan_model.train_on_batch(X_gan, y_gan)
		g_hist.append(g_loss)
		# summarize loss on this batch
		print('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))
		# evaluate the model performance every 'epoch'
		if (i+1) % (bat_per_epo//3) == 0:
			summarize_performance(len(c1_hist), g_model, gan_model, c_model, latent_dim)
			plot_history(c1_hist, c2_hist, g_hist, len(c1_hist))
	# line plots of loss
	plot_history(c1_hist, c2_hist, g_hist,len(c1_hist))

In [None]:
# size of the latent space
latent_dim = 100
# create the critic
critic = keras.models.load_model('/content/gdrive/My Drive/WGAN/cmodel_33132.h5',custom_objects={'ClipConstraint': ClipConstraint,'wasserstein_loss': wasserstein_loss})
# create the generator
generator = keras.models.load_model('/content/gdrive/My Drive/WGAN/gmodel_33132.h5')
# create the gan
gan_model = keras.models.load_model('/content/gdrive/My Drive/WGAN/ganmodel_33132.h5',custom_objects={'ClipConstraint': ClipConstraint,'wasserstein_loss': wasserstein_loss})

print(critic.summary())
print(gan_model.summary())
# print(critic.optimizer)

# train model
# train(generator, critic, gan_model, db_iter, latent_dim)

In [None]:
generator = gan_model.get_layer("sequential_4")
print(generator)
print(gan_model.get_layer("sequential_4"))

In [None]:
step = -3
X, _ = generate_fake_samples(generator, latent_dim, 16)
fig,axes = plt.subplots(16,1,figsize=(15, 30))
for i in range(16):
  x = np.linspace(0,16384,16384)
  k = np.reshape(X[i,:,0],(16384))
  display.display(display.Audio(k, rate=16000))
  axes[i].plot(x,k)
filename1 = '/content/gdrive/My Drive/WGAN/sound_at_epoch_%04d.png' % (step+1)
plt.savefig(filename1)
plt.show()

In [None]:
train(generator, critic, gan_model, db_iter, latent_dim)

In [None]:
# size of the latent space
latent_dim = 100
# create the critic
critic = define_critic(64,1)
# create the generator
generator = define_generator(latent_dim,64,1)
# create the gan
gan_model = define_gan(generator, critic)
# train model
# train(generator, critic, gan_model, db_iter, latent_dim)

In [None]:
checkpoint_dir = '/content/gdrive/My Drive/WGAN'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator=generator,
                  critic=critic,
                  gan_model=gan_model)

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
generator.save("/content/gdrive/My Drive/WGAN/bmodel")

In [None]:
model = keras.models.load_model("/content/gdrive/My Drive/WGAN/bmodel")

In [None]:
train(generator, critic, gan_model, db_iter, latent_dim, n_epochs=5)

In [None]:
summarize_performance(9038,generator,critic,100,16)

In [None]:
plot_history(c1_hist,c2_hist,g_hist,1)