<a href="https://colab.research.google.com/github/minimario/vae-bach/blob/main/vae_bach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [53]:
%pip install --upgrade music21

Collecting music21
[?25l  Downloading https://files.pythonhosted.org/packages/f2/0e/b9bf3530203f6e6ed1f04d4352ac421aef2429ab77c416ff583dd6d58597/music21-6.3.0.tar.gz (19.2MB)
[K     |████████████████████████████████| 19.2MB 1.4MB/s 
Collecting webcolors
  Downloading https://files.pythonhosted.org/packages/12/05/3350559de9714b202e443a9e6312937341bd5f79f4e4f625744295e7dd17/webcolors-1.11.1-py3-none-any.whl
Building wheels for collected packages: music21
  Building wheel for music21 (setup.py) ... [?25l[?25hdone
  Created wheel for music21: filename=music21-6.3.0-cp36-none-any.whl size=21888021 sha256=f34dad481f42726171089a178b2702c0eab6962df92764db71cb0dd2f8367e6d
  Stored in directory: /root/.cache/pip/wheels/02/e8/2c/eed32afec2b6c6f945a17280c4e4df1cf2e8cd15ebe1025680
Successfully built music21
Installing collected packages: webcolors, music21
  Found existing installation: music21 5.5.0
    Uninstalling music21-5.5.0:
      Successfully uninstalled music21-5.5.0
Successfully insta

In [2]:
print(music21.__version__)

6.3.0


In [1]:
import music21
import music21 as m21
from music21 import *
import keras
import numpy as np
paths = corpus.getComposer('bach')

In [3]:
bach = []
for i in paths:
    s = corpus.parse(i)
    bach.append(s)

In [4]:
def transpose(song):
    parts = song.getElementsByClass(music21.stream.Part)
    measures_part0 = parts[0].getElementsByClass(music21.stream.Measure)
    
    key = measures_part0[0].getElementsByClass(music21.key.Key)
    
    if len(key) != 0:
        key = key[0]
    else:
        key = song.analyze("key")

    # get interval for transposition. E.g., Bmaj -> Cmaj
    if key.mode == "major":
        interval = music21.interval.Interval(key.tonic, music21.pitch.Pitch("C"))
    elif key.mode == "minor":
        interval = music21.interval.Interval(key.tonic, music21.pitch.Pitch("A"))

    # transpose song by calculated interval
    tranposed_song = song.transpose(interval)
    return tranposed_song

In [5]:
def encode_song(song, time_step=0.25):
    """Converts a score into a time-series-like music representation. Each item in the encoded list represents 'min_duration'
    quarter lengths. The symbols used at each step are: integers for MIDI notes, 'r' for representing a rest, and '_'
    for representing notes/rests that are carried over into a new time step. Here's a sample encoding:

        ["r", "_", "60", "_", "_", "_", "72" "_"]

    :param song (m21 stream): Piece to encode
    :param time_step (float): Duration of each time step in quarter length
    :return:
    """

    encoded_song = []

    for event in song.flat.notesAndRests:

        # handle notes
        if isinstance(event, m21.note.Note):
            symbol = event.pitch.midi # 60
        # handle rests
        elif isinstance(event, m21.note.Rest):
            symbol = "r"

        # convert the note/rest into time series notation
        steps = int(event.duration.quarterLength / time_step)
        for step in range(steps):

            # if it's the first time we see a note/rest, let's encode it. Otherwise, it means we're carrying the same
            # symbol in a new time step
            if step == 0:
                encoded_song.append(symbol)
            else:
                encoded_song.append("_")

    # cast encoded song to str
    encoded_song = list(map(str, encoded_song))

    return encoded_song


In [24]:
bach_transposed = []
for s in bach:
    bach_transposed.append(transpose(s))

In [25]:
def transpose_octave(encoding, direction):
    """
    takes in a string encoding and transposes it up/down an octave
    direction: one of 'up', 'down'
    """
    transposed_encoding = []
    for i in encoding:
        if i in '_r':
            transposed_encoding.append(i)
        else:
            if direction == 'up':
                transposed_encoding.append(str(int(i)+12))
            else:
                transposed_encoding.append(str(int(i)-12))
    return transposed_encoding

In [26]:
encodings = []

for b in bach_transposed:
    for part in b.parts:
      if part.partName in ['Soprano', 'Tenor', 'Alto', 'Bass']:
        part_encoded = encode_song(part)
        if part.partName == 'Soprano':
            encodings.append(transpose_octave(part_encoded, 'down'))
        elif part.partName == 'Bass':
            encodings.append(transpose_octave(part_encoded, 'up'))
        else:
            encodings.append(part_encoded)

In [51]:
import itertools
mappings = {}
inv_mappings = {}
distinct_symbols = list(set(itertools.chain.from_iterable(encodings)))
for i, s in enumerate(distinct_symbols):
    mappings[s] = i
    inv_mappings[i] = s

In [55]:
def compress(songs):
  return [mappings[symbol] for symbol in songs]

def decompress(ints):
  return [inv_mappings[num] for num in ints]
  

In [29]:
mapped_encodings = list(map(compress, encodings))

In [33]:
len_64_segments = []
for encoding in mapped_encodings:
    for j in range(0, len(encoding), 4):
        if j+64 <= len(encoding) and encoding[j] != mappings['_']:
            len_64_segments.append(encoding[j:j+64])

In [34]:
inputs = np.array(len_64_segments)
inputs_cat = keras.utils.to_categorical(inputs)
print(inputs_cat.shape)

(57336, 64, 45)


In [90]:
import random
r = random.randint(0, len(inputs_cat)-1)
print(inputs[r])

[ 1 39 39 39  1 39 19 39 24 39 39 39 19 39  1 39 43 39 39 39 21 39 24 39
 19 39 39 39 24 39 39 39  1 39  4 39 24 39 39 39  9 39 34 39 16 39 39 39
 16 39 39 39 43 39  1 39  4 39 24 39 16 39 39 39]


In [35]:
inputs_cat = inputs_cat.reshape((inputs_cat.shape[0], -1))
print(inputs_cat.shape)

(57336, 2880)


In [36]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [98]:
input_shape = inputs_cat.shape[1]

latent_dim = 256

encoder_inputs = keras.Input(shape=(input_shape,1))
x = layers.Conv1D(32, 8, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(64, 8, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           [(None, 2880, 1)]    0                                            
__________________________________________________________________________________________________
conv1d_10 (Conv1D)              (None, 1440, 32)     288         input_11[0][0]                   
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 1440, 32)     128         conv1d_10[0][0]                  
__________________________________________________________________________________________________
conv1d_11 (Conv1D)              (None, 720, 64)      16448       batch_normalization_3[0][0]      
____________________________________________________________________________________________

In [99]:

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(720 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((720, 64))(x)
x = layers.Conv1DTranspose(64, 8, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv1DTranspose(32, 8, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
decoder_outputs = layers.Conv1DTranspose(1, 8, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_12 (InputLayer)        [(None, 256)]             0         
_________________________________________________________________
dense_9 (Dense)              (None, 46080)             11842560  
_________________________________________________________________
reshape_4 (Reshape)          (None, 720, 64)           0         
_________________________________________________________________
conv1d_transpose_12 (Conv1DT (None, 1440, 64)          32832     
_________________________________________________________________
batch_normalization_5 (Batch (None, 1440, 64)          256       
_________________________________________________________________
conv1d_transpose_13 (Conv1DT (None, 2880, 32)          16416     
_________________________________________________________________
batch_normalization_6 (Batch (None, 2880, 32)          128 

In [100]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            recon_weight = 0.8
            total_loss = recon_weight * reconstruction_loss + (1 - recon_weight) * kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

In [101]:
vae_inputs = inputs_cat[:, :, np.newaxis]
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(vae_inputs, epochs=30, batch_size=128)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
 91/448 [=====>........................] - ETA: 35s - loss: 0.0250 - reconstruction_loss: 0.0302 - kl_loss: 0.0042

KeyboardInterrupt: ignored

In [102]:
latent_input = np.random.normal(0, 1, (1, latent_dim))
decoded = decoder.predict(latent_input).reshape((64, 45))
melody = np.argmax(decoded, axis=1)
song = decompress(list(melody))
save_melody(song)

<music21.stream.Stream 0x7f7d4b46dac8>

In [103]:
def save_melody(melody, step_duration=0.25, format="midi", file_name="mel.mid"):
    """Converts a melody into a MIDI file

    :param melody (list of str):
    :param min_duration (float): Duration of each time step in quarter length
    :param file_name (str): Name of midi file
    :return:
    """

    # create a music21 stream
    stream = m21.stream.Stream()

    start_symbol = None
    step_counter = 1

    # parse all the symbols in the melody and create note/rest objects
    for i, symbol in enumerate(melody):
        # handle case in which we have a note/rest
        if symbol != "_" or i + 1 == len(melody):
            # ensure we're dealing with note/rest beyond the first one
            if start_symbol is not None:
                quarter_length_duration = step_duration * step_counter # 0.25 * 4 = 1
                # handle rest
                if start_symbol == "r":
                    m21_event = m21.note.Rest(quarterLength=quarter_length_duration)
                # handle note
                else:
                    m21_event = m21.note.Note(int(start_symbol), quarterLength=quarter_length_duration)
                stream.append(m21_event)
                # reset the step counter
                step_counter = 1
            start_symbol = symbol
        # handle case in which we have a prolongation sign "_"
        else:
            step_counter += 1
    # write the m21 stream to a midi file
    stream.write(format, file_name)
    return stream

In [104]:
print(song)
s = save_melody(song)
for i in s:
  print(i)

['60', '_', '_', '_', '_', '_', '_', '_', '60', '_', '_', '_', '60', '_', '_', '_', '_', '_', '_', '_', '60', '_', '_', '_', '_', '_', '_', '_', '60', '_', '_', '_', '60', '_', '_', '_', '_', '_', '_', '_', '60', '_', '_', '_', '_', '_', '_', '_', '60', '_', '_', '_', '60', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_']
<music21.note.Note C>
<music21.note.Note C>
<music21.note.Note C>
<music21.note.Note C>
<music21.note.Note C>
<music21.note.Note C>
<music21.note.Note C>
<music21.note.Note C>
<music21.note.Note C>
