## MuseGAN 훈련

### 라이브러리 임포트

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import types
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms

from MuseGAN import Generator, Critic
from utils import load_music

from music21 import midi
from music21 import note, stream, duration

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
data_folder = '../data/chorales'
file_name = 'Jsb16thSeparated.npz'
output_folder = './output/muse_gan'
image_save_folder = './images/muse_gan'
model_save_path = './muse_gan.pth'

os.makedirs(output_folder, exist_ok=True)
os.makedirs(image_save_folder, exist_ok=True)

mode = 'build'
# mode = 'load'

### 데이터 적재

In [None]:
batch_size = 64
n_bars = 2              # 마디 개수
n_steps_per_bar = 16    # 한 마디의 16분음표 개수
n_pitches = 84          # 음의 범위
n_tracks = 4            # 성부의 개수
z_dim = 32

data_binary, data_ints, raw_data = load_music(data_folder, file_name, n_bars, n_steps_per_bar)
data_binary = np.squeeze(data_binary)
print(data_binary.shape)

In [None]:
class MyDataset(Dataset):
    def __init__(self, data_binary):
        self.data_binary = torch.FloatTensor(data_binary)
        self.data_len = len(self.data_binary)
        
    def __getitem__(self, idx):
        return self.data_binary[idx]
    
    def __len__(self):
        return self.data_len

In [None]:
dataset = MyDataset(data_binary)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
dataset_size = len(dataset)

### 모델 만들기

In [None]:
generator = Generator(z_dim, n_tracks, n_bars, n_steps_per_bar, n_pitches).to(device)
critic = Critic(in_channels=n_tracks, n_bars=n_bars).to(device)

if mode == 'load':
    loaded_model = torch.load(model_save_path)
    generator.load_state_dict(loaded_model['Generator'])
    critic.load_state_dict(loaded_model['Critic'])

### 모델 훈련

In [None]:
num_epochs = 500
n_critic = 5
gradient_weight = 10
g_learning_rate = 1e-3
c_learning_rate = 1e-3
g_optimizer = optim.Adam(generator.parameters(), lr=g_learning_rate)
c_optimizer = optim.Adam(critic.parameters(), lr=c_learning_rate)

In [None]:
def gradient_penalty_loss(discriminator, real_images, fake_images):
    alpha = torch.rand((real_images.size(0), 1, 1, 1, 1), device=device).requires_grad_(True)
    # alpha = alpha.expand_as(real_images)
    interpolates = (alpha * real_images.data + ((1 - alpha) * fake_images.data)).requires_grad_(True)
    
    model_interpolates = discriminator(interpolates)
    grad_outputs = torch.ones(model_interpolates.size(), device=device, requires_grad=False)
    
    gradients = torch.autograd.grad(
        outputs=model_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1) ** 2)
    return gradient_penalty

In [None]:
def wasserstein_loss(y_pred, y_target):
    return -torch.mean(y_pred * y_target)

In [None]:
g_losses = []
c_losses = []
c_losses_real = []
c_losses_fake = []
grad_penalty_losses = []

for epoch in range(num_epochs):
    epoch_c_loss = 0.0
    epoch_c_loss_real = 0.0
    epoch_c_loss_fake = 0.0
    epoch_grad_penalty_loss = 0.0
    epoch_g_loss = 0.0
    
    num_inputs = 0
    num_G_inputs = 0
    
    for i, inputs in enumerate(dataloader):
        # 진짜 데이터로 학습
        inputs = inputs.to(device)
        
        for _ in range(n_critic):
            critic.zero_grad()
            
            output = critic(inputs)
            # c_loss_real = -output.mean().view(-1)
            c_loss_real = wasserstein_loss(output, torch.ones_like(output))
            
            # 가짜 데이터로 학습
            chords_noise = torch.randn(inputs.size(0), z_dim, device=device)
            style_noise = torch.randn(inputs.size(0), z_dim, device=device)
            melody_noise = torch.randn(inputs.size(0), n_tracks, z_dim, device=device)
            groove_noise = torch.randn(inputs.size(0), n_tracks, z_dim, device=device)
            
            with torch.no_grad():
                fake = generator(chords_noise, style_noise, melody_noise, groove_noise)
            output = critic(fake)
            # c_loss_fake = output.mean().view(-1)
            c_loss_fake = wasserstein_loss(output, -torch.ones_like(output))
            
            # Gradient Penalty
            gradient_penalty = gradient_penalty_loss(critic, inputs, fake) * gradient_weight
            
            c_loss = c_loss_real + c_loss_fake + gradient_penalty
            c_loss.backward(retain_graph=True)
            c_optimizer.step()
            
            epoch_c_loss += c_loss.item() * inputs.size(0)
            epoch_c_loss_real += c_loss_real.item() * inputs.size(0)
            epoch_c_loss_fake += c_loss_fake.item() * inputs.size(0)
            epoch_grad_penalty_loss += gradient_penalty.item() * inputs.size(0)
        
            num_inputs += inputs.size(0)
        
        # if (i + 1) % n_critic == 0 or ((i + 1) == len(dataloader) and len(dataloader) < n_critic):
        for p in critic.parameters():
            p.requires_grad = False
        
        # Generator
        generator.zero_grad()
        chords_noise = torch.randn(batch_size, z_dim, device=device)
        style_noise = torch.randn(batch_size, z_dim, device=device)
        melody_noise = torch.randn(batch_size, n_tracks, z_dim, device=device)
        groove_noise = torch.randn(batch_size, n_tracks, z_dim, device=device)
        g_output = generator(chords_noise, style_noise, melody_noise, groove_noise)
        output = critic(g_output)
        # g_loss = -output.mean().view(-1)
        g_loss = wasserstein_loss(output, torch.ones_like(output))
        g_loss.backward()
        g_optimizer.step()
            
        epoch_g_loss += g_loss.item() * batch_size
        
        num_G_inputs += batch_size
        
        for p in critic.parameters():
            p.requires_grad = True
                
    epoch_c_loss /= num_inputs
    epoch_c_loss_real /= num_inputs
    epoch_c_loss_fake /= num_inputs
    epoch_grad_penalty_loss /= num_inputs
    epoch_g_loss /= num_G_inputs
    
    c_losses.append(epoch_c_loss)
    c_losses_real.append(epoch_c_loss_real)
    c_losses_fake.append(epoch_c_loss_fake)
    grad_penalty_losses.append(epoch_grad_penalty_loss)
    g_losses.append(epoch_g_loss)
    
    print('%d [C loss: (%.4f)(R %.4f, F %.4f, G %.4f)] [G loss: %.4f]' %
            (epoch + 1, epoch_c_loss, epoch_c_loss_real, epoch_c_loss_fake, epoch_grad_penalty_loss, epoch_g_loss))

In [None]:
# 모델 저장
models = {
    'Generator': generator.state_dict(),
    'Critic': critic.state_dict()
}
torch.save(models, model_save_path)

In [None]:
fig = plt.figure(figsize=(20, 10))

plt.plot([x for x in c_losses], color='black', linewidth=1)
plt.plot([x for x in c_losses_real], color='green', linewidth=1)
plt.plot([x for x in c_losses_fake], color='red', linewidth=1)
plt.plot([x for x in g_losses], color='orange', linewidth=1)

plt.xlabel('epoch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.savefig(os.path.join(image_save_folder, 'train_loss_graph.png'))