## Introduction

This is a brief introduction to music generation using **Generative Adversarial Networks** (**GAN**s) emulating the AWS DeepComposer. 

The goal of our tutorial is to train a machine learning model using a dataset of Bach compositions so that the model learns to add accompaniments to a single track input melody.

**What is GAN?**

The algorithm consists of two competing networks: a generator and discriminator. Generator is a deep neural network that learns to create new synthetic data that resembles the distribution of the dataset on which it was trained. Discriminator is deep neural network that is trained to differentiate between real and synthetic data. The generator and discriminator are trained in alternating cycles such that generator learns to produce more and more realistic data while the discriminator iteratively gets better at learning to differentiate real data (Bach music) from the synthetic ones.
As a result, the quality of music produced by the generator gets more and more realistic with time.

## Dependencies
First, let's import all of the python packages we will use throughout the tutorial.


In [None]:

# Create the environment
import subprocess
print("Please wait, while the required packages are being installed...")
subprocess.call(['./requirements.sh'], shell=True)
print("All the required packages are installed successfully...")

In [None]:
# IMPORTS
import os 
import numpy as np
from PIL import Image
import logging
import pypianoroll
import scipy.stats
import pickle
import music21
from IPython import display
import matplotlib.pyplot as plt

# Configure Tensorflow
import tensorflow as tf
print(tf.__version__)
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()

# Use this command to make a subset of GPUS visible to the jupyter notebook.
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

# Utils library for plotting, loading and saving midi among other functions
from utils import display_utils, metrics_utils, path_utils, inference_utils, midi_utils

LOGGER = logging.getLogger("gan.train")
%matplotlib inline

## Configuration

Here we configure paths to retrieve our dataset and save our experiments.

In [None]:
root_dir = './Experiments'

# Directory to save checkpoints
model_dir = os.path.join(root_dir,'2Bar')    # JSP: 229, Bach: 19199

# Directory to save pianorolls during training
train_dir = os.path.join(model_dir, 'train')

# Directory to save checkpoint generated during training
check_dir = os.path.join(model_dir, 'preload')

# Directory to save midi during training
sample_dir = os.path.join(model_dir, 'sample')

# Directory to save samples generated during inference
eval_dir = os.path.join(model_dir, 'eval')

os.makedirs(train_dir, exist_ok=True)
os.makedirs(eval_dir, exist_ok=True)
os.makedirs(sample_dir, exist_ok=True)


## Data Preparation

### Dataset summary

