<a href="https://colab.research.google.com/github/maekawataiki/MusicGeneration/blob/master/Style_Transformer_Cycle_GAN/Music_Style_Transfer_Cycle_GAN_inference_only.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Music Style Transfer with Cycle GAN

This notebook demonstrates the trainining and inference of Cycle-GAN.

This note book is for inference only. If you want to train the model on your dataset, there is [another notebook]().

Based on paper
[Symbolic Music Genre Transfer with CycleGAN](https://arxiv.org/pdf/1809.07575.pdf)

## Changes from original paper

- Identity loss is added to Generator loss which significantly improved convergence


In [None]:
#@title Clone repository

!git clone https://github.com/sumuzhao/CycleGAN-Music-Style-Transfer-Refactorization
%cd CycleGAN-Music-Style-Transfer-Refactorization
!mkdir models
!pip install pretty_midi pypianoroll

In [None]:
#@title Select model type

model_type = 'Classic <-> Pop' #@param ["Classic <-> Pop", "Jazz <-> Pop"]


In [None]:
#@title Download pre-trained model

if model_type == 'Classic <-> Pop':
  !gdown https://drive.google.com/uc?id=10Q2v1Fad0kvdc_fB2ZDritwa7xA1kQ3Y
  !unzip CP_C2CP_P_2020-08-31_base_0.1.zip -d ./models
else:
  !gdown https://drive.google.com/uc?id=124-o9uchvYJOw5wgoefdNeomhBKymUTh
  !unzip JP_J2JP_P_2020-08-28_base_0.01.zip -d ./models

In [None]:
#@title Import Dependencies

import os
import time
import datetime
import copy
from glob import glob
from collections import namedtuple

import pretty_midi
from pypianoroll import Multitrack, Track

from google.colab import files, drive

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers, Input
from tensorflow.keras.optimizers import Adam

import write_midi
from tf2_module import abs_criterion, mae_criterion
from tf2_utils import get_now_datetime, ImagePool, to_binary, load_npy_data, save_midis

In [None]:
#@title Define Model

def get_bar_piano_roll(piano_roll):
    if int(piano_roll.shape[0] % 64) is not 0:
        if LAST_BAR_MODE == 'fill':
            piano_roll = np.concatenate((piano_roll, np.zeros((64 - piano_roll.shape[0] % 64, 128))), axis=0)
        elif LAST_BAR_MODE == 'remove':
            piano_roll = np.delete(piano_roll,  np.s_[-int(piano_roll.shape[0] % 64):], axis=0)
    piano_roll = piano_roll.reshape(-1, 64, 128)
    return piano_roll

def padding(x, p=3):
    return tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")

class InstanceNorm(layers.Layer):
    def __init__(self, **kwargs):
        super(InstanceNorm, self).__init__(**kwargs)

    def build(self, input_shape):
        self.scale = self.add_weight(name='SCALE',
                                     shape=input_shape[-1:],
                                     initializer=tf.keras.initializers.random_normal(1., 0.02),
                                     trainable=True,
                                     dtype=tf.float32)
        self.offset = self.add_weight(name='OFFSET',
                                      shape=input_shape[-1:],
                                      initializer=tf.keras.initializers.zeros(),
                                      trainable=True,
                                      dtype=tf.float32)

    def call(self, x, epsilon=1e-5):
          mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
          inv = tf.math.rsqrt(variance + epsilon)
          normalized = (x - mean) * inv
          return self.scale * normalized + self.offset

def build_resnet_block(dim, k_init, ks=3, s=1, name='Resnet'):

    inputs = Input(shape=(128, 128, 256))
    x = inputs

    # e.g, x is (batch * 128 * 128 * 3)
    p = (ks - 1) // 2
    # For ks = 3, p = 1
    y = layers.Lambda(padding,
                      arguments={'p': p},
                      name='PADDING_1')(x)
    # After first padding, (batch * 130 * 130 * 3)

    y = layers.Conv2D(filters=dim,
                      kernel_size=ks,
                      strides=s,
                      padding='valid',
                      kernel_initializer=k_init,
                      use_bias=False)(y)
    y = InstanceNorm(name='IN_1')(y)
    y = layers.ReLU()(y)
    # After first conv2d, (batch * 128 * 128 * 3)

    y = layers.Lambda(padding,
                      arguments={'p': p},
                      name='PADDING_2')(y)
    # After second padding, (batch * 130 * 130 * 3)

    y = layers.Conv2D(filters=dim,
                      kernel_size=ks,
                      strides=s,
                      padding='valid',
                      kernel_initializer=k_init,
                      use_bias=False)(y)
    y = InstanceNorm(name='IN_2')(y)
    y = layers.ReLU()(y + x)
    # After second conv2d, (batch * 128 * 128 * 3)
    outputs = y
    return Model(inputs=inputs,
                outputs=outputs,
                name=name)

def build_discriminator(options, name='Discriminator'):

    initializer = tf.random_normal_initializer(0., 0.02)

    inputs = Input(shape=(options.time_step,
                          options.pitch_range,
                          options.output_nc))

    x = inputs

    x = layers.Conv2D(filters=options.df_dim,
                      kernel_size=7,
                      strides=2,
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_1')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    # (batch * 32 * 42 * 64)

    x = layers.Conv2D(filters=options.df_dim * 4,
                      kernel_size=7,
                      strides=2,
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_2')(x)
    x = InstanceNorm(name='IN_1')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    # (batch * 16 * 21 * 256)

    x = layers.Conv2D(filters=1,
                      kernel_size=7,
                      strides=1,
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_3')(x)
    # (batch * 16 * 21 * 1)

    outputs = x

    return Model(inputs=inputs,
                 outputs=outputs,
                 name=name)


def build_generator(options, name='Generator'):

    initializer = tf.random_normal_initializer(0., 0.02)

    inputs = Input(shape=(options.time_step,
                          options.pitch_range,
                          options.output_nc))

    x = inputs
    # (batch * 64 * 84 * 1)

    x = layers.Lambda(padding,
                      name='PADDING_1')(x)
    # (batch * 70 * 90 * 1)

    x = layers.Conv2D(filters=options.gf_dim,
                      kernel_size=7,
                      strides=1,
                      padding='valid',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_1')(x)
    x = InstanceNorm(name='IN_1')(x)
    x = layers.ReLU()(x)
    # (batch * 64 * 84 * 64)

    x = layers.Conv2D(filters=options.gf_dim * 2,
                      kernel_size=3,
                      strides=2,
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_2')(x)
    x = InstanceNorm(name='IN_2')(x)
    x = layers.ReLU()(x)
    # (batch * 32 * 42 * 128)

    x = layers.Conv2D(filters=options.gf_dim * 4,
                      kernel_size=3,
                      strides=2,
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_3')(x)
    x = InstanceNorm(name='IN_3')(x)
    x = layers.ReLU()(x)
    # (batch * 16 * 21 * 256)

    for i in range(10):
        x = build_resnet_block(dim=options.gf_dim * 4, 
                               k_init=initializer,
                               name='ResNet_Block_{}'.format(i))(x)
    # (batch * 16 * 21 * 256)

    x = layers.Conv2DTranspose(filters=options.gf_dim * 2,
                               kernel_size=3,
                               strides=2,
                               padding='same',
                               kernel_initializer=initializer,
                               use_bias=False,
                               name='DECONV2D_1')(x)
    x = InstanceNorm(name='IN_4')(x)
    x = layers.ReLU()(x)
    # (batch * 32 * 42 * 128)

    x = layers.Conv2DTranspose(filters=options.gf_dim,
                               kernel_size=3,
                               strides=2,
                               padding='same',
                               kernel_initializer=initializer,
                               use_bias=False,
                               name='DECONV2D_2')(x)
    x = InstanceNorm(name='IN_5')(x)
    x = layers.ReLU()(x)
    # (batch * 64 * 84 * 64)

    x = layers.Lambda(padding,
                      name='PADDING_2')(x)
    # After padding, (batch * 70 * 90 * 64)

    x = layers.Conv2D(filters=options.output_nc,
                      kernel_size=7,
                      strides=1,
                      padding='valid',
                      kernel_initializer=initializer,
                      activation='sigmoid',
                      use_bias=False,
                      name='CONV2D_4')(x)
    # (batch * 64 * 84 * 1)

    outputs = x

    return Model(inputs=inputs,
                 outputs=outputs,
                 name=name)


def build_discriminator_classifier(options, name='Discriminator_Classifier'):

    initializer = tf.random_normal_initializer(0., 0.02)

    inputs = Input(shape=(options.time_step,
                          options.pitch_range,
                          options.output_nc))

    x = inputs
    # (batch * 64, 84, 1)

    x = layers.Conv2D(filters=options.df_dim,
                      kernel_size=[1, 12],
                      strides=[1, 12],
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_1')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    # (batch * 64 * 7 * 64)

    x = layers.Conv2D(filters=options.df_dim * 2,
                      kernel_size=[4, 1],
                      strides=[4, 1],
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_2')(x)
    x = InstanceNorm(name='IN_1')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    # (batch * 16 * 7 * 128)

    x = layers.Conv2D(filters=options.df_dim * 4,
                      kernel_size=[2, 1],
                      strides=[2, 1],
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_3')(x)
    x = InstanceNorm(name='IN_2')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    # (batch * 8 * 7 * 256)

    x = layers.Conv2D(filters=options.df_dim * 8,
                      kernel_size=[8, 1],
                      strides=[8, 1],
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_4')(x)
    x = InstanceNorm(name='IN_3')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    # (batch * 1 * 7 * 512)

    x = layers.Conv2D(filters=2,
                      kernel_size=[1, 7],
                      strides=[1, 7],
                      padding='same',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_5')(x)
    # (batch * 1 * 1 * 2)

    x = tf.reshape(x, [-1, 2])
    # (batch * 2)

    outputs = x

    return Model(inputs=inputs,
                 outputs=outputs,
                 name=name)

class CycleGAN(object):

    def __init__(self, args):

        self.batch_size = args.batch_size
        self.time_step = args.time_step  # number of time steps
        self.pitch_range = args.pitch_range  # number of pitches
        self.input_c_dim = args.input_nc  # number of input image channels
        self.output_c_dim = args.output_nc  # number of output image channels
        self.lr = args.lr
        self.L1_lambda = args.L1_lambda
        self.gamma = args.gamma
        self.sigma_d = args.sigma_d
        self.dataset_A_dir = args.dataset_A_dir
        self.dataset_B_dir = args.dataset_B_dir
        self.sample_dir = args.sample_dir

        self.model = args.model
        self.discriminator = build_discriminator
        self.generator = build_generator
        self.criterionGAN = mae_criterion

        OPTIONS = namedtuple('OPTIONS', 'batch_size '
                                        'time_step '
                                        'input_nc '
                                        'output_nc '
                                        'pitch_range '
                                        'gf_dim '
                                        'df_dim '
                                        'is_training')
        self.options = OPTIONS._make((args.batch_size,
                                      args.time_step,
                                      args.pitch_range,
                                      args.input_nc,
                                      args.output_nc,
                                      args.ngf,
                                      args.ndf,
                                      args.phase == 'train'))

        self.now_datetime = get_now_datetime()
        self.pool = ImagePool(args.max_size)

        self._build_model(args)

        print("initialize model...")

    def _build_model(self, args):

        # Generator
        self.generator_A2B = self.generator(self.options,
                                            name='Generator_A2B')
        self.generator_B2A = self.generator(self.options,
                                            name='Generator_B2A')

        # Discriminator
        self.discriminator_A = self.discriminator(self.options,
                                                  name='Discriminator_A')
        self.discriminator_B = self.discriminator(self.options,
                                                  name='Discriminator_B')

        if self.model != 'base':
            self.discriminator_A_all = self.discriminator(self.options,
                                                          name='Discriminator_A_all')
            self.discriminator_B_all = self.discriminator(self.options,
                                                          name='Discriminator_B_all')

        # Discriminator and Generator Optimizer
        self.DA_optimizer = Adam(self.lr,
                                 beta_1=args.beta1)
        self.DB_optimizer = Adam(self.lr,
                                 beta_1=args.beta1)
        self.GA2B_optimizer = Adam(self.lr,
                                   beta_1=args.beta1)
        self.GB2A_optimizer = Adam(self.lr,
                                   beta_1=args.beta1)

        if self.model != 'base':
            self.DA_all_optimizer = Adam(self.lr,
                                         beta_1=args.beta1)
            self.DB_all_optimizer = Adam(self.lr,
                                         beta_1=args.beta1)

        # Checkpoints
        model_name = "cyclegan.model"
        model_dir = "{}2{}_{}_{}_{}".format(self.dataset_A_dir,
                                            self.dataset_B_dir,
                                            self.now_datetime,
                                            self.model,
                                            self.sigma_d)
        self.checkpoint_dir = os.path.join(args.checkpoint_dir,
                                           args.model_dir or model_dir,
                                           model_name)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        if self.model == 'base':
            self.checkpoint = tf.train.Checkpoint(generator_A2B_optimizer=self.GA2B_optimizer,
                                                  generator_B2A_optimizer=self.GB2A_optimizer,
                                                  discriminator_A_optimizer=self.DA_optimizer,
                                                  discriminator_B_optimizer=self.DB_optimizer,
                                                  generator_A2B=self.generator_A2B,
                                                  generator_B2A=self.generator_B2A,
                                                  discriminator_A=self.discriminator_A,
                                                  discriminator_B=self.discriminator_B)
        else:
            self.checkpoint = tf.train.Checkpoint(generator_A2B_optimizer=self.GA2B_optimizer,
                                                  generator_B2A_optimizer=self.GB2A_optimizer,
                                                  discriminator_A_optimizer=self.DA_optimizer,
                                                  discriminator_B_optimizer=self.DB_optimizer,
                                                  discriminator_A_all_optimizer=self.DA_all_optimizer,
                                                  discriminator_B_all_optimizer=self.DB_all_optimizer,
                                                  generator_A2B=self.generator_A2B,
                                                  generator_B2A=self.generator_B2A,
                                                  discriminator_A=self.discriminator_A,
                                                  discriminator_B=self.discriminator_B,
                                                  discriminator_A_all=self.discriminator_A_all,
                                                  discriminator_B_all=self.discriminator_B_all)

        self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint,
                                                             self.checkpoint_dir,
                                                             max_to_keep=5)

        # if self.checkpoint_manager.latest_checkpoint:
        #     self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
        #     print('Latest checkpoint restored!!')

    def train(self, args):

        # Data from domain A and B, and mixed dataset for partial and full models.
        dataA = glob('./datasets/{}/train/*.*'.format(self.dataset_A_dir))
        dataB = glob('./datasets/{}/train/*.*'.format(self.dataset_B_dir))
        data_mixed = None
        if self.model == 'partial':
            data_mixed = dataA + dataB
        if self.model == 'full':
            data_mixed = glob('./datasets/JCP_mixed/*.*')

        if args.continue_train:
            if self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint):
                print(" [*] Load checkpoint succeeded!")
            else:
                print(" [!] Load checkpoint failed...")

        counter = 1
        start_time = time.time()

        for epoch in range(args.epoch):

            # Shuffle training data
            np.random.shuffle(dataA)
            np.random.shuffle(dataB)
            if self.model != 'base' and data_mixed is not None:
                np.random.shuffle(data_mixed)

            # Get the proper number of batches
            batch_idxs = min(len(dataA), len(dataB)) // self.batch_size

            # learning rate starts to decay when reaching the threshold
            self.lr = self.lr if epoch < args.epoch_step else self.lr * (args.epoch-epoch) / (args.epoch-args.epoch_step)

            for idx in range(batch_idxs):

                # To feed real_data
                batch_files = list(zip(dataA[idx * self.batch_size:(idx + 1) * self.batch_size],
                                       dataB[idx * self.batch_size:(idx + 1) * self.batch_size]))
                batch_samples = [load_npy_data(batch_file) for batch_file in batch_files]
                batch_samples = np.array(batch_samples).astype(np.float32)  # batch_size * 64 * 84 * 2
                real_A, real_B = batch_samples[:, :, :, 0], batch_samples[:, :, :, 1]
                real_A = tf.expand_dims(real_A, -1)
                real_B = tf.expand_dims(real_B, -1)

                # generate gaussian noise for robustness improvement
                gaussian_noise = np.abs(np.random.normal(0,
                                                         self.sigma_d,
                                                         [self.batch_size,
                                                          self.time_step,
                                                          self.pitch_range,
                                                          self.input_c_dim])).astype(np.float32)

                if self.model == 'base':

                    with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape(persistent=True) as disc_tape:

                        fake_B = self.generator_A2B(real_A,
                                                    training=True)
                        cycle_A = self.generator_B2A(fake_B,
                                                     training=True)

                        fake_A = self.generator_B2A(real_B,
                                                    training=True)
                        cycle_B = self.generator_A2B(fake_A,
                                                     training=True)
                        
                        # Added for identity loss
                        same_B = self.generator_A2B(real_B,
                                                    training=True)
                        same_A = self.generator_B2A(real_A,
                                                    training=True)

                        [fake_A_sample, fake_B_sample] = self.pool([fake_A, fake_B])

                        DA_real = self.discriminator_A(real_A + gaussian_noise,
                                                       training=True)
                        DB_real = self.discriminator_B(real_B + gaussian_noise,
                                                       training=True)

                        DA_fake = self.discriminator_A(fake_A + gaussian_noise,
                                                       training=True)
                        DB_fake = self.discriminator_B(fake_B + gaussian_noise,
                                                       training=True)

                        DA_fake_sample = self.discriminator_A(fake_A_sample + gaussian_noise,
                                                              training=True)
                        DB_fake_sample = self.discriminator_B(fake_B_sample + gaussian_noise,
                                                              training=True)

                        # Generator loss
                        cycle_loss = self.L1_lambda * (abs_criterion(real_A, cycle_A) + abs_criterion(real_B, cycle_B))
                        identity_loss_A = self.L1_lambda * 0.5 * abs_criterion(real_A, same_A) # added
                        identity_loss_B = self.L1_lambda * 0.5 * abs_criterion(real_B, same_B) # added
                        g_A2B_loss = self.criterionGAN(DB_fake, tf.ones_like(DB_fake)) + cycle_loss + identity_loss_B
                        g_B2A_loss = self.criterionGAN(DA_fake, tf.ones_like(DA_fake)) + cycle_loss + identity_loss_A
                        g_loss = g_A2B_loss + g_B2A_loss - cycle_loss

                        # Discriminator loss
                        d_A_loss_real = self.criterionGAN(DA_real, tf.ones_like(DA_real))
                        d_A_loss_fake = self.criterionGAN(DA_fake_sample, tf.zeros_like(DA_fake_sample))
                        d_A_loss = (d_A_loss_real + d_A_loss_fake) / 2
                        d_B_loss_real = self.criterionGAN(DB_real, tf.ones_like(DB_real))
                        d_B_loss_fake = self.criterionGAN(DB_fake_sample, tf.zeros_like(DB_fake_sample))
                        d_B_loss = (d_B_loss_real + d_B_loss_fake) / 2
                        d_loss = d_A_loss + d_B_loss

                    # Calculate the gradients for generator and discriminator
                    generator_A2B_gradients = gen_tape.gradient(target=g_A2B_loss,
                                                                sources=self.generator_A2B.trainable_variables)
                    generator_B2A_gradients = gen_tape.gradient(target=g_B2A_loss,
                                                                sources=self.generator_B2A.trainable_variables)

                    discriminator_A_gradients = disc_tape.gradient(target=d_A_loss,
                                                                   sources=self.discriminator_A.trainable_variables)
                    discriminator_B_gradients = disc_tape.gradient(target=d_B_loss,
                                                                   sources=self.discriminator_B.trainable_variables)

                    # Apply the gradients to the optimizer
                    self.GA2B_optimizer.apply_gradients(zip(generator_A2B_gradients,
                                                            self.generator_A2B.trainable_variables))
                    self.GB2A_optimizer.apply_gradients(zip(generator_B2A_gradients,
                                                            self.generator_B2A.trainable_variables))

                    self.DA_optimizer.apply_gradients(zip(discriminator_A_gradients,
                                                          self.discriminator_A.trainable_variables))
                    self.DB_optimizer.apply_gradients(zip(discriminator_B_gradients,
                                                          self.discriminator_B.trainable_variables))

                    print('=================================================================')
                    print(("Epoch: [%2d] [%4d/%4d] time: %4.4f D_loss: %6.2f, G_loss: %6.2f, cycle_loss: %6.2f" %
                           (epoch, idx, batch_idxs, time.time() - start_time, d_loss, g_loss, cycle_loss)))

                else:

                    # To feed real_mixed
                    batch_files_mixed = data_mixed[idx * self.batch_size:(idx + 1) * self.batch_size]
                    batch_samples_mixed = [np.load(batch_file) * 1. for batch_file in batch_files_mixed]
                    real_mixed = np.array(batch_samples_mixed).astype(np.float32)

                    with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape(persistent=True) as disc_tape:

                        fake_B = self.generator_A2B(real_A,
                                                    training=True)
                        cycle_A = self.generator_B2A(fake_B,
                                                     training=True)

                        fake_A = self.generator_B2A(real_B,
                                                    training=True)
                        cycle_B = self.generator_A2B(fake_A,
                                                     training=True)

                        [fake_A_sample, fake_B_sample] = self.pool([fake_A, fake_B])

                        DA_real = self.discriminator_A(real_A + gaussian_noise,
                                                       training=True)
                        DB_real = self.discriminator_B(real_B + gaussian_noise,
                                                       training=True)

                        DA_fake = self.discriminator_A(fake_A + gaussian_noise,
                                                       training=True)
                        DB_fake = self.discriminator_B(fake_B + gaussian_noise,
                                                       training=True)

                        DA_fake_sample = self.discriminator_A(fake_A_sample + gaussian_noise,
                                                              training=True)
                        DB_fake_sample = self.discriminator_B(fake_B_sample + gaussian_noise,
                                                              training=True)

                        DA_real_all = self.discriminator_A_all(real_mixed + gaussian_noise,
                                                               training=True)
                        DB_real_all = self.discriminator_B_all(real_mixed + gaussian_noise,
                                                               training=True)

                        DA_fake_sample_all = self.discriminator_A_all(fake_A_sample + gaussian_noise,
                                                                      training=True)
                        DB_fake_sample_all = self.discriminator_B_all(fake_B_sample + gaussian_noise,
                                                                      training=True)

                        # Generator loss
                        cycle_loss = self.L1_lambda * (abs_criterion(real_A, cycle_A) + abs_criterion(real_B, cycle_B))
                        g_A2B_loss = self.criterionGAN(DB_fake, tf.ones_like(DB_fake)) + cycle_loss
                        g_B2A_loss = self.criterionGAN(DA_fake, tf.ones_like(DA_fake)) + cycle_loss
                        g_loss = g_A2B_loss + g_B2A_loss - cycle_loss

                        # Discriminator loss
                        d_A_loss_real = self.criterionGAN(DA_real, tf.ones_like(DA_real))
                        d_A_loss_fake = self.criterionGAN(DA_fake_sample, tf.zeros_like(DA_fake_sample))
                        d_A_loss = (d_A_loss_real + d_A_loss_fake) / 2
                        d_B_loss_real = self.criterionGAN(DB_real, tf.ones_like(DB_real))
                        d_B_loss_fake = self.criterionGAN(DB_fake_sample, tf.zeros_like(DB_fake_sample))
                        d_B_loss = (d_B_loss_real + d_B_loss_fake) / 2
                        d_loss = d_A_loss + d_B_loss

                        d_A_all_loss_real = self.criterionGAN(DA_real_all, tf.ones_like(DA_real_all))
                        d_A_all_loss_fake = self.criterionGAN(DA_fake_sample_all, tf.zeros_like(DA_fake_sample_all))
                        d_A_all_loss = (d_A_all_loss_real + d_A_all_loss_fake) / 2
                        d_B_all_loss_real = self.criterionGAN(DB_real_all, tf.ones_like(DB_real_all))
                        d_B_all_loss_fake = self.criterionGAN(DB_fake_sample_all, tf.zeros_like(DB_fake_sample_all))
                        d_B_all_loss = (d_B_all_loss_real + d_B_all_loss_fake) / 2
                        d_all_loss = d_A_all_loss + d_B_all_loss
                        D_loss = d_loss + self.gamma * d_all_loss

                    # Calculate the gradients for generator and discriminator
                    generator_A2B_gradients = gen_tape.gradient(target=g_A2B_loss,
                                                                sources=self.generator_A2B.trainable_variables)
                    generator_B2A_gradients = gen_tape.gradient(target=g_B2A_loss,
                                                                sources=self.generator_B2A.trainable_variables)

                    discriminator_A_gradients = disc_tape.gradient(target=d_A_loss,
                                                                   sources=self.discriminator_A.trainable_variables)
                    discriminator_B_gradients = disc_tape.gradient(target=d_B_loss,
                                                                   sources=self.discriminator_B.trainable_variables)

                    discriminator_A_all_gradients = disc_tape.gradient(target=d_A_all_loss,
                                                                   sources=self.discriminator_A_all.trainable_variables)
                    discriminator_B_all_gradients = disc_tape.gradient(target=d_B_all_loss,
                                                                   sources=self.discriminator_B_all.trainable_variables)

                    # Apply the gradients to the optimizer
                    self.GA2B_optimizer.apply_gradients(zip(generator_A2B_gradients,
                                                            self.generator_A2B.trainable_variables))
                    self.GB2A_optimizer.apply_gradients(zip(generator_B2A_gradients,
                                                            self.generator_B2A.trainable_variables))

                    self.DA_optimizer.apply_gradients(zip(discriminator_A_gradients,
                                                          self.discriminator_A.trainable_variables))
                    self.DB_optimizer.apply_gradients(zip(discriminator_B_gradients,
                                                          self.discriminator_B.trainable_variables))

                    self.DA_all_optimizer.apply_gradients(zip(discriminator_A_all_gradients,
                                                              self.discriminator_A_all.trainable_variables))
                    self.DB_all_optimizer.apply_gradients(zip(discriminator_B_all_gradients,
                                                              self.discriminator_B_all.trainable_variables))

                    print('=================================================================')
                    print(("Epoch: [%2d] [%4d/%4d] time: %4.4f D_loss: %6.2f, G_loss: %6.2f" %
                           (epoch, idx, batch_idxs, time.time() - start_time, D_loss, g_loss)))

                counter += 1

                # generate samples during training to track the learning process
                if np.mod(counter, args.print_freq) == 1:
                    sample_dir = os.path.join(self.sample_dir,
                                              '{}2{}_{}_{}_{}'.format(self.dataset_A_dir,
                                                                      self.dataset_B_dir,
                                                                      self.now_datetime,
                                                                      self.model,
                                                                      self.sigma_d))
                    if not os.path.exists(sample_dir):
                        os.makedirs(sample_dir)

                    # to binary, 0 denotes note off, 1 denotes note on
                    samples = [to_binary(real_A, 0.5),
                               to_binary(fake_B, 0.5),
                               to_binary(cycle_A, 0.5),
                               to_binary(real_B, 0.5),
                               to_binary(fake_A, 0.5),
                               to_binary(cycle_B, 0.5)]

                    self.sample_model(samples=samples,
                                      sample_dir=sample_dir,
                                      epoch=epoch,
                                      idx=idx)

                if np.mod(counter, args.save_freq) == 1:
                    self.checkpoint_manager.save(counter)

    def sample_model(self, samples, sample_dir, epoch, idx):

        print('generating samples during learning......')

        if not os.path.exists(os.path.join(sample_dir, 'B2A')):
            os.makedirs(os.path.join(sample_dir, 'B2A'))
        if not os.path.exists(os.path.join(sample_dir, 'A2B')):
            os.makedirs(os.path.join(sample_dir, 'A2B'))

        save_midis(samples[0], './{}/A2B/{:02d}_{:04d}_origin.mid'.format(sample_dir, epoch, idx))
        save_midis(samples[1], './{}/A2B/{:02d}_{:04d}_transfer.mid'.format(sample_dir, epoch, idx))
        save_midis(samples[2], './{}/A2B/{:02d}_{:04d}_cycle.mid'.format(sample_dir, epoch, idx))
        save_midis(samples[3], './{}/B2A/{:02d}_{:04d}_origin.mid'.format(sample_dir, epoch, idx))
        save_midis(samples[4], './{}/B2A/{:02d}_{:04d}_transfer.mid'.format(sample_dir, epoch, idx))
        save_midis(samples[5], './{}/B2A/{:02d}_{:04d}_cycle.mid'.format(sample_dir, epoch, idx))

    def test(self, args):

        if args.which_direction == 'AtoB':
            sample_files = glob('./datasets/{}/test/*.*'.format(self.dataset_A_dir))
        elif args.which_direction == 'BtoA':
            sample_files = glob('./datasets/{}/test/*.*'.format(self.dataset_B_dir))
        else:
            raise Exception('--which_direction must be AtoB or BtoA')
        sample_files.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1]))

        if self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint):
            print(" [*] Load checkpoint succeeded!")
        else:
            print(" [!] Load checkpoint failed...")

        test_dir_mid = os.path.join(args.test_dir, '{}2{}_{}_{}_{}/{}/mid'.format(self.dataset_A_dir,
                                                                                  self.dataset_B_dir,
                                                                                  self.now_datetime,
                                                                                  self.model,
                                                                                  self.sigma_d,
                                                                                  args.which_direction))
        if not os.path.exists(test_dir_mid):
            os.makedirs(test_dir_mid)

        test_dir_npy = os.path.join(args.test_dir, '{}2{}_{}_{}_{}/{}/npy'.format(self.dataset_A_dir,
                                                                                  self.dataset_B_dir,
                                                                                  self.now_datetime,
                                                                                  self.model,
                                                                                  self.sigma_d,
                                                                                  args.which_direction))
        if not os.path.exists(test_dir_npy):
            os.makedirs(test_dir_npy)

        for idx in range(len(sample_files)):
            print('Processing midi: ', sample_files[idx])
            sample_npy = np.load(sample_files[idx]) * 1.

            # save midis
            origin = sample_npy.reshape(1, sample_npy.shape[0], sample_npy.shape[1], 1)
            midi_path_origin = os.path.join(test_dir_mid, '{}_origin.mid'.format(idx + 1))
            midi_path_transfer = os.path.join(test_dir_mid, '{}_transfer.mid'.format(idx + 1))
            midi_path_cycle = os.path.join(test_dir_mid, '{}_cycle.mid'.format(idx + 1))

            if args.which_direction == 'AtoB':

                transfer = self.generator_A2B(origin,
                                              training=False)
                cycle = self.generator_B2A(transfer,
                                           training=False)

            else:

                transfer = self.generator_B2A(origin,
                                              training=False)
                cycle = self.generator_A2B(transfer,
                                           training=False)

            save_midis(origin, midi_path_origin)
            save_midis(transfer, midi_path_transfer)
            save_midis(cycle, midi_path_cycle)

            # save npy files
            npy_path_origin = os.path.join(test_dir_npy, 'origin')
            npy_path_transfer = os.path.join(test_dir_npy, 'transfer')
            npy_path_cycle = os.path.join(test_dir_npy, 'cycle')

            if not os.path.exists(npy_path_origin):
                os.makedirs(npy_path_origin)
            if not os.path.exists(npy_path_transfer):
                os.makedirs(npy_path_transfer)
            if not os.path.exists(npy_path_cycle):
                os.makedirs(npy_path_cycle)

            np.save(os.path.join(npy_path_origin, '{}_origin.npy'.format(idx + 1)), origin)
            np.save(os.path.join(npy_path_transfer, '{}_transfer.npy'.format(idx + 1)), transfer)
            np.save(os.path.join(npy_path_cycle, '{}_cycle.npy'.format(idx + 1)), cycle)

    def inference(self, sample_npys, result_path):
        if self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint):
            print(" [*] Load checkpoint succeeded!")
        else:
            print(" [!] Load checkpoint failed...")

        results = []
        for idx in range(len(sample_npys)):
            sample_npy = sample_npys[idx] * 1.
            
            origin = sample_npy.reshape(1, sample_npy.shape[0], sample_npy.shape[1], 1)

            if args.which_direction == 'AtoB':
                transfer = self.generator_A2B(origin,
                                              training=False)
            else:
                transfer = self.generator_B2A(origin,
                                              training=False)
            results += [transfer]
            # save_midis(transfer, result_path, 127)
        
        result = results[0]
        for i in range(1, len(results)):
          result = np.concatenate((result, results[i]), axis=0)
        save_midis(result, result_path)
        return result

    def test_famous(self, args):

        song = np.load('./datasets/famous_songs/P2C/merged_npy/YMCA.npy')

        if self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint):
            print(" [*] Load checkpoint succeeded!")
        else:
            print(" [!] Load checkpoint failed...")

        if args.which_direction == 'AtoB':
            transfer = self.generator_A2B(song,
                                          training=False)
        else:
            transfer = self.generator_B2A(song,
                                          training=False)

        save_midis(transfer, './datasets/famous_songs/P2C/transfer/YMCA.mid', 127)
        np.save('./datasets/famous_songs/P2C/transfer/YMCA.npy', transfer)

LAST_BAR_MODE = 'remove'

def process_midi(midi_path):

    multitrack = Multitrack(beat_resolution=4)
    x = pretty_midi.PrettyMIDI(midi_path)
    multitrack.parse_pretty_midi(x)

    category_list = {'Piano': [], 'Drums': []}
    program_dict = {'Piano': 0, 'Drums': 0}

    for idx, track in enumerate(multitrack.tracks):
        if track.is_drum:
            category_list['Drums'].append(idx)
        else:
            category_list['Piano'].append(idx)
    tracks = []
    merged = multitrack[category_list['Piano']].get_merged_pianoroll()

    pr = get_bar_piano_roll(merged)
    pr_clip = pr[:, :, 24:108]
    if int(pr_clip.shape[0] % 4) != 0:
        pad = np.zeros(pr_clip.shape)[:4 - pr_clip.shape[0] % 4, :, :]
        pr_clip = np.concatenate((pr_clip, pad), axis=0)
    pr_re = pr_clip.reshape(-1, 64, 84, 1)

    train = pr_re
    x = (train > 0.0)
    result = []
    for i in range(x.shape[0]):
      if np.max(x[i]):
          result += [x[i]]
    return result

class Args:
  dataset_A_dir='CP_C'
  dataset_B_dir='CP_P'
  epoch=10
  epoch_step=10
  batch_size=4
  time_step=64
  pitch_range=84
  ngf=64
  ndf=64
  input_nc=1
  output_nc=1
  lr=0.0002
  beta1=0.5
  which_direction='AtoB'
  phase='train'
  save_freq=1000
  print_freq=100
  continue_train=False
  checkpoint_dir='./checkpoint'
  model_dir=''
  sample_dir='./samples'
  test_dir='./test'
  log_dir='./log'
  L1_lambda=10.0
  gamma=1.0
  max_size=50
  sigma_c=0.01
  sigma_d=0.01
  model='full'
  type='classifier'

  def __init__(self, **kwargs):
    self.__dict__.update(kwargs)

In [None]:
#@title Upload MIDI file for inference
uploaded = files.upload()
file_names = list(uploaded.keys())
filename = file_names[0]

In [None]:
#@title Perform Inference

direction = 'BtoA' #@param ["AtoB", "BtoA"]

model_name = {
    'Classic <-> Pop': 'CP_C2CP_P_2020-08-31_base_0.1',
    'Jazz <-> Pop': 'JP_J2JP_P_2020-08-28_base_0.01'
    }[model_type]

output_path = "../result.mid"
args = Args(type='cyclegan',
            model='base',
            phase='test',
            checkpoint_dir='models',
            model_dir=model_name,
            which_direction=direction)
with tf.device('/device:GPU:0'):
  model = CycleGAN(args)
  model.inference(process_midi(filename), output_path)
print("Inference finished")

In [None]:
#@title Download Result
files.download(output_path)