Copyright 2017 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# MusicVAE: A hierarchical recurrent variational autoencoder for music.
### ___Adam Roberts and Jesse Engel___

MusicVAE learns a latent space of musical sequences, providing different modes
of interactive musical creation, including:

* Random sampling from the prior distribution.
* Interpolation between existing sequences.
* Manipulation of existing sequences via a [latent constraint model](https://goo.gl/STGMGx).

Examples of these interactions can be generated below, and selections can be heard in our
[YouTube playlist](https://www.youtube.com/playlist?list=PLBUMAYA6kvGU8Cgqh709o5SUvo-zHGTxr).

For short sequences (e.g., 2-bar "loops"), we use a bidirectional LSTM encoder
and LSTM decoder. For longer sequences, we use a novel hierarchical LSTM
decoder, which helps the model learn longer-term structures.

We also model the interdependencies between instruments by training multiple
decoders on the lowest-level embeddings of the hierarchical decoder.

For additional details, check out the code in our GitHub [repository](https://github.com/tensorflow/magenta/tree/master/magenta/models/music_vae) and our [paper](https://nips2017creativity.github.io/doc/Hierarchical_Variational_Autoencoders_for_Music.pdf).
___

This colab notebook is self-contained and should run natively on google cloud. The [code](https://github.com/tensorflow/magenta/tree/master/magenta/models/music_vae) and [checkpoints](http://download.magenta.tensorflow.org/models/music_vae/checkpoints.tar.gz) can be downloaded separately and run locally, which is required if you want to train your own model.

# Basic Instructions

1. Double click on the hidden cells to make them visible, or select "View > Expand Sections" in the menu at the top.
2. Hover over the "`[ ]`" in the top-left corner of each cell and click on the "Play" button to run it, in order.
3. Listen to the generated samples.
4. Make it your own: copy the notebook, modify the code, train your own models, upload your own MIDI, etc.!

# Environment Setup
Includes package installation for sequence synthesis. Will take a few minutes.


In [0]:
import glob

print 'Copying checkpoints and example MIDI from GCS. This may take a few minutes...'
!gsutil -q -m cp -R gs://download.magenta.tensorflow.org/models/music_vae/colab/* /content/

print 'Installing dependencies...'
!apt-get update -qq && apt-get install -qq libfluidsynth1 fluid-soundfont-gm build-essential libasound2-dev libjack-dev
!pip install -qU pyfluidsynth pretty_midi

if glob.glob('/content/magenta*.whl'):
  !pip install -qU /content/magenta*.whl
else:
  !pip install -qU magenta

# Hack to allow python to pick up the newly-installed fluidsynth lib.
import ctypes.util
def proxy_find_library(lib):
  if lib == 'fluidsynth':
    return 'libfluidsynth.so.1'
  else:
    return ctypes.util.find_library(lib)

ctypes.util.find_library = proxy_find_library


print 'Importing libraries and defining some helper functions...'
from google.colab import files
import magenta.music as mm
from magenta.music.sequences_lib import concatenate_sequences
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
import numpy as np
import os
import tensorflow as tf


def play(note_sequence):
  mm.play_sequence(note_sequence, synth=mm.fluidsynth)
  
def slerp(p0, p1, t):
  """Spherical linear interpolation."""
  omega = np.arccos(np.dot(np.squeeze(p0/np.linalg.norm(p0)), np.squeeze(p1/np.linalg.norm(p1))))
  so = np.sin(omega)
  return np.sin((1.0-t)*omega) / so * p0 + np.sin(t*omega)/so * p1

def interpolate(model, start_seq, end_seq, num_steps, max_length=32,
                assert_same_length=True, temperature=0.5, 
                individual_duration=4.0):
  """Interpolates between a start and end sequence."""
  _, mu, _ = model.encode([start_seq, end_seq], assert_same_length)
  z = np.array([slerp(mu[0], mu[1], t) for t in np.linspace(0, 1, num_steps)])
  note_sequences = model.decode(
      length=max_length,
      z=z,
      temperature=temperature)

  print 'Start Seq Reconstruction'
  play(note_sequences[0])
  print 'End Seq Reconstruction'
  play(note_sequences[-1])
  print 'Mean Sequence'
  play(note_sequences[num_steps // 2])
  print 'Start -> End Interpolation'
  interp_seq = concatenate_sequences(note_sequences, [individual_duration] * len(note_sequences))
  play(interp_seq)
  mm.plot_sequence(interp_seq)
  return interp_seq if num_steps > 3 else note_sequences[num_steps // 2]

def download(note_sequence, filename):
  mm.sequence_proto_to_midi_file(note_sequence, filename)
  files.download(filename)

print 'Done'

# 2-Bar Drums Model

Below are 4 pre-trained models to experiment with. The first 3 map the 61 MIDI drum "pitches" to a reduced set of 9 classes (bass, snare, closed hi-hat, open hi-hat, low tom, mid tom, high tom, crash cymbal, ride cymbal) for a simplified but less expressive output space. The last model uses a [NADE](http://homepages.inf.ed.ac.uk/imurray2/pub/11nade/) to represent all possible MIDI drum "pitches".

* **drums_2bar_oh_lokl**: This *low* KL model was trained for more *realistic* sampling. The output is a one-hot encoding of 2^9 combinations of hits. It has a single-layer bidirectional LSTM encoder with 512 nodes in each direction, a 2-layer LSTM decoder with 256 nodes in each layer, and a Z with 256 dimensions. During training it was given 0 free bits, and had a fixed beta value of 0.8. After 300k steps, the final accuracy is 0.73 and KL divergence is 11 bits.
* **drums_2bar_oh_hikl**: This *high* KL model was trained for *better reconstruction and interpolation*. The output is a one-hot encoding of 2^9 combinations of hits. It has a single-layer bidirectional LSTM encoder with 512 nodes in each direction, a 2-layer LSTM decoder with 256 nodes in each layer, and a Z with 256 dimensions. During training it was given 96 free bits and had a fixed beta value of 0.2. It was trained with scheduled sampling with an inverse sigmoid schedule and a rate of 1000. After 300k, steps the final accuracy is 0.97 and KL divergence is 107 bits.
* **drums_2bar_nade_reduced**: This model outputs a multi-label "pianoroll" with 9 classes. It has a single-layer bidirectional LSTM encoder with 512 nodes in each direction, a 2-layer LSTM-NADE decoder with 512 nodes in each layer and 9-dimensional NADE with 128 hidden units, and a Z with 256 dimensions. During training it was given 96 free bits and has a fixed beta value of 0.2. It was trained with scheduled sampling with an inverse sigmoid schedule and a rate of 1000. After 300k steps, the final accuracy is 0.98 and KL divergence is 110 bits.
* **drums_2bar_nade_full**:  The output is a multi-label "pianoroll" with 61 classes. A single-layer bidirectional LSTM encoder with 512 nodes in each direction, a 2-layer LSTM-NADE decoder with 512 nodes in each layer and 61-dimensional NADE with 128 hidden units, and a Z with 256 dimensions. During training it was given 0 free bits and has a fixed beta value of 0.2. It was trained with scheduled sampling with an inverse sigmoid schedule and a rate of 1000. After 300k steps, the final accuracy is 0.90 and KL divergence is 116 bits.

In [0]:
# Load the pre-trained models.

# One-hot encoded.
drums_config = configs.CONFIG_MAP['cat-drums_2bar_small']
drums_2bar_oh_lokl = TrainedModel(drums_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/drums_2bar_small.lokl.ckpt')
drums_2bar_oh_hikl = TrainedModel(drums_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/drums_2bar_small.hikl.ckpt')

# Multi-label NADE.
drums_nade_reduced_config = configs.CONFIG_MAP['nade-drums_2bar_reduced']
drums_2bar_nade_reduced = TrainedModel(drums_nade_reduced_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/drums_2bar_nade.reduced.ckpt')
drums_nade_full_config = configs.CONFIG_MAP['nade-drums_2bar_full']
drums_2bar_nade_full = TrainedModel(drums_nade_full_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/drums_2bar_nade.full.ckpt')


## Generate Samples

In [0]:
# Generate 4 samples from the prior of the low KL one-hot model.
drums_2_ol_samples = drums_2bar_oh_lokl.sample(n=4, length=32, temperature=0.5)
for ns in drums_2_ol_samples:
  play(ns)

In [0]:
# Generate 4 samples from the prior of the high KL one-hot model.
for ns in drums_2bar_oh_hikl.sample(n=4, length=32, temperature=0.5):
  play(ns)

In [0]:
# Generate 4 samples from the prior of the 9-class NADE model.
for ns in drums_2bar_nade_reduced.sample(n=4, length=32, temperature=0.5):
  play(ns)

In [0]:
# Generate 4 samples from the prior of the 61-class NADE model.
drums_2_nf_samples = drums_2bar_nade_full.sample(n=4, length=32, temperature=0.5)
for ns in drums_2_nf_samples:
  play(ns)

In [0]:
# Optionally download generated samples from the low KL one-hot model.
for i, ns in enumerate(drums_2_ol_samples):
  download(ns, 'drums_2bar_oh_lokl_sample_%d.mid' % i)

In [0]:
# Optionally download generated samples from the 61-class NADE model.
for i, ns in enumerate(drums_2_nf_samples):
  download(ns, 'drums_2bar_nade_full_sample_%d.mid' % i)

## Generate Interpolations

In [0]:
# Use example MIDI files for interpolation endpoints.
input_drums_midi_data = [
    tf.gfile.Open(fn).read()
    for fn in sorted(tf.gfile.Glob('/content/midi/drums_2bar*.mid'))]

In [0]:
# Optionally upload your own MIDI files to use for interpolation endpoints instead of those provided.
input_drums_midi_data = files.upload().values() or input_drums_midi_data

In [0]:
# Extract drums from MIDI files. This will extract all unique 2-bar drum beats
# using a sliding window with a stride of 1 bar.
drums_input_seqs = [mm.midi_to_sequence_proto(m) for m in input_drums_midi_data]
extracted_beats = []
for ns in drums_input_seqs:
  extracted_beats.extend(drums_nade_full_config.note_sequence_converter.to_notesequences(
      drums_nade_full_config.note_sequence_converter.to_tensors(ns)[1]))
for i, ns in enumerate(extracted_beats):
  print "Beat", i
  play(ns)

In [0]:
# Select the start and end beat for interpolation.
start_beat = extracted_beats[0]
end_beat = extracted_beats[1]

In [0]:
# Interpolate between beats over 13 steps with "low" KL model.
# Will not be very accurate and uses a reduced representation.
drums_2bar_interp_oh_lokl = interpolate(drums_2bar_oh_lokl, start_beat, end_beat, num_steps=13, temperature=0.5)

In [0]:
# Interpolate between beats over 13 steps with "high" KL model.
# Will be much more accurate but uses a reduced represenation.
drums_2bar_interp_oh_hikl = interpolate(drums_2bar_oh_hikl, start_beat, end_beat, num_steps=13, temperature=0.5)

In [0]:
# Interpolate between beats over 13 steps with the NADE model using the reduced representation.
drums_2bar_interp_nade_reduced = interpolate(drums_2bar_nade_reduced, start_beat, end_beat, num_steps=13, temperature=0.5)

In [0]:
# Interpolate between beats over 13 steps with the NADE model using the full representation.
drums_2bar_interp_nade_full = interpolate(drums_2bar_nade_full, start_beat, end_beat, num_steps=13, temperature=0.5)

In [0]:
# Optionally download interpolation MIDI files.
download(drums_2bar_interp_oh_lokl, 'drums_2bar_interp_oh_lokl.mid')
download(drums_2bar_interp_oh_hikl, 'drums_2bar_interp_oh_hikl.mid')
download(drums_2bar_interp_nade_reduced, 'drums_2bar_interp_nade_reduced.mid')
download(drums_2bar_interp_nade_full, 'drums_2bar_interp_nade_full.mid')

# 2-Bar Melody Model

The pre-trained model consists of a single-layer bidirectional LSTM encoder with 2048 nodes in each direction, a 3-layer LSTM decoder with 2048 nodes in each layer, and Z with 512 dimensions. The model was given 0 free bits, and had its beta valued annealed at an exponential rate of 0.99999 from 0 to 0.43 over 200k steps. It was trained with scheduled sampling with an inverse sigmoid schedule and a rate of 1000. The final accuracy is 0.95 and KL divergence is 58 bits.

In [0]:
# Load the pre-trained model.
mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']
mel_2bar = TrainedModel(mel_2bar_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/mel_2bar_big.ckpt')

## Generate Samples

In [0]:
# Generate 4 samples from the prior.
mel_2_samples = mel_2bar.sample(n=4, length=32, temperature=0.5)
for ns in mel_2_samples:
  play(ns)

In [0]:
# Optionally download samples.
for i, ns in enumerate(mel_2_samples):
  download(ns, 'mel_2bar_sample_%d.mid' % i)

## Generate Interpolations

In [0]:
# Use example MIDI files for interpolation endpoints.
input_mel_midi_data = [
    tf.gfile.Open(fn).read()
    for fn in sorted(tf.gfile.Glob('/content/midi/mel_2bar*.mid'))]

In [0]:
# Optionally upload your own MIDI files to use for interpolation endpoints instead of those provided.
input_mel_midi_data = files.upload().values() or input_mel_midi_data

In [0]:
# Extract melodies from MIDI files. This will extract all unique 2-bar melodies
# using a sliding window with a stride of 1 bar.
mel_input_seqs = [mm.midi_to_sequence_proto(m) for m in input_mel_midi_data]
extracted_mels = []
for ns in mel_input_seqs:
  extracted_mels.extend(
      mel_2bar_config.note_sequence_converter.to_notesequences(
          mel_2bar_config.note_sequence_converter.to_tensors(ns)[1]))
for i, ns in enumerate(extracted_mels):
  print "Melody", i
  play(ns)

In [0]:
# Select the start and end melody for interpolation.
start_mel = extracted_mels[0]
end_mel = extracted_mels[1]

In [0]:
# Interpolate between melodies over 13 steps.
mel_2bar_interp = interpolate(mel_2bar, start_mel, end_mel, num_steps=15, temperature=0.5)

In [0]:
# Optionally download interpolation MIDI file.
download(mel_2bar_interp, 'mel_2bar_interp.mid')

# 16-bar Melody Models

The pre-trained hierarchical model consists of a 2-layer stacked bidirectional LSTM encoder with 2048 nodes in each direction for each layer, a 16-step 2-layer LSTM "conductor" decoder with 1024 nodes in each layer, a 2-layer LSTM core decoder with 1024 nodes in each layer, and a Z with 512 dimensions. It was given 256 free bits, and had a fixed beta value of 0.2. After 25k steps, the final accuracy is 0.90 and KL divergence is 277 bits.

In [0]:
# Load the pre-trained models.
hierarch_mel_16bar_config = configs.CONFIG_MAP['hiercat-mel_16bar_big']
hierarch_mel_16bar = TrainedModel(hierarch_mel_16bar_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/hiercat_mel_16bar_big.ckpt')

flat_mel_16bar_config = configs.CONFIG_MAP['cat-mel_16bar_big']
flat_mel_16bar = TrainedModel(flat_mel_16bar_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/cat_mel_16bar_big.ckpt')

## Generate Samples

In [0]:
# Generate 4 samples from the hierarchical prior.
hmel_16_samples = hierarch_mel_16bar.sample(n=4, length=256, temperature=0.5)
for ns in hmel_16_samples:
  play(ns)

In [0]:
# Generate 4 samples from the flat (baseline) prior.
for ns in flat_mel_16bar.sample(n=4, length=256, temperature=0.5):
  play(ns)

In [0]:
# Optionally download hierarchical samples.
for i, ns in enumerate(hmel_16_samples):
  download(ns, 'hierarch_mel_16bar_sample_%d.mid' % i)

## Generate Means

In [0]:
# Use example MIDI files for interpolation endpoints.
input_mel_16_midi_data = [
    tf.gfile.Open(fn).read()
    for fn in sorted(tf.gfile.Glob('/content/midi/mel_16bar*.mid'))]

In [0]:
# Optionally upload your own MIDI files to use for interpolation endpoints instead of those provided.
input_mel_16_midi_data = files.upload().values() or input_mel_16_midi_data

In [0]:
# Extract melodies from MIDI files. This will extract all unique 16-bar melodies
# using a sliding window with a stride of 1 bar.
mel_input_seqs = [mm.midi_to_sequence_proto(m) for m in input_mel_16_midi_data]
extracted_16_mels = []
for ns in mel_input_seqs:
  extracted_16_mels.extend(
      hierarch_mel_16bar_config.note_sequence_converter.to_notesequences(
          hierarch_mel_16bar_config.note_sequence_converter.to_tensors(ns)[1]))
for i, ns in enumerate(extracted_16_mels):
  print "Melody", i
  play(ns)

In [0]:
# Select the start and end melody for interpolation.
start_16_mel = extracted_16_mels[0]
end_16_mel = extracted_16_mels[1]

In [0]:
# Compute the reconstructions and mean of the two melodies from the hierarchical model.
hierarch_mel_16bar_mean = interpolate(hierarch_mel_16bar, start_16_mel, end_16_mel, num_steps=3, max_length=256, individual_duration=32, temperature=0.5)

In [0]:
# Compute the reconstructions and mean of the two melodies from the flat (baseline) model.
flat_mel_16bar_mean = interpolate(flat_mel_16bar, start_16_mel, end_16_mel, num_steps=3, max_length=256, individual_duration=32, temperature=0.5)

In [0]:
# Optionally download mean MIDI file.
download(hierarch_mel_16bar_mean, 'hierarch_mel_16bar_mean.mid')

#16-bar "Trio" Models (lead, bass, drums)

We present two pre-trained models for 16-bar trios: a hierarchical model and a flat (baseline) model.

The pre-trained hierarchical model consists of a 2-layer stacked bidirectional LSTM encoder with 2048 nodes in each direction for each layer, a 16-step 2-layer LSTM "conductor" decoder with 1024 nodes in each layer, 3 (lead, bass, drums) 2-layer LSTM core decoders with 1024 nodes in each layer, and a Z with 512 dimensions. It was given 1024 free bits, and had a fixed beta value of 0.1.  It was trained with scheduled sampling with an inverse sigmoid schedule and a rate of 1000. After 50k steps, the final accuracy is 0.82 for lead, 0.87 for bass, and 0.90 for drums, and the KL divergence is 1027 bits.

The pre-trained flat model consists of a 2-layer stacked bidirectional LSTM encoder with 2048 nodes in each direction for each layer, a 3-layer LSTM decoder with 2048 nodes in each layer, and a Z with 512 dimensions. It was given 1024 free bits, and had a fixed beta value of 0.1.  It was trained with scheduled sampling with an inverse sigmoid schedule and a rate of 1000. After 50k steps, the final accuracy is 0.67 for lead, 0.66 for bass, and 0.79 for drums, and the KL divergence is 1016 bits.

In [0]:
 # Load the pre-trained models.
hierarch_trio_16bar_config = configs.CONFIG_MAP['hiercat-trio_16bar_big']
hierarch_trio_16bar = TrainedModel(hierarch_trio_16bar_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/hiercat_trio_16bar_big.ckpt')

flat_trio_16bar_config = configs.CONFIG_MAP['cat-trio_16bar_big']
flat_trio_16bar = TrainedModel(flat_trio_16bar_config, batch_size=4, checkpoint_dir_or_path='/content/checkpoints/cat_trio_16bar_big.ckpt')

## Generate Samples

In [0]:
# Generate 4 samples from the hierarchical model prior.
htrio_16_samples = hierarch_trio_16bar.sample(n=4, length=256, temperature=0.5)
for ns in htrio_16_samples:
  play(ns)

In [0]:
# Generate 4 samples from the flat (baseline) model prior.
for ns in flat_trio_16bar.sample(n=4, length=256, temperature=0.5):
  play(ns)

In [0]:
# Optionally download hierarchical samples.
for i, ns in enumerate(htrio_16_samples):
  download(ns, 'hierarch_trio_16bar_sample_%d.mid' % i)

## Generate Means

In [0]:
# Use example MIDI files for interpolation endpoints.
input_trio_midi_data = [
    tf.gfile.Open(fn).read()
    for fn in sorted(tf.gfile.Glob('/content/midi/trio_16bar*.mid'))]

In [0]:
# Optionally upload your own MIDI files to use for interpolation endpoints instead of those provided.
input_trio_midi_data = files.upload().values() or input_trio_midi_data

In [0]:
# Extract trios from MIDI files. This will extract all unique 16-bar trios
# using a sliding window with a stride of 1 bar.
trio_input_seqs = [mm.midi_to_sequence_proto(m) for m in input_trio_midi_data]
extracted_trios = []
for ns in trio_input_seqs:
  extracted_trios.extend(
      hierarch_trio_16bar_config.note_sequence_converter.to_notesequences(
          hierarch_trio_16bar_config.note_sequence_converter.to_tensors(ns)[1]))
for i, ns in enumerate(extracted_trios):
  print "Trio", i
  play(ns)

In [0]:
# Select the start and end trio for interpolation.
start_trio = extracted_trios[0]
end_trio = extracted_trios[1]

In [0]:
# Compute the hierarchical model reconstructions and mean of the two trios.
hierarch_trio_16bar_mean = interpolate(hierarch_trio_16bar, start_trio, end_trio, num_steps=3, max_length=256, individual_duration=32, temperature=0.5)

In [0]:
# Compute the flat (baseline) model reconstructions and mean of the two trios.
flat_trio_16bar_mean = interpolate(flat_trio_16bar, start_trio, end_trio, num_steps=3, max_length=256, individual_duration=32, temperature=0.5)

In [0]:
# Optionally download mean MIDI file.
download(hierarch_trio_16bar_mean, 'hierarch_trio_16bar_mean.mid')