# MuseGAN Training

## imports

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import types

from MuseGAN_original import MuseGAN
from util_music import loaders

from music21 import midi
from music21 import note, stream, duration

In [None]:
# run params
SECTION = 'compose'
RUN_ID = '1000'
DATA_NAME = 'lpd_17'
#FILENAME = 'Jsb16thSeparated.npz' #'lpd_17_cleansed.npz'

RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])


if not os.path.exists(RUN_FOLDER):
    os.makedirs(RUN_FOLDER)
    os.makedirs(os.path.join(RUN_FOLDER, 'viz'))
    os.makedirs(os.path.join(RUN_FOLDER, 'images'))
    os.makedirs(os.path.join(RUN_FOLDER, 'weights'))
    os.makedirs(os.path.join(RUN_FOLDER, 'samples'))

mode =  'build' # ' 'load' # 

## data

In [None]:
BATCH_SIZE = 64
n_bars = 2
n_steps_per_bar = 16
n_pitches = 84
n_tracks = 8

# data_binary_2, data_ints, raw_data = loaders.load_music(DATA_NAME, FILENAME, n_bars, n_steps_per_bar)
# data_binary_2 = np.squeeze(data_binary)
# data_binary = np.load('./run/dataset3.npy')

In [None]:
data_binary = np.load('./run/dataset3.npy')
print(data_binary.shape)

In [None]:
data_binary = data_binary.reshape(-1, n_bars, n_steps_per_bar, n_pitches, 17)
print(data_binary.shape)

In [None]:
# data_binary = np.where(data_binary==False, -1, 1)

## architecture

In [None]:
gan = MuseGAN(input_dim = data_binary.shape[1:-1] + (8,)
        , critic_learning_rate = 0.01
        , generator_learning_rate = 0.01
        , optimiser = 'adam'
        , grad_weight = 0.0
        , z_dim = 32
        , batch_size = BATCH_SIZE
        , n_tracks = n_tracks
        , n_bars = n_bars
        , n_steps_per_bar = n_steps_per_bar
        , n_pitches = n_pitches
        )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:                 
    gan.load_weights(RUN_FOLDER)

In [None]:
gan.chords_tempNetwork.summary()

In [None]:
gan.barGen[0].summary()

In [None]:
gan.generator.summary()

In [None]:
gan.critic.summary()

## training

In [None]:
EPOCHS = 500
PRINT_EVERY_N_BATCHES = 10

gan.epoch = 0

In [None]:
print(np.all(np.isnan(data_binary[:, :, :, :, :8])))
print(np.any(~np.isnan(data_binary[:, :, :, :, :8])))
print(np.where(data_binary[:, :, :, :, :8] == True))

In [None]:
np.all(~data_binary[:, :, :, :, :8])

In [None]:
gan.train(
    data_binary[:, :, :, :, :8]
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
)

In [None]:
fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, len(gan.d_losses))
# plt.ylim(0, 2)

plt.show()

In [None]:
RUN_FOLDER