# MuseGAN Training

## imports

In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import types

from models.MuseGAN import MuseGAN
from utils.loaders import load_music

from music21 import midi
from music21 import note, stream, duration

In [2]:
# run params
SECTION = 'compose'
RUN_ID = '0017'
DATA_NAME = 'chorales'
FILENAME = 'Jsb16thSeparated.npz'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])



if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))
    os.mkdir(os.path.join(RUN_FOLDER, 'samples'))

mode =  'build' # ' 'load' # 

## data

In [3]:
BATCH_SIZE = 64
n_bars = 2
n_steps_per_bar = 16
n_pitches = 84
n_tracks = 4

data_binary, data_ints, raw_data = load_music(DATA_NAME, FILENAME, n_bars, n_steps_per_bar)
data_binary = np.squeeze(data_binary)

## architecture

In [4]:
gan = MuseGAN(input_dim = data_binary.shape[1:]
        , critic_learning_rate = 0.001
        , generator_learning_rate = 0.001
        , optimiser = 'adam'
        , grad_weight = 10
        , z_dim = 32
        , batch_size = BATCH_SIZE
        , n_tracks = n_tracks
        , n_bars = n_bars
        , n_steps_per_bar = n_steps_per_bar
        , n_pitches = n_pitches
        )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:                 
    gan.load_weights(RUN_FOLDER)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [5]:
gan.chords_tempNetwork.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
temporal_input (InputLayer)  [(None, 32)]              0         
_________________________________________________________________
reshape (Reshape)            (None, 1, 1, 32)          0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 2, 1, 1024)        66560     
_________________________________________________________________
batch_normalization (BatchNo (None, 2, 1, 1024)        4096      
_________________________________________________________________
activation (Activation)      (None, 2, 1, 1024)        0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 2, 1, 32)          32800     
_________________________________________________________________
batch_normalization_1 (Batch (None, 2, 1, 32)          128 

In [6]:
gan.barGen[0].summary()

Model: "model_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bar_generator_input (InputLa [(None, 128)]             0         
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              132096    
_________________________________________________________________
batch_normalization_10 (Batc (None, 1024)              4096      
_________________________________________________________________
activation_10 (Activation)   (None, 1024)              0         
_________________________________________________________________
reshape_10 (Reshape)         (None, 2, 1, 512)         0         
_________________________________________________________________
conv2d_transpose_10 (Conv2DT (None, 4, 1, 512)         524800    
_________________________________________________________________
batch_normalization_11 (Batc (None, 4, 1, 512)         2048

In [7]:
gan.generator.summary()

Model: "model_10"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
melody_input (InputLayer)       [(None, 4, 32)]      0                                            
__________________________________________________________________________________________________
chords_input (InputLayer)       [(None, 32)]         0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 32)           0           melody_input[0][0]               
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 32)           0           melody_input[0][0]               
___________________________________________________________________________________________

In [8]:
gan.critic.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
critic_input (InputLayer)    [(None, 2, 16, 84, 4)]    0         
_________________________________________________________________
conv3d (Conv3D)              multiple                  1152      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      multiple                  0         
_________________________________________________________________
conv3d_1 (Conv3D)            multiple                  16512     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    multiple                  0         
_________________________________________________________________
conv3d_2 (Conv3D)            multiple                  196736    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    multiple                  0     

## training

In [9]:

EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 10

gan.epoch = 0

In [None]:
gan.train(     
    data_binary
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
)

0 (5, 1) [D loss: (9.2)(R -0.7, F -0.0, G 1.0)] [G loss: -0.1]
1 (5, 1) [D loss: (-257.6)(R -301.6, F -5.2, G 4.9)] [G loss: -22.4]
2 (5, 1) [D loss: (-324.9)(R -612.0, F -23.7, G 31.1)] [G loss: -150.4]
3 (5, 1) [D loss: (-303.9)(R -639.6, F -25.7, G 36.1)] [G loss: -164.0]
4 (5, 1) [D loss: (-21.5)(R 102.7, F -126.9, G 0.3)] [G loss: -431.1]
5 (5, 1) [D loss: (-301.7)(R -843.6, F 279.6, G 26.2)] [G loss: -554.4]
6 (5, 1) [D loss: (-253.9)(R -670.3, F 271.6, G 14.5)] [G loss: -598.1]
7 (5, 1) [D loss: (-188.4)(R -583.1, F 295.9, G 9.9)] [G loss: -555.3]
8 (5, 1) [D loss: (-125.5)(R -526.0, F 324.9, G 7.6)] [G loss: -465.0]
9 (5, 1) [D loss: (-79.1)(R -432.0, F 302.6, G 5.0)] [G loss: -344.0]
10 (5, 1) [D loss: (-71.7)(R -78.7, F -42.1, G 4.9)] [G loss: -153.3]
11 (5, 1) [D loss: (-29.6)(R 241.1, F -271.6, G 0.1)] [G loss: 437.6]
12 (5, 1) [D loss: (-30.7)(R 441.6, F -480.0, G 0.8)] [G loss: 586.8]
13 (5, 1) [D loss: (-46.3)(R 572.4, F -633.4, G 1.5)] [G loss: 706.0]
14 (5, 1) [D loss:

#######

In [None]:
fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, len(gan.d_losses))
# plt.ylim(0, 2)

plt.show()
