In [1]:
%pip install wandb music21



# Import Necessary libraries

- **numpy** (required)
- **tensorflow** (required)
- **wandb (optional)**
- **matplotlib** (optional, may required if generating new scores)

- **music21** (Required for constructing `midi` files)

**Custom utility functions:**

* notes_to_midi()
* draw_score()

In [4]:
import os                # creating, writing & storing files and directories
import numpy as np
import tensorflow as tf
import wandb
import keras
from matplotlib import pyplot as plt
import music21

from keras import (
    layers,
    models,
    optimizers,
    callbacks,
    initializers,
    metrics
)
from tensorflow.python.ops.linalg.linalg_impl import sqrtm
from wandb.integration.keras import WandbMetricsLogger
from google.colab import drive

In [5]:
drive.mount("/content/drive")
dataset_path = "/content/drive/MyDrive/DC_GAN_MUSIC/Jsb16thSeparated.npz"
model_checkpoint_path = "/content/drive/MyDrive/DC_GAN_MUSIC/model_checkpoints"
tf_logs_path = "/content/drive/MyDrive/DC_GAN_MUSIC/logs"
output_dir = "/content/drive/MyDrive/DC_GAN_MUSIC/output"

Mounted at /content/drive


In [43]:
def binarize_output(output_tensor):
    """
    :param output_tensor: A vector of shape (batch_size, bars, steps, pitches, tracks)
    :return: Position of the maximum pitch value from the tracks.
    """
    return np.argmax(output_tensor, axis=3)


def notes_to_midi(output, n_bars: int, n_tracks: int, steps_per_bar: int, filename: str):
    """
    :param output: An output tensor of shape (batch_size, bars, steps, pitches, tracks)
    :param n_bars: no of bars per score i.e (axis=1)
    :param n_tracks: no of tracks per score i.e (axis=4 or -1)
    :param steps_per_bar: no of steps per bar per score i.e (axis=2)
    :param filename: the path or the destination to write into (defaults to output_dir)
    :return: A constructed `.mid` file that you can listen to.
    """

    for scores in range(len(output)):
        max_pitches = binarize_output(output)
        midi_note_score =  max_pitches[scores].reshape(n_bars * steps_per_bar, n_tracks)

        parts = music21.stream.Score()
        parts.append(music21.tempo.MetronomeMark(number=66))
        for i in range(n_tracks):
            last_x = int(midi_note_score[:, i][0])
            note_stream = music21.stream.Part()
            duration = 0
            for pos, track in enumerate(midi_note_score[:, i]):
                x = int(track)
                if (x != last_x or pos % 4 == 0) and pos:
                    note = music21.note.Note(last_x)
                    note.duration = music21.duration.Duration(duration)
                    note_stream.append(note)
                    duration = 0
                last_x = x
                duration += .25
            note = music21.note.Note(last_x)
            note.duration = music21.duration.Duration(duration)
            note_stream.append(note)
            parts.append(note_stream)
       # os.makedirs(output_dir, exist_ok=True)
        parts.write("midi", fp=f"{output_dir}/{filename}_{scores}")



def draw_bar(data, score, bar, part):
    plt.imshow(
        data[score, bar, :, :, part].transpose([1, 0]),
        origin="lower",
        cmap=plt.get_cmap("jet"),
        vmin=-1,
        vmax=1
    )


def draw_score(data, score):
    """

    :param data: A tensor of shape (batch_size, bars, steps, pitches, tracks)
    :param score: The current score
    :return: A matplotlib pyplot of the generated score
    """
    num_bars, num_tracks = data.shape[1], data.shape[-1]

    fig, ax = plt.subplots(num_tracks, num_bars, figsize=(12, 8), sharex=True, sharey=True)
    fig.subplots_adjust(0, 0, .2, 1.5, 0, 0)
    plt.style.use('ggplot')

    for bar in range(num_bars):
        for track in range(num_tracks):
            if num_bars > 1:
                ax[track, bar].imshow(
                    data[score, bar, :, :, track].transpose([1, 0]),
                    origin="lower",
                    cmap=plt.get_cmap("jet")
                )

            else:
                ax[track].imshow(
                    data[score, bar, :, :, track].transpose([1, 0]),
                    origin="lower",
                    cmap=plt.get_cmap("Greys")
                )