In this tutorial, we use the [`JSB-Chorales-dataset`](http://www-etud.iro.umontreal.ca/~boulanni/icml2012), comprising 229 chorale snippets. A chorale is a hymn that is usually sung with a single voice playing a simple melody and three lower voices providing harmony. In this dataset, these voices are represented by four piano tracks.

Let's listen to a song from this dataset.

In [None]:
display_utils.playmidi('./original_midi/MIDI-0.mid')

### Data format - piano roll

For the purpose of this tutorial, we represent music from the JSB-Chorales dataset in the piano roll format.

**Piano roll** is a discrete representation of music which is intelligible by many machine learning algorithms. Piano rolls can be viewed as a two-dimensional grid with "Time" on the horizontal axis and "Pitch" on the vertical axis. A one or zero in any particular cell in this grid indicates if a note was played or not at that time for that pitch.


**Why 32 time steps?**

For the purpose of this tutorial, we sample two non-empty bars from each song in the JSB-Chorales dataset. A **bar** (or **measure**) is a unit of composition and contains four beats for songs in our particular dataset (our songs are all in 4/4 time) :

We’ve found that using a resolution of four time steps per beat captures enough of the musical detail in this dataset.

This gives...

$$ \frac{4\;timesteps}{1\;beat} * \frac{4\;beats}{1\;bar} * \frac{2\;bars}{1} = 32\;timesteps $$

Let us now load our dataset as a numpy array. Our dataset comprises 229 samples of 4 tracks (all tracks are piano). Each sample is a 32 time-step snippet of a song, so our dataset has a shape of...
(num_samples, time_steps, pitch_range, tracks) = (229, 32, 128, 4).

In [None]:
training_data = np.load('./dataset/train.npy')
print(training_data.shape)

Let's see a sample of the data we'll feed into our model. The four graphs represent the four tracks.

In [None]:
display_utils.show_pianoroll(training_data)

### Load data 

We now create a Tensorflow dataset object from our numpy array to feed into our model. The dataset object helps us feed batches of data into our model. A batch is a subset of the data that is passed through the deep learning network before the weights are updated. 

In [None]:
#Number of input data samples in a batch
BATCH_SIZE = 64

#Shuffle buffer size for shuffling data
SHUFFLE_BUFFER_SIZE = 1000

#Preloads PREFETCH_SIZE batches so that there is no idle time between batches
PREFETCH_SIZE = 4

In [None]:
def prepare_dataset(filename):
    
    """Load the samples used for training."""
    
    data = np.load(filename)
    data = np.asarray(data, dtype=np.float32)  # {-1, 1}

    print('data shape = {}'.format(data.shape))

    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.shuffle(SHUFFLE_BUFFER_SIZE).repeat()
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(PREFETCH_SIZE)

    return dataset 

dataset = prepare_dataset('./dataset/train.npy')

## Model architecture

The model consists of two networks, a generator and discriminator. The role of two networks is as follows:

* Generator:
    1. The generator takes in a batch of single-track piano rolls (melody) as the input and generates a batch of multi-track piano rolls as the output by adding accompaniments to each of the input music tracks. 
    2. The discriminator then takes these generated music tracks and predicts how far it deviates from the real data present in your training dataset.
    3. This feedback from discriminator is used by the generator to update its weights.
* Discriminator: As the generator gets better at creating better music accompaniments using the feedback from the discriminator, the discriminator needs to be retrained as well.
    1. Train discriminator with the music tracks just generated by the generator as fake inputs and an equivalent number of songs from the original dataset as the real input. 
* Alternate between training these two networks until the model converges and produces realistic music, beginning with the critic on the first iteration.

We use a special type of GAN called the **Wasserstein GAN with Gradient Penalty** (or **WGAN-GP**) to generate music.

### Generator

The generator is adapted from the U-Net architecture, consisting of an encoder that maps the single track music data (represented as piano roll images) to a relatively lower dimensional latent space and a decoder that maps the latent space back to multi-track music data.

Here are the inputs provided to the generator:

**Single-track piano roll input**: A single melody track of size (32, 128, 1) => (TimeStep, NumPitches, NumTracks) is provided as the input to the generator. 

**Latent noise vector**: A latent noise vector z of dimension (2, 8, 512) is also passed in as input and this is responsible for ensuring that there is a distinctive flavor to each output generated by the generator, even when the same input is provided.

Notice from the figure below that the encoding layers of the generator on the left side and decoder layer on on the right side are connected to create a U-shape, thereby giving the name U-Net to this architecture.

In this implementation, we build the generator following a simple four-level Unet architecture by combining `_conv2d`s and `_deconv2d`, where `_conv2d` compose the contracting path and `_deconv2d` forms the expansive path. 

In [None]:
def _conv2d(layer_input, filters, f_size=4, bn=True):
    """Generator Basic Downsampling Block"""
    d = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=2,
                               padding='same')(layer_input)
    d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)
    if bn:
        d = tf.keras.layers.BatchNormalization(momentum=0.8)(d)
    return d


def _deconv2d(layer_input, pre_input, filters, f_size=4, dropout_rate=0):
    """Generator Basic Upsampling Block"""
    u = tf.keras.layers.UpSampling2D(size=2)(layer_input)
    u = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=1,
                               padding='same')(u)
    u = tf.keras.layers.BatchNormalization(momentum=0.8)(u)
    u = tf.keras.layers.ReLU()(u)

    if dropout_rate:
        u = tf.keras.layers.Dropout(dropout_rate)(u)
        
    u = tf.keras.layers.Concatenate()([u, pre_input])
    return u

    
def build_generator(condition_input_shape=(32, 128, 1), filters=64,
                    instruments=4, latent_shape=(2, 8, 512)):
    """Buld Generator"""
    c_input = tf.keras.layers.Input(shape=condition_input_shape)
    z_input = tf.keras.layers.Input(shape=latent_shape)

    d1 = _conv2d(c_input, filters, bn=False)
    d2 = _conv2d(d1, filters * 2)
    d3 = _conv2d(d2, filters * 4)
    d4 = _conv2d(d3, filters * 8)

    d4 = tf.keras.layers.Concatenate(axis=-1)([d4, z_input])

    u4 = _deconv2d(d4, d3, filters * 4)
    u5 = _deconv2d(u4, d2, filters * 2)
    u6 = _deconv2d(u5, d1, filters)

    u7 = tf.keras.layers.UpSampling2D(size=2)(u6)
    output = tf.keras.layers.Conv2D(instruments, kernel_size=4, strides=1,
                               padding='same', activation='tanh')(u7)  # 32, 128, 4

    generator = tf.keras.models.Model([c_input, z_input], output, name='Generator')

    return generator

