# MuseGAN 音乐生成 - 模型训练
## 引入

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

from models import MuseGAN
from utils import load_music

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

## 环境变量

In [None]:
# run params
SECTION = 'compose'
RUN_ID = '0002'
DATA_NAME = 'chorales'
FILENAME = 'Jsb16thSeparated.npz'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.makedirs(RUN_FOLDER)
    os.makedirs(os.path.join(RUN_FOLDER, 'viz'))
    os.makedirs(os.path.join(RUN_FOLDER, 'images'))
    os.makedirs(os.path.join(RUN_FOLDER, 'weights'))
    os.makedirs(os.path.join(RUN_FOLDER, 'samples'))

mode = 'build'  # ' 'load' #

## 加载数据

In [None]:
BATCH_SIZE = 64
n_bars = 2
n_steps_per_bar = 16
n_pitches = 84
n_tracks = 4

data_binary, data_ints, raw_data = load_music(DATA_NAME, FILENAME, n_bars, n_steps_per_bar)
data_binary = np.squeeze(data_binary)

## 模型加载

In [None]:
gan = MuseGAN(
    input_dim=data_binary.shape[1:], 
    critic_learning_rate=0.001, 
    generator_learning_rate=0.001, 
    optimizer='adam', 
    grad_weight=10, 
    z_dim=32, 
    batch_size=BATCH_SIZE, 
    n_tracks=n_tracks, 
    n_bars=n_bars, 
    n_steps_per_bar=n_steps_per_bar, 
    n_pitches=n_pitches
)

if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(RUN_FOLDER)


In [None]:
gan.chords_temp_network.summary()

In [None]:
gan.bar_gen[0].summary()

In [None]:
gan.generator.summary()

In [None]:
gan.critic.summary()

## 训练

In [None]:
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 100

gan.epoch = 0
gan.train(
    data_binary,
    batch_size=BATCH_SIZE, 
    epochs=EPOCHS, 
    run_folder=RUN_FOLDER, 
    print_every_n_batches=PRINT_EVERY_N_BATCHES
)


In [None]:
fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, len(gan.d_losses))
# plt.ylim(0, 2)
plt.savefig(os.path.join(RUN_FOLDER, 'images/loss.png'))

plt.show()