def plot_score(data, title="Generated Score"):
    """
    Plots a multi-track musical score in a more aesthetic and helpful way.

    :param data: A tensor of shape (bars, steps, pitches, tracks)
    :param title: The title for the plot.
    """
    # --- 1. Prepare the data ---
    # The input 'data' is now expected to be a single score (bars, steps, pitches, tracks)
    score = data
    num_bars, num_steps, num_pitches, num_tracks = score.shape

    # Create a 2D piano roll for each track
    piano_rolls = []
    for track in range(num_tracks):
        # Reshape from (bars, steps, pitches) to (bars * steps, pitches)
        track_data = score[:, :, :, track].reshape(-1, num_pitches)
        piano_rolls.append(track_data.T) # Transpose to have pitches on y-axis and time on x-axis

    # Vertically concatenate the piano rolls for each track
    combined_roll = np.vstack(piano_rolls)


    # --- 2. Create the plot ---
    fig, ax = plt.subplots(figsize=(16, 8))
    plt.style.use('ggplot') # A dark theme often looks better

    # Use a perceptually uniform colormap like 'magma' or 'viridis'
    im = ax.imshow(
        combined_roll,
        aspect='auto',
        origin='lower',
        cmap=plt.get_cmap('magma'),
        interpolation='nearest' # Use 'nearest' to get sharp note edges
    )

    # --- 3. Add helpful labels and titles ---
    # Adjust y-ticks to represent tracks, accounting for concatenated piano rolls
    track_height = num_pitches
    ax.set_yticks([i * track_height + track_height / 2 for i in range(num_tracks)])
    ax.set_yticklabels([f'Track {i+1}' for i in range(num_tracks)])
    ax.set_ylabel("Instrument Tracks")

    # Set x-axis ticks to mark the beginning of each bar
    ax.set_xticks(np.arange(0, num_bars * num_steps, num_steps))
    ax.set_xticklabels([f'Bar {i+1}' for i in range(num_bars)])
    ax.set_xlabel("Time")
    ax.minorticks_on()

    fig.suptitle(title, fontsize=16)

    # --- 4. Clean up aesthetics ---
    # Add horizontal lines to separate tracks
    for i in range(num_tracks):
        ax.axhline(y=(i + 1) * track_height - 0.5, color='gray', linestyle='--', linewidth=0.5)


    # Remove the top and right spines for a cleaner look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for the suptitle
    plt.show()