In [None]:
# Models
generator = build_generator()
generator.summary()

Discriminator

The goal of the discriminator is to provide feedback to the generator about how realistic the generated piano rolls are, so that the generator can learn to produce more realistic data. The discriminator provides this feedback by outputting a scalar that  represents how real or fake a piano roll is.
The discriminator tries to classify data as real or fake.We use a simple architecture for the discriminator, composed of four convolutional layers and a dense layer at the end.

In [None]:
def _build_discriminator_layer(layer_input, filters, f_size=4):
    """
    This layer decreases the spatial resolution by 2:

        input:  [batch_size, in_channels, H, W]
        output: [batch_size, out_channels, H/2, W/2]
    """
    d = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=2,
                               padding='same')(layer_input)
    # Critic does not use batch-norm
    d = tf.keras.layers.LeakyReLU(alpha=0.2)(d) 
    return d


def build_discriminator(pianoroll_shape=(32, 128, 4), filters=64):
    """WGAN discriminator(critic)."""
    
    condition_input_shape = (32,128,1)
    groundtruth_pianoroll = tf.keras.layers.Input(shape=pianoroll_shape)
    condition_input = tf.keras.layers.Input(shape=condition_input_shape)
    combined_imgs = tf.keras.layers.Concatenate(axis=-1)([groundtruth_pianoroll, condition_input])


    
    d1 = _build_discriminator_layer(combined_imgs, filters)
    d2 = _build_discriminator_layer(d1, filters * 2)
    d3 = _build_discriminator_layer(d2, filters * 4)
    d4 = _build_discriminator_layer(d3, filters * 8)

    x = tf.keras.layers.Flatten()(d4)
    logit = tf.keras.layers.Dense(1)(x)

    discriminator = tf.keras.models.Model([groundtruth_pianoroll,condition_input], logit,
                                          name='Critic')
    

    return discriminator

In [None]:
# Create the Discriminator

discriminator = build_discriminator()
discriminator.summary() # View discriminator architecture.

## Training

We train our models by searching for model parameters which optimize an objective function. For our WGAN-GP, we have special loss functions that we minimize as we alternate between training our generator and critic networks:

*Generator Loss:*
* We use the Wasserstein (Generator) loss function which is negative of the Critic Loss function. The generator is trained to bring the generated pianoroll as close to the real pianoroll as possible.
    * $\frac{1}{m} \sum_{i=1}^{m} -D_w(G(z^{i}|c^{i})|c^{i})$

*Critic Loss:*

* We begin with the Wasserstein (Critic) loss function designed to maximize the distance between the real piano roll distribution and generated (fake) piano roll distribution.
    * $\frac{1}{m} \sum_{i=1}^{m} [D_w(G(z^{i}|c^{i})|c^{i}) - D_w(x^{i}|c^{i})]$

* We add a gradient penalty loss function term designed to control how the gradient of the critic with respect to its input behaves.  This makes optimization of the generator easier. 
    * $\frac{1}{m} \sum_{i=1}^{m}(\lVert \nabla_{\hat{x}^i}D_w(\hat{x}^i|c^{i}) \rVert_2 -  1)^2 $

In [None]:
# Define the different loss functions

def generator_loss(critic_fake_output):
    """ Wasserstein GAN loss
    (Generator)  -D(G(z|c))
    """
    return -tf.reduce_mean(critic_fake_output)


def wasserstein_loss(critic_real_output, critic_fake_output):
    """ Wasserstein GAN loss
    (Critic)  D(G(z|c)) - D(x|c)
    """
    return tf.reduce_mean(critic_fake_output) - tf.reduce_mean(
        critic_real_output)


def compute_gradient_penalty(critic, x, fake_x):
    
    c = tf.expand_dims(x[..., 0], -1)
    batch_size = x.get_shape().as_list()[0]
    eps_x = tf.random.uniform(
        [batch_size] + [1] * (len(x.get_shape()) - 1))  # B, 1, 1, 1, 1
    inter = eps_x * x + (1.0 - eps_x) * fake_x

    with tf.GradientTape() as g:
        g.watch(inter)
        disc_inter_output = discriminator((inter,c), training=True)
    grads = g.gradient(disc_inter_output, inter)
    slopes = tf.sqrt(1e-8 + tf.reduce_sum(
        tf.square(grads),
        reduction_indices=tf.range(1, grads.get_shape().ndims)))
    gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0))
    
    return gradient_penalty


With our loss functions defined, we associate them with Tensorflow optimizers to define how our model will search for a good set of model parameters. We use the *Adam* algorithm, a commonly used general-purpose optimizer. 

In [None]:
# Setup Adam optimizers for both G and D
generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5, beta_2=0.9)
critic_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5, beta_2=0.9)

# We define our checkpoint directory and where to save trained checkpoints
ckpt = tf.train.Checkpoint(generator=generator,
                           generator_optimizer=generator_optimizer,
                           critic=discriminator,
                           critic_optimizer=critic_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, check_dir, max_to_keep=5)

Now we define the `generator_train_step` and `critic_train_step` functions, each of which performs a single forward pass on a batch and returns the corresponding loss.

In [None]:
@tf.function
def generator_train_step(x, condition_track_idx=0):

    ############################################
    #(1) Update G network: maximize D(G(z|c))
    ############################################

    # Extract condition track to make real batches pianoroll
    c = tf.expand_dims(x[..., condition_track_idx], -1)

    # Generate batch of latent vectors
    z = tf.random.truncated_normal([BATCH_SIZE, 2, 8, 512])

    with tf.GradientTape() as tape:
        fake_x = generator((c, z), training=True)
        fake_output = discriminator((fake_x,c), training=False)

        # Calculate Generator's loss based on this generated output
        gen_loss = generator_loss(fake_output)

    # Calculate gradients for Generator
    gradients_of_generator = tape.gradient(gen_loss,
                                           generator.trainable_variables)
    # Update Generator
    generator_optimizer.apply_gradients(
        zip(gradients_of_generator, generator.trainable_variables))

    return gen_loss


In [None]:
@tf.function
def critic_train_step(x, condition_track_idx=0):

    ############################################################################
    #(2) Update D network: maximize (D(x|c)) + (1 - D(G(z|c))|c) + GradientPenality() 
    ############################################################################

    # Extract condition track to make real batches pianoroll
    c = tf.expand_dims(x[..., condition_track_idx], -1)

    # Generate batch of latent vectors
    z = tf.random.truncated_normal([BATCH_SIZE, 2, 8, 512])

    # Generated fake pianoroll
    fake_x = generator((c, z), training=False)


    # Update critic parameters
    with tf.GradientTape() as tape:
        real_output = discriminator((x,c), training=True)
        fake_output = discriminator((fake_x,c), training=True)
        critic_loss =  wasserstein_loss(real_output, fake_output)

    # Caculate the gradients from the real and fake batches
    grads_of_critic = tape.gradient(critic_loss,
                                               discriminator.trainable_variables)

    with tf.GradientTape() as tape:
        gp_loss = compute_gradient_penalty(critic, x, fake_x)
        gp_loss *= 10.0

    # Calculate the gradients penalty from the real and fake batches
    grads_gp = tape.gradient(gp_loss, discriminator.trainable_variables)
    gradients_of_critic = [g + ggp for g, ggp in
                                  zip(grads_of_critic, grads_gp)
                                  if ggp is not None]

    # Update Critic
    critic_optimizer.apply_gradients(
        zip(gradients_of_critic, discriminator.trainable_variables))

    return critic_loss + gp_loss


Here we log the losses and metrics which we can use to determine when to stop training. 

In [None]:
# We use load_melody_samples() to load 10 input data samples from our dataset into sample_x 
# and 10 random noise latent vectors into sample_z
sample_x, sample_z = inference_utils.load_melody_samples(n_sample=10)

In [None]:
# Number of iterations to train for
iterations = 1000

# Update critic n times per generator update 
n_dis_updates_per_gen_update = 5

# Determine input track in sample_x that we condition on
condition_track_idx = 0 
sample_c = tf.expand_dims(sample_x[..., condition_track_idx], -1)

Let us now train our model!

In [None]:
# Clear out any old metrics we've collected
metrics_utils.metrics_manager.initialize()

# Keep a running list of various quantities:
c_losses = []
g_losses = []

# Data iterator to iterate over our dataset
it = iter(dataset)