In [7]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvedantgade2006[0m ([33mvedantgade2006-secant[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
from google.colab import userdata
userdata.get('WANDB_API_KEY')

wandb.init(
    project='DC-GAN-Music',
    name="DC-GAN-Music",
    config={
        "z_dim": 32,
        "critic_steps": 5,
        "gp_weight": 10,
        "generator_lr": .001,
        "critic_lr": .001,
        "beta1": .5,
        "beta2": .9,
        "epochs": 5000,
        "batch_size": 64,
        "dataset": "JS Bach Chorale Dataset",
        "architecture": "GAN"
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mvedantgade2006[0m ([33mvedantgade2006-secant[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
#Define HyperParameter Constants
TRAIN_BATCH_SIZE = 64
N_BARS = 2               # NO OF BARS PER TRACK
N_STEPS_PER_BAR = 16     # NO OF TIME-STEPS PER BAR
MAX_PITCH = 83           # MAXIMUM MIDI PITCH USED
N_PITCHES = MAX_PITCH + 1
Z_DIM = 32               # DIMENSION OF THE LATENT SPACE

CRITIC_STEPS = 5
GP_WEIGHT = 10
CRITIC_LR = .0001
GENERATOR_LR = .0001
ADAM_BETA1 = .5
ADAM_BETA2 = .99
NUM_EPOCHS = 500
LOAD_MODEL = False

In [11]:
#DATASET_PATH='src/Jsb16thSeparated.npz'

with np.load(dataset_path, encoding="bytes", allow_pickle=True) as f:
    train_data = f['train']
    test_data =  f['test']


N_SONGS = len(train_data)
print(f"There are {N_SONGS} songs in train set.")
chorale = train_data[0]
N_BEATS, N_TRACKS = chorale.shape
print(f"{N_BEATS} beats and {N_TRACKS} tracks.")
print(f"\nTrain Chorale: 0")
print("\n", chorale[:10])


N_SONGS2 = len(test_data)
print(f"There are {N_SONGS} songs in test set.")
chorale2 = test_data[0]
N_BEATS, N_TRACKS = chorale2.shape
print(f"{N_BEATS} beats and {N_TRACKS} tracks.")
print(f"\nTest Chorale: 0")
print("\n", chorale2[:10])


There are 229 songs in train set.
192 beats and 4 tracks.

Train Chorale: 0

 [[74. 70. 65. 58.]
 [74. 70. 65. 58.]
 [74. 70. 65. 58.]
 [74. 70. 65. 58.]
 [75. 70. 58. 55.]
 [75. 70. 58. 55.]
 [75. 70. 60. 55.]
 [75. 70. 60. 55.]
 [77. 69. 62. 50.]
 [77. 69. 62. 50.]]
There are 229 songs in test set.
228 beats and 4 tracks.

Test Chorale: 0

 [[65. 60. 57. 53.]
 [65. 60. 57. 53.]
 [65. 60. 57. 53.]
 [65. 60. 57. 53.]
 [72. 60. 55. 52.]
 [72. 60. 55. 52.]
 [70. 60. 55. 52.]
 [70. 60. 55. 52.]
 [69. 60. 53. 53.]
 [69. 60. 53. 53.]]


In [12]:
two_bars = np.array([x[:(N_BARS * N_STEPS_PER_BAR)] for x in train_data])
two_bars = np.array(np.nan_to_num(two_bars, nan=MAX_PITCH), dtype=np.int32)
two_bars = two_bars.reshape([N_SONGS, N_BARS, N_STEPS_PER_BAR, N_TRACKS])
print(f"2 bars shape (no of songs * no of bars * steps per bar * no of tracks): {two_bars.shape}")

binary_data = np.eye(N_PITCHES)[two_bars]
#print(binary_data.shape)
#binary_data[binary_data == 0] = -1
binary_data = np.where(binary_data == 0, -1, binary_data)
#print(binary_data)
binary_data = binary_data.transpose([0, 1, 2, 4, 3])
#print(binary_data)
print(f"\nTransformed Train Binary data shape (songs * bars * steps * pitches, tracks):\n{binary_data.shape}")


# Pre-process Test & Validation subset to match train data.

two_bars_test = np.array([x[:(N_BARS * N_STEPS_PER_BAR)] for x in test_data])
two_bars_test = np.array(np.nan_to_num(two_bars_test, nan=MAX_PITCH), dtype=np.int32)
two_bars_test = two_bars_test.reshape([len(test_data), N_BARS, N_STEPS_PER_BAR, N_TRACKS])
print(f"\n2 bars shape (no of songs * no of bars * steps per bar * no of tracks): {two_bars_test.shape}")
test_binary_data = np.eye(N_PITCHES)[two_bars_test]
test_binary_data = np.where(test_binary_data == 0, -1, test_binary_data)
test_binary_data = test_binary_data.transpose([0, 1, 2, 4, 3])

print(f"\nTransformed Test Binary data shape (songs * bars * steps * pitches, tracks):\n{test_binary_data.shape}")


2 bars shape (no of songs * no of bars * steps per bar * no of tracks): (229, 2, 16, 4)

Transformed Train Binary data shape (songs * bars * steps * pitches, tracks):
(229, 2, 16, 84, 4)

2 bars shape (no of songs * no of bars * steps per bar * no of tracks): (77, 2, 16, 4)

Transformed Test Binary data shape (songs * bars * steps * pitches, tracks):
(77, 2, 16, 84, 4)


In [None]:
# Convert the test set into a tensorflow dataset for validation and FID callback access.
dummy_y = tf.zeros(shape=len(test_binary_data), dtype=tf.float64)
validation_set = tf.data.Dataset.from_tensor_slices((test_binary_data,  dummy_y))
validation_set.batch(TRAIN_BATCH_SIZE)
list(validation_set.as_numpy_iterator())

In [14]:
# Define some helper functions for our GAN network
initializer = initializers.RandomNormal(mean=0.0, stddev=0.02)

def conv(x, f, k, s, p):
    """
    :param x: A tensor of shape (batch_size, num_bars, steps_per_bar, num_pitches, num_tracks).
    :param f: Number of filters.
    :param k: Defines the size of the convolutional filter that strides along the input data.
    :param s: How many strides (the no of data values across all dimensions) the kernel passes through.
    :param p: Padding used by the kernel, can be either of 'valid','same' or 'zeros'.
    ----
    :return: A 3d convoluted tensor using leaky relu as activation function.
    """
    x = layers.Conv3D(
        filters=f,
        kernel_size=k,
        strides=s,
        padding=p,
        kernel_initializer=initializer
    )(x)

    x = layers.LeakyReLU()(x)
    return x



def conv_t(x, f, k, s, a, p, bn):
    """
    :param x: A tensor of shape (batch_size, num_bars, steps_per_bar, num_pitches).
    :param f: Number of convolutional filters to use.
    :param k: Defines the size of the convolutional filter that strides along the input data.
    :param s: How many strides (the no of data values across all dimensions) the kernel passes through.
    :param a: Activation function to use can be either of relu or tanh (GAN relative).
    :param p: Padding used by the kernel, can be either of 'valid','same' or 'zeros'.
    :param bn: Whether to use batch normalization or not.
    :return:
    A 2d Convolution Transposed Tensor of shape (batch_size, num_bars, steps_per_bar, num_pitches).
    """
    x = layers.Conv2DTranspose(
        filters=f,
        kernel_size=k,
        padding=p,
        strides=s,
        kernel_initializer=initializer
    )(x)
    if bn:
        x = layers.BatchNormalization(momentum=.9)(x)
    x = layers.Activation(a)(x)
    return x

## **Temporal Network**

A Convolutional Network block used to generate an noise latent vector which will be used to handle chord & melody vectors of the Generator.

---
## **Bar Generator**

Used to generate noise vectors sampled from the standard normal distribution N(0, 1).
Converts Input noise vector of shape 128 (4x32) into a vector that's used by the generator input to generate each of chords, style, melody & groove for the score.


In [15]:
def temporal_network():
    input_layer = layers.Input(shape=(Z_DIM,), name="temp_input")
    x = layers.Reshape([1,1,Z_DIM], name="temp_reshape")(input_layer)
    x = conv_t(x, f=1024, k=(2,1), s=1, a="relu", p="valid", bn=True)
    x = conv_t(x, f=Z_DIM, k=1, s=1, a="relu", p="valid", bn=True)
    output_layer = layers.Reshape([N_BARS, Z_DIM])(x)
    return models.Model(inputs=input_layer, outputs=output_layer)


#temporal_network().summary()


def bar_generator():
    input_layer = layers.Input(shape=(Z_DIM*4,), name="bar_gen_input") # each of chord, style, groove & melody

    x = layers.Dense(1024)(input_layer)
    x = layers.BatchNormalization(momentum=.9)(x)
    x = layers.Activation("relu")(x)
    x = layers.Reshape([2, 1, 512])(x)

    x = conv_t(x, f=512 , k=(2,1) , s=(2,1) , a='relu',p='same', bn=True)
    x = conv_t(x, f=256 , k=(2,1) , s=(2,1) , a='relu',p='same', bn=True)
    x = conv_t(x, f=256 , k=(2,1) , s=(2,1) , a='relu',p='same', bn=True)
    x = conv_t(x, f=256 , k=(1,7) , s=(1,7) , a='relu',p='same', bn=True)
    x = conv_t(x, f=1, k=(1,12) , s=(1,12) , a='tanh',p='same', bn=False)
    output_layer = layers.Reshape([1, N_STEPS_PER_BAR, N_PITCHES, 1], name="BarGeneratorOutput")(x)

    return models.Model(inputs=input_layer, outputs=output_layer)


#bar_generator().summary()

In [16]:
def Generator():
    """
     :returns:
      A joined noise vector sampled from Z_DIM each consisting of [chords, style, melody, groove]:

     :input:

     - chords: Controls the general progression of music
              (rhythm, pitch, tempo, time_steps) across each bar & track.
              Useful for providing control over each bar and track.\n

     - style: Controls the overall behaviour of the song/music.
             Is applied throughout the musical sheet.

     - melody: Use to control track specific or bar relative features of
              melody including change in pitch value, time, quavers
              etc. Relative to each track.

     - groove: An additional noise vector of shape (n_tracks * z_dim).
              Used to control bar relative or track relative features
              and is applied throughout the music sheet altering the
              overall style of the track across the song.
    """

    chords_input = layers.Input(shape=(Z_DIM,), name="chords_input")
    style_input = layers.Input(shape=(Z_DIM,), name="style_input")
    melody_input = layers.Input(shape=(N_TRACKS, Z_DIM), name="melody_input")
    groove_input = layers.Input(shape=(N_TRACKS, Z_DIM), name="groove_input")

    # CHORDS -> TEMPORAL NETWORK
    chords_tempNetwork = temporal_network()
    chords_over_time = chords_tempNetwork(chords_input)  # [n_bars, z_dim]

    # MELODY -> TEMPORAL NETWORK
    melody_over_time = [None] * N_TRACKS  # list of n_tracks [n_bars z_dim] tensors
    melody_tempNetwork = [None] * N_TRACKS
    for track in range(N_TRACKS):
        melody_tempNetwork[track] = temporal_network()
        melody_track = layers.Lambda(lambda x, track = track: x[:, track, :])(
            melody_input
        )
        melody_over_time[track] = melody_tempNetwork[track](melody_track)

    # CREATE BAR GENERATOR FOR EACH TRACK
    barGen = [None] * N_TRACKS
    for track in range(N_TRACKS):
        barGen[track] = bar_generator()

    # CREATE OUTPUT FOR EVERY TRACK AND BAR
    bars_output = [None] * N_BARS
    c = [None] * N_BARS
    for bar in range(N_BARS):
        track_output = [None] * N_TRACKS

        c[bar] = layers.Lambda(lambda x, bar=bar: x[:, bar, :])(
            chords_over_time
        )  # [z_dim]
        s = style_input  # [z_dim]

        for track in range(N_TRACKS):
            melody = layers.Lambda(lambda x, bar=bar: x[:, bar, :])(
                melody_over_time[track]
            )  # [z_dim]
            groove = layers.Lambda(lambda x, track=track: x[:, track, :])(
                groove_input
            )  # [z_dim]

            z_input = layers.Concatenate(
                axis=1
            )([c[bar], s, melody, groove])

            track_output[track] = barGen[track](z_input)

        bars_output[bar] = layers.Concatenate(axis=-1)(track_output)

    generator_output = layers.Concatenate(axis=1)(bars_output)

    return models.Model(
        inputs=[chords_input, style_input, melody_input, groove_input],
        outputs=generator_output,
        name="generator_model"
    )

In [17]:
def Critic():
    critic_input = layers.Input(shape=(N_BARS, N_STEPS_PER_BAR, N_PITCHES, N_TRACKS), name="critic_input")

    x = conv(x=critic_input, f=128, k=(2,1,1), s=1, p="valid")
    x = conv(x=x, f=128, k=(N_BARS-1,1,1), s=1, p="valid")
    x = conv(x=x, f=128, k=(1,1,12), s=(1,1,12), p="same")
    x = conv(x=x, f=128, k=(1,1,7), s=(1,1,7), p="same")
    x = conv(x=x, f=128, k=(1,2,1), s=(1,2,1), p="same")
    x = conv(x=x, f=128, k=(1,2,1), s=(1,2,1), p="same")
    x = conv(x=x, f=256, k=(1,4,1), s=(1,2,1), p="same")
    x = conv(x=x, f=512, k=(1,3,1), s=(1,2,1), p="same")

    x = layers.Flatten(name="flatten")(x)
    x = layers.Dense(1024, kernel_initializer=initializer)(x)
    x = layers.LeakyReLU()(x)
    x = layers.Dropout(.2)(x)

    critic_output = layers.Dense(1, activation=None, kernel_initializer=initializer)(x)

    return models.Model(inputs=critic_input, outputs=critic_output, name="critic_model")


Critic().summary()

In [18]:
class MuseGAN(models.Model):
    def __init__(self, generator_model, critic_model, latent_dim, critic_steps, gp_weight, **kwargs):
        super(MuseGAN, self).__init__(**kwargs)e
        self.generator_model = generator_model
        self.critic_model = critic_model
        self.latent_dim = latent_dim
        self.critic_steps = critic_steps
        self.gp_weight = gp_weight
        feature_layer = self.critic_model.get_layer('flatten').output # Capture the gradients before passing it to the dense layer.
        self.fid_feature_extractor = models.Model(
            inputs=self.critic_model.input,  # None, bars, steps, pitches, tracks
            outputs=feature_layer,
            name="fid_feature_extractor"
        )
        self.fid_metric = metrics.Mean(name="fid_score")
        self._built = False


    def compile(self, critic_optimizer, gen_optimizer):
        super(MuseGAN, self).compile()
        self.critic_optimizer = critic_optimizer
        self.gen_optimizer = gen_optimizer

        self.critic_wassertein_loss_metric = metrics.Mean(name="critic_wassertein_loss")
        self.critic_gradient_penalty_metric = metrics.Mean(name="critic_gradient_penalty")

        self.critic_loss_metric = metrics.Mean(name="critic_loss")
        self.gen_loss_metric = metrics.Mean(name="gen_loss")



    @property
    def metrics(self):
        super(MuseGAN, self).metrics
        return [
            self.critic_loss_metric,
            self.critic_wassertein_loss_metric,
            self.critic_gradient_penalty_metric,
            self.gen_loss_metric,
            self.fid_metric
        ]


    def gradient_penalty(self, batch_size, real_images, fake_images):
        """
        :usage:
        -------
        * The core idea of WGAN-GP is to force the norm (or magnitude) of this
        gradient to be exactly 1.
        * This ensures the critic's feedback is smooth and doesn't change too
        erratically.
        * The code calculates how far the gradient's norm is from 1, squares it,
        and this becomes the penalty.
        * This penalty is then added to the critic's loss, pushing the critic to
        learn a smoother scoring function.
        -------
        """
        alpha = tf.random.normal(shape=(batch_size, 1, 1, 1, 1), mean=0.0, stddev=1.0)

        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            predictions = self.critic_model(interpolated, training=True)

        gradients = gp_tape.gradient(predictions, [interpolated])[0]
        rms_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3,4]))
        gp = tf.reduce_mean((rms_norm - 1.0) ** 2)
        return gp



    def calculate_fid(self, real_images, fake_images):
        '''
        Calculates the FID (Frechet Inception Distance) between generated & target samples.
        '''
        # Get the feature activations
        real_features = self.fid_feature_extractor.predict(real_images)
        fake_features = self.fid_feature_extractor.predict(fake_images)

        mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
        mu_fake, sigma_fake = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)
        sum_sq_diff = np.sum((mu_real - mu_fake) ** 2.0)
        cov_sqrt = sqrtm(sigma_real.dot(sigma_fake))
        # Check for complex numbers and take the real part
        if np.iscomplexobj(cov_sqrt):
            cov_sqrt = cov_sqrt.real

        fid = sum_sq_diff + np.trace(sigma_real + sigma_fake - 2.0 * cov_sqrt)
        return fid


    def generate_random_latent_vectors(self, batch_size):
        return [
            tf.random.normal(shape=(batch_size, self.latent_dim)),
            tf.random.normal(shape=(batch_size, self.latent_dim)),
            tf.random.normal(shape=(batch_size, N_TRACKS, self.latent_dim)),
            tf.random.normal(shape=(batch_size, N_TRACKS, self.latent_dim))
        ]


    @tf.function(jit_compile=True)
    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]

        # Build the optimizers on the first training step
        if not self._built:
            self.critic_optimizer.build(self.critic_model.trainable_variables)
            self.gen_optimizer.build(self.generator_model.trainable_variables)
            self._built = True

        for step in range(self.critic_steps):
            random_latent_vectors = self.generate_random_latent_vectors(batch_size)

            with tf.GradientTape() as tape:
                fake_images = self.generator_model(random_latent_vectors, training=True)
                fake_predictions = self.critic_model(fake_images, training=True)
                real_predictions = self.critic_model(real_images, training=True)

                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                critic_loss = c_wass_loss + gp * self.gp_weight

            critic_gradients = tape.gradient(critic_loss, self.critic_model.trainable_variables)
            self.critic_optimizer.apply_gradients(zip(critic_gradients, self.critic_model.trainable_variables))

            random_latent_vectors = self.generate_random_latent_vectors(batch_size)

            with tf.GradientTape() as tape:
                fake_images = self.generator_model(random_latent_vectors, training=True)
                fake_predictions = self.critic_model(fake_images, training=True)
                gen_loss = -tf.reduce_mean(fake_predictions)

            gen_gradients = tape.gradient(gen_loss, self.generator_model.trainable_variables)
            self.gen_optimizer.apply_gradients(zip(gen_gradients, self.generator_model.trainable_variables))
            # Update metrics
            self.critic_loss_metric.update_state(critic_loss)
            self.critic_wassertein_loss_metric.update_state(c_wass_loss)
            self.critic_gradient_penalty_metric.update_state(gp)
            self.gen_loss_metric.update_state(gen_loss)

            return {m.name: m.result() for m in self.metrics}
        return None


    def generate_piano_roll(self, num_scores):
        crlv = tf.random.normal(shape=(num_scores, Z_DIM))
        srlv = tf.random.normal(shape=(num_scores, Z_DIM))
        mrlv = tf.random.normal(shape=(num_scores,N_TRACKS,Z_DIM))
        grlv = tf.random.normal(shape=(num_scores,N_TRACKS,Z_DIM))
        random_latent_vectors = [crlv,srlv,mrlv,grlv]

        return self.generator_model(random_latent_vectors).numpy()

In [19]:
class FIDCallback(keras.callbacks.Callback):
    def __init__(self, validation_data, num_samples=1024):
        super(FIDCallback, self).__init__()
        self.real_samples = validation_data.take(num_samples // TRAIN_BATCH_SIZE).map(lambda x, y: x)
        self.num_samples = num_samples


    def on_epoch_end(self, epoch, logs=None):
        super(FIDCallback, self).on_epoch_end(epoch, logs)
        print(f"\nCalculating FID score for epoch {epoch + 1}...")

        crlv = tf.random.normal(shape=(self.num_samples, Z_DIM))
        srlv = tf.random.normal(shape=(self.num_samples, Z_DIM))
        mrlv = tf.random.normal(shape=(self.num_samples, N_TRACKS, Z_DIM))
        grlv = tf.random.normal(shape=(self.num_samples, N_TRACKS, Z_DIM))
        random_latent_vectors = [crlv, srlv, mrlv, grlv]
        fake_music = self.model.generator_model.predict(random_latent_vectors)


        real_music_list = list(self.real_samples.as_numpy_iterator())
        real_music = np.concatenate(real_music_list, axis=0)
        # Ensure we have the same number of samples
        num_to_use = min(len(real_music), len(fake_music))

        # Reshape real_music to include the N_BARS dimension if it's missing
        if real_music.ndim == 4:
            real_music = real_music.reshape(-1, N_BARS, N_STEPS_PER_BAR, N_PITCHES, N_TRACKS)


        fid_score = self.model.calculate_fid(
            real_music[:num_to_use],
            fake_music[:num_to_use]
        )
        print(f"FID Score: {fid_score:.4f}")

        self.model.fid_metric.update_state(fid_score)
        if wandb.run:
            wandb.log({"epoch_fid_score": fid_score}, commit=False)  # commit=False to log with other epoch metrics

In [20]:
gen_model = Generator()
critic_model = Critic()

dc_gan = MuseGAN(
    generator_model=gen_model,
    critic_model=critic_model,
    latent_dim=Z_DIM,
    critic_steps=CRITIC_STEPS,
    gp_weight=GP_WEIGHT
)

if LOAD_MODEL:
    os.makedirs(model_checkpoint_path, exist_ok=True)
    model_weights_location = os.path.join(model_checkpoint_path, "dc_gan_weights.ckpt")
    dc_gan.load_weights(model_weights_location)


dc_gan.compile(
    critic_optimizer=optimizers.Adam(
        learning_rate=CRITIC_LR,
        beta_1=ADAM_BETA1,
        beta_2=ADAM_BETA2
    ),
    gen_optimizer=optimizers.Adam(
        learning_rate=GENERATOR_LR,
        beta_1=ADAM_BETA1,
        beta_2=ADAM_BETA2
    )
)

In [21]:
class MusicGen(callbacks.Callback):
    def __init__(self, num_scores):
        super(MusicGen, self).__init__()
        self.num_scores = num_scores

    def on_epoch_end(self, epoch, logs=None):
        super(MusicGen, self).on_epoch_end(epoch, logs)
        if epoch % 1 == 0:
            generated_music = self.model.generate_piano_roll(self.num_scores)
            notes_to_midi(
                generated_music,
                N_BARS,
                N_TRACKS,
                N_STEPS_PER_BAR,
                filename="output_" + str(epoch).zfill(4)
            )

In [22]:
# Define CallBacks to Use:
music_gen_callback = MusicGen(1)
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath=os.path.join(model_checkpoint_path, "model.keras"),
    save_freq=10,
    verbose=0
)
#os.makedirs(tf_logs_path, exist_ok=True)
tensorboard_callback = callbacks.TensorBoard(log_dir=tf_logs_path)
wndb_callback = WandbMetricsLogger()
fid_callback = FIDCallback(validation_data=validation_set)


In [23]:
history = dc_gan.fit(
    binary_data,
    epochs=NUM_EPOCHS,
    batch_size=TRAIN_BATCH_SIZE,
    callbacks=[model_checkpoint_callback, tensorboard_callback, music_gen_callback, fid_callback, wndb_callback]
)
# exit wandb
wandb.finish()

Epoch 1/500
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12s/step - critic_gradient_penalty: 1.0000 - critic_loss: 9.9999 - critic_wassertein_loss: -1.0504e-05 - fid_score: 0.0000e+00 - gen_loss: 1.5863e-04  
Calculating FID score for epoch 1...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 12ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 623ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 680ms/step
FID Score: nan
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m131s[0m 14s/step - critic_gradient_penalty: 1.0000 - critic_loss: 9.9999 - critic_wassertein_loss: -1.3100e-05 - fid_score: 0.0000e+00 - gen_loss: 1.6621e-04
Epoch 2/500
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 261ms/step - critic_gradient_penalty: 1.0000 - critic_loss: 9.9997 - critic_wassertein_loss: -1.8716e-04 - fid_score: 0.0000e+00 - gen_loss: 3.4681e-04
Calculating FID score for epoch 2...
[1m32/32[0m [32m━━━━━━

  return saving_lib.save_model(model, filepath)


[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 653ms/step - critic_gradient_penalty: 0.9999 - critic_loss: 9.9973 - critic_wassertein_loss: -0.0022 - fid_score: 0.0000e+00 - gen_loss: 0.0014
Calculating FID score for epoch 3...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
FID Score: nan
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 2s/step - critic_gradient_penalty: 0.9999 - critic_loss: 9.9970 - critic_wassertein_loss: -0.0024 - fid_score: 0.0000e+00 - gen_loss: 0.0016   
Epoch 4/500
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 263ms/step - critic_gradient_penalty: 0.9996 - critic_loss: 9.9764 - critic_wassertein_loss: -0.0198 - fid_score: 0.0000e+00 - gen_loss: 0.0082
Calculating FID score for epoch 4...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m

0,1
epoch/critic_gradient_penalty,█▁▁▁▁▁▁▁▁▁▄▂▂▂▂▂▂▂▂▁▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂▂▃▂▂▂
epoch/critic_loss,███▄▅▃▂▂▁▁▂▃▃▄▄▄▅▅▅▅▅▅▅▅▅▄▅▅▄▄▄▄▄▄▄▄▄▄▄▃
epoch/critic_wassertein_loss,▁▅▆████▇▅▆▇██████▇██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▆▇▇
epoch/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
epoch/fid_score,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch/gen_loss,▅▁▂▂▂▃▂▄▆▅▆▆▆▆▅▇▆▆▇▆▆▆▆▅▅▆▅▆▆▆▇▆▆▆▇▆█▇██
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch/critic_gradient_penalty,0.30095
epoch/critic_loss,-12.80465
epoch/critic_wassertein_loss,-15.81412
epoch/epoch,499.0
epoch/fid_score,0.0
epoch/gen_loss,52.27059
epoch/learning_rate,0.001
epoch_fid_score,


In [1]:
# Generate new scores

def generate_new_scores(noise_vector, mode: str, file_name: str, num_scores: int = 1):
    np.random.seed(10752)
    # Create all base random latent vectors
    random_chords = np.random.normal(size=(num_scores, Z_DIM))
    random_style = np.random.normal(size=(num_scores, Z_DIM))
    random_melody = np.random.normal(size=(num_scores, N_TRACKS, Z_DIM))
    random_groove = np.random.normal(size=(num_scores, N_TRACKS, Z_DIM))

    # Replace the specific latent vector based on the mode
    if mode == 'chords':
        chords_lv = noise_vector
        style_lv = random_style
        melody_lv = random_melody
        groove_lv = random_groove

    elif mode == 'style':
        chords_lv = random_chords
        style_lv = noise_vector
        melody_lv = random_melody
        groove_lv = random_groove

    elif mode == 'melody':
        chords_lv = random_chords
        style_lv = random_style
        melody_lv = noise_vector
        groove_lv = random_groove

    elif mode == 'groove':
        chords_lv = random_chords
        style_lv = random_style
        melody_lv = random_melody
        groove_lv = noise_vector

    else:
        raise ValueError(f"Mode should be either of: chords, style, melody, groove. Got {mode} instead.")

    # Assemble the latent vectors in the correct order for the generator
    latent_vectors = [chords_lv, style_lv, melody_lv, groove_lv]
    print(f"{mode} latent noise vector has been generated!")


    generated_music_scores = gen_model(latent_vectors).numpy()

    # Select the first generated score for plotting and MIDI conversion
    generated_music_score = generated_music_scores[0]

    plot_score(generated_music_score)
    notes_to_midi(generated_music_scores, N_BARS, N_TRACKS, N_STEPS_PER_BAR, f"{file_name}")



chords_change = np.random.normal(size=(1, Z_DIM))
generate_new_scores(chords_change, "chords", "changed_chord_midi", 2)

style_change = np.random.normal(size=(1, Z_DIM))
generate_new_scores(style_change, "style", "changed_style_midi", 2)

melody_change = np.random.normal(size=(1, N_TRACKS, Z_DIM))
generate_new_scores(melody_change, "melody", "changed_melody_midi", 2)

groove_change = np.random.normal(size=(1, N_TRACKS, Z_DIM))
generate_new_scores(groove_change, "groove", "changed_groove_midi", 2)

NameError: name 'np' is not defined

In [3]:
import tensorflow as tf

# List all available physical devices
physical_devices = tf.config.list_physical_devices()
print("Available devices:", physical_devices)

# Check if a GPU is available
gpu_available = tf.config.list_physical_devices('GPU')
if gpu_available:
    print("GPU is available and will be used by TensorFlow.")
else:
    print("No GPU available. TensorFlow will use the CPU.")

Available devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU is available and will be used by TensorFlow.
Available devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU is available and will be used by TensorFlow.