for iteration in range(iterations):

    # Train critic
    for _ in range(n_dis_updates_per_gen_update):
        c_loss = critic_train_step(next(it))

    # Train generator
    g_loss = generator_train_step(next(it))

    # Save Losses for plotting later
    c_losses.append(c_loss)
    g_losses.append(g_loss)

    display.clear_output(wait=True)
    fig = plt.figure(figsize=(15, 5))
    line1, = plt.plot(range(iteration+1), c_losses, 'r')
    line2, = plt.plot(range(iteration+1), g_losses, 'k')
    plt.xlabel('Iterations')
    plt.ylabel('Losses')
    plt.legend((line1, line2), ('C-loss', 'G-loss'))
    display.display(fig)
    plt.close(fig)
    
    # Output training stats
    print('Iteration {}, c_loss={:.2f}, g_loss={:.2f}'.format(iteration, c_loss, g_loss))
    
    # Save checkpoints, music metrics, generated output
    if iteration < 100 or iteration % 50 == 0 :
        # Check how the generator is doing by saving G's samples on fixed_noise
        fake_sample_x = generator((sample_c, sample_z), training=False)
        metrics_utils.metrics_manager.append_metrics_for_iteration(fake_sample_x.numpy(), iteration)

        if iteration % 50 == 0:
            # Save the checkpoint to disk.
            ckpt_manager.save(checkpoint_number=iteration) 
        
            fake_sample_x = fake_sample_x.numpy()
    
            # plot the pianoroll
            display_utils.plot_pianoroll(iteration, sample_x[:4], fake_sample_x[:4], save_dir=train_dir)

            # generate the midi
            destination_path = path_utils.generated_midi_path_for_iteration(iteration, saveto_dir=sample_dir)
            midi_utils.save_pianoroll_as_midi(fake_sample_x[:4], destination_path=destination_path)


### We have started training

When using the Wasserstein loss function, we should train the discriminator to converge to ensure that the gradients for the generator update are accurate. 

With WGANs, we can simply train the discriminator several times between generator updates, to ensure it is close to convergence. A typical ratio used is five critic updates to one generator update.


**How do I know when to stop?**
- If the samples meet your expectations
- discriminator loss no longer improving
- The expected value of the musical quality metrics converge to the corresponding expected value of the same metric on the training data

## Evaluate results

Now that we have finished training, let's find out how we did. We will analyze our model in several ways:
1. Examine how the generator and discriminator losses changed while training
2. Understand how certain musical metrics changed while training
3. Visualize generated piano roll output for a fixed input at every iteration and create a video


Let us first restore our last saved checkpoint. If you did not complete training but still want to continue with a pre-trained version, set `TRAIN = False`.

In [None]:
ckpt = tf.train.Checkpoint(generator=generator)
ckpt_manager = tf.train.CheckpointManager(ckpt, check_dir, max_to_keep=5)

ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
print('Latest checkpoint {} restored.'.format(ckpt_manager.latest_checkpoint))

### Plot losses

In [None]:
display_utils.plot_loss_logs(g_losses, c_losses, figsize=(15, 5), smoothing=0.01)

Observe how the critic loss (C_loss in the graph) decays to zero as we train. In WGAN-GPs, the critic loss decreases (almost) monotonically as you train.

### Plot metrics

In [None]:
metrics_utils.metrics_manager.set_reference_metrics(training_data)
metrics_utils.metrics_manager.plot_metrics()

Each row here corresponds to a different music quality metric and each column denotes an instrument track. 

Observe how the expected value of the different metrics (blue scatter) approach the corresponding training set expected values (red) as the number of iterations increase. You might expect to see diminishing returns as the model converges.


### Generated samples during training

The function below helps you probe intermediate samples generated in the training process. Remember that the conditioned input here is sampled from our training data. Let's start by listening to and observing a sample at iteration 0 and then iteration 100. Notice the difference!



In [None]:
# Enter an iteration number (can be divided by 50) and listen to the midi at that iteration
iteration = 50
midi_file = os.path.join(sample_dir, 'iteration-{}.mid'.format(iteration))
display_utils.playmidi(midi_file)    

In [None]:
# Enter an iteration number (can be divided by 50) and look at the generated pianorolls at that iteration
iteration = 50
pianoroll_png = os.path.join(train_dir, 'sample_iteration_%05d.png' % iteration)
display.Image(filename=pianoroll_png)

Let's see how the generated piano rolls change with the number of iterations.

In [None]:
from IPython.display import Video


display_utils.make_training_video(train_dir)
video_path = "movie.mp4"
Video(video_path)

## Inference 

### Generating accompaniment for custom input

Congratulations! You have trained your very own WGAN-GP to generate music. Let us see how our generator performs on a custom input.

The function below generates a new song based on "Twinkle Twinkle Little Star".

In [None]:
latest_midi = inference_utils.generate_midi(generator, eval_dir, input_midi_file='./input_twinkle_twinkle.mid')

NameError: ignored

In [None]:
display_utils.playmidi(latest_midi)

NameError: ignored

We can also take a look at the generated piano rolls for a certain sample, to see how diverse they are!

In [None]:
inference_utils.show_generated_pianorolls(generator, eval_dir, input_midi_file='./input_twinkle_twinkle.mid')

With this we have completed our custom GAN for AWS DeepComposer.