# MidiNetモデルの学習

- DataLoaderの作成
- Modelの作成
- 学習コードの作成
- 学習経過の可視化

In [1]:
import os, ipdb, pickle
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision.utils as vutils
from pypianoroll import Multitrack, Track
from utils import grid_plot, Timer

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
base_dir = "../datasets/theorytab/midinet"
input_file_path = f"{base_dir}/midinet_natural.pkl"
output_dir = f"{base_dir}/learning"

## DataLoaderの作成

In [3]:
batch_size = 72

In [4]:
class MidinetDataloader():
    def __init__(self, data_path):
        data = pickle.load(open(data_path,'rb'))
        
        melody, prev, chord = [], [], []
        for m, p, c in data:
            melody.append(m)
            prev.append(p)
            chord.append(c)
        
        self.x = torch.from_numpy(np.array(melody)).float()
        self.prev_x   = torch.from_numpy(np.array(prev)).float()
        self.y  = torch.from_numpy(np.array(chord)).float()
        self.size = self.x.shape[0]

         # self.label = np.array(label)
    def __getitem__(self, index):
        return self.x[index], self.prev_x[index], self.y[index]

    def __len__(self):
        return self.size

In [5]:
def get_dataloader(data_path, batch_size=72, shuffle=True):
    iterator = MidinetDataloader(data_path)
    kwargs = {'num_workers': 4, 'pin_memory': True}
    data_loader = DataLoader(iterator, batch_size=batch_size, shuffle=shuffle, **kwargs)
    print('Data loading is completed.')
    print(f'{len(data_loader)} batches from {len(iterator)} bars are obtained.')
    return data_loader

In [6]:
data_loader = get_dataloader(input_file_path, batch_size=batch_size)

Data loading is completed.
663 batches from 47723 bars are obtained.


## Modelの作成

#### model用共通関数の作成

In [8]:
def concat_vector(x, y):
    x0, _, x_2, x_3 = x.shape
    y2 = y.expand(x_0, y.shape[1], x_2, x_3)
    return torch.cat((x, y2),1)
    

def batch_norm(x, eps=1e-05, momentum=0.9, affine=True):
    if x.ndim == 2:
        return nn.BatchNorm1d(x.shape[1], eps=eps, momentum=momentum, affine=affine).cuda()(x)
    elif x.ndim == 3:
        return nn.BatchNorm2d(x.shape[1], eps=eps, momentum=momentum, affine=affine).cuda()(x)
    else:
        return x


def lrelu(x, leak=0.2):
    z = torch.mul(x,leak)
    return torch.max(x, z)

#### generator
forwardの入力
- z (batch, noise_size) = (72, 113): ランダムノイズ
- prev_x (batch, ch, steps, pitch) = (72, 1, 16, 128): 前の小節
- y (batch, 13): コード，0~11次元はコードの主音，12次元目はmajorかminorかを区別する

forwardの出力
- g_x (batch, ch, steps, pitch)= (72, 1, 16, 128): 生成された今の小節
    
オリジナルのmidinetと同じ．詳しい説明はmidinet_understandingを参照

In [9]:
class generator(nn.Module):
    def __init__(self,pitch_range=128):
        super(generator, self).__init__()
        self.gf_dim  = 64
        self.y_dim   = 13

        self.h1      = nn.ConvTranspose2d(in_channels=157, out_channels=pitch_range, kernel_size=(2,1), stride=(2,2))
        self.h2      = nn.ConvTranspose2d(in_channels=157, out_channels=pitch_range, kernel_size=(2,1), stride=(2,2))
        self.h3      = nn.ConvTranspose2d(in_channels=157, out_channels=pitch_range, kernel_size=(2,1), stride=(2,2))
        self.h4      = nn.ConvTranspose2d(in_channels=157, out_channels=1, kernel_size=(1,pitch_range), stride=(1,2))

        self.h0_prev = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(1,pitch_range), stride=(1,2))
        self.h1_prev = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(2,1), stride=(2,2))
        self.h2_prev = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(2,1), stride=(2,2))
        self.h3_prev = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(2,1), stride=(2,2))

        self.linear1 = nn.Linear(113,1024)
        self.linear2 = nn.Linear(1037,self.gf_dim*2*2*1)

    def forward(self, z, prev_x, y ,batch_size, pitch_range):
        
        h0_prev = lrelu(batch_norm(self.h0_prev(prev_x)))   # 72, 16, 16, 1
        h1_prev = lrelu(batch_norm(self.h1_prev(h0_prev)))  # 72, 16, 8, 1
        h2_prev = lrelu(batch_norm(self.h2_prev(h1_prev)))  # 72, 16, 4, 1
        h3_prev = lrelu(batch_norm(self.h3_prev(h2_prev)))  # 72, 16, 2, 1

        yb = y.view(batch_size,  self.y_dim, 1, 1)          # 72, 13, 1, 1

        z = torch.cat((z,y),1)                              # 72, 113

        h0 = F.relu(batch_norm(self.linear1(z)))            # 72, 1024
        h0 = torch.cat((h0,y),1)   #(72,1037)

        h1 = F.relu(batch_norm(self.linear2(h0)))           # 72, 256
        h1 = h1.view(batch_size, self.gf_dim * 2, 2, 1)     # 72, 128, 2, 1
        h1 = concat_vector(h1, yb)                          # 72, 141, 2, 1
        h1 = concat_vector(h1, h3_prev)                     # 72, 157, 2, 1

        h2 = F.relu(batch_norm(self.h1(h1)))                # 72, 128, 4, 1
        h2 = concat_vector(h2, yb)                          # 72, 141, 4, 1
        h2 = concat_vector(h2, h2_prev)                     # 72, 157, 4, 1

        h3 = F.relu(batch_norm(self.h2(h2)))                # 72, 128, 8, 1 
        h3 = concat_vector(h3, yb)                          # 72, 141, 8, 1
        h3 = concat_vector(h3, h1_prev)                     # 72, 157, 8, 1

        h4 = F.relu(batch_norm(self.h3(h3)))                # 72, 128, 16, 1
        h4 = concat_vector(h4, yb)                          # 72, 141, 16, 1
        h4 = concat_vector(h4, h0_prev)                     # 72, 157, 16, 1

        g_x = torch.sigmoid(self.h4(h4))                    # 72, 1, 16, 128

        return g_x

#### Discriminator

forwardの入力
- x (batch, 1, steps, pitch) = (72, 1, 16, 128): real/fake判定を行う小節データ
- y (batch, 13) = (72, 13): コード

forwardの出力
- h3_sigmoid (batch, 1) = (72, 1): 0~1に押し込められたreal/fake判定結果．0はfake, 1はreal
- h3 (batch, 1) = (72, 1): 0~1に押し込められていないreal/fake判定結果
- fm (batch, 1+13, steps, pitch) = (72, 14, 16, 128): 特徴マップ．

オリジナルのmidinetと同じ．詳しい説明はmidinet_understandingを参照

In [10]:
class discriminator(nn.Module):
    def __init__(self,pitch_range=128):
        super(discriminator, self).__init__()

        self.df_dim = 64
        self.dfc_dim = 1024
        self.y_dim = 13
        
        # out channels = y_dim +1 
        self.h0_prev = nn.Conv2d(in_channels=14, out_channels=14, kernel_size=(2,pitch_range), stride=(2,2))

        # out channels = df_dim + y_dim
        self.h1_prev = nn.Conv2d(in_channels=27, out_channels=77, kernel_size=(4,1), stride=(2,2))
        self.linear1 = nn.Linear(244, self.dfc_dim)
        self.linear2 = nn.Linear(1037, 1)

    def forward(self, x, y, batch_size, pitch_range):        

        yb = y.view(batch_size,self.y_dim, 1, 1)
        x = concat_vector(x, yb)                    #72, 14, 16, 128
        
        h0 = lrelu(self.h0_prev(x))                 #72, 14, 8, 1
        fm = h0
        h0 = concat_vector(h0, yb)                  #72, 27, 8, 1

        h1 = lrelu(batch_norm(self.h1_prev(h0)))    #72, 77, 3, 1
        h1 = h1.view(batch_size, -1)                #72, 231
        h1 = torch.cat((h1,y), 1)                   #72, 244

        h2 = lrelu(batch_norm(self.linear1(h1)))
        h2 = torch.cat((h2,y), 1)                   #72, 1037

        h3 = self.linear2(h2)
        h3_sigmoid = torch.sigmoid(h3)


        return h3_sigmoid, h3, fm

## 学習コードの作成

#### 共通関数の作成

In [None]:
def sigmoid_cross_entropy_with_logits(inputs,labels):
    loss = nn.BCEWithLogitsLoss()
    output = loss(inputs, labels)
    return output

def reduce_mean(x):
    output = torch.mean(x,0, keepdim = False)
    output = torch.mean(output,-1, keepdim = False)
    return output


def reduce_mean_0(x):
    output = torch.mean(x,0, keepdim = False)
    return output


def l2_loss(x,y):
    loss_ = nn.MSELoss(reduction='sum')
    l2_loss_ = loss_(x, y)/2
    return l2_loss_


In [None]:
is_train = 1
is_draw = 0
is_sample = 0

epochs = 20
lr = 0.0002

check_range_st = 0
check_range_ed = 129
pitch_range = check_range_ed - check_range_st-1

device = torch.device('cuda')
train_loader = load_data()

# 機能1. 訓練を行う
if is_train == 1 :
    netG = generator(pitch_range).to(device)
    netD = discriminator(pitch_range).to(device)  

    netD.train()
    netG.train()
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999)) 

    batch_size = 72
    nz = 100
    fixed_noise = torch.randn(batch_size, nz, device=device)
    real_label = 1
    fake_label = 0
    average_lossD = 0
    average_lossG = 0
    average_D_x   = 0
    average_D_G_z = 0

    lossD_list =  []
    lossD_list_all = []
    lossG_list =  []
    lossG_list_all = []
    D_x_list = []
    D_G_z_list = []
    for epoch in range(epochs):
        sum_lossD = 0
        sum_lossG = 0
        sum_D_x   = 0
        sum_D_G_z = 0
        for i, (data,prev_data,chord) in enumerate(train_loader, 0):

            ############################
            # (1) Dの学習: log(D(x)) + log(1 - D(G(z))) を最大化
            #     realデータを1，fakeデータを0と判断させるよう学習
            ###########################

            # Dのrealデータに対する訓練

            # Dの勾配の初期化
            netD.zero_grad()

            # バッチ(譜面，前の譜面，コード)をdeviceに渡す  
            real_cpu = data.to(device)
            prev_data_cpu = prev_data.to(device)
            chord_cpu = chord.to(device)

            # 全てのデータがrealデータであるというラベルを作成
            # このlabel使ってなくない？
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size,), real_label, device=device)

            # Dへ本物データとコードを渡す
            D, D_logits, fm = netD(real_cpu,chord_cpu,batch_size,pitch_range)

            # realに対して0.9をラベルとしたsigmoid_cross_entropy_with_logits誤差の平均を得る
            # なぜ0.9? Dを弱くしたかった？
            d_loss_real = reduce_mean(sigmoid_cross_entropy_with_logits(D_logits, 0.9*torch.ones_like(D)))

            # 誤差逆伝搬
            # retain_graph: 計算グラフの維持．Falseならメモリ節約になるが勾配情報が消えてしまう
            d_loss_real.backward(retain_graph=True)

            # realデータに対するDの誤差の記録
            D_x = D.mean().item()
            sum_D_x += D_x 



            # Dのfakeデータに対する訓練

            # ノイズベクトルの作成
            noise = torch.randn(batch_size, nz, device=device)

            # Gにノイズベクトル，前の譜面，コードを渡し，fakeデータを作成
            fake = netG(noise,prev_data_cpu,chord_cpu,batch_size,pitch_range)

            # すべてのデータがrealデータであるというラベルを作成
            # このlabel使ってなくない？
            label.fill_(fake_label)

            # Dへfakeデータとコードを渡す
            D_, D_logits_, fm_ = netD(fake.detach(),chord_cpu,batch_size,pitch_range)

            # fakeに対して0をラベルとしたsigmoid_cross_entropy_with_logits誤差の平均を得る
            d_loss_fake = reduce_mean(sigmoid_cross_entropy_with_logits(D_logits_, torch.zeros_like(D_)))

            # 誤差逆伝搬
            d_loss_fake.backward(retain_graph=True)
            D_G_z1 = D_.mean().item() # fakeへのDのロス．記録しないが表示する

            # Dの誤差の記録
            errD = d_loss_real + d_loss_fake
            errD = errD.item()
            lossD_list_all.append(errD)
            sum_lossD += errD

            # Dの勾配からパラメータを更新
            optimizerD.step()




            ############################
            # (2) Gの学習(1) : log(D(G(z)))を最大化
            #     fakeデータを1と判断させるよう学習
            ###########################

            # Gの勾配の初期化
            netG.zero_grad()

            # GはDにfakeデータに対して1を出力してもらいたいのでラベルを逆転
            # でもこのlabel使ってなくない？
            label.fill_(real_label)

            # 先ほど作ったfakeデータ，コードの情報をDへ渡す
            # Gはもう一度データを作らなくていいのか？
            D_, D_logits_, fm_= netD(fake,chord_cpu,batch_size,pitch_range)

            # fakeに対して1をラベルとしたsigmoid_cross_entropy_with_logits誤差の平均を得る
            g_loss0 = reduce_mean(sigmoid_cross_entropy_with_logits(D_logits_, torch.ones_like(D_)))

            # Dの特徴マッチング：realとfakeでnetDの初段のreluの出力が近くなるようにする
            features_from_g = reduce_mean_0(fm_) # fakeデータに対するDのfeatureの平均値
            features_from_i = reduce_mean_0(fm)  # realデータに対するDのfeatureの平均値
            # fakeとrealの出すfeatureの違いが大きいほどペナルティを与える
            fm_g_loss1 =torch.mul(l2_loss(features_from_g, features_from_i), 0.1)

            # Gの特徴マッチング：Gがrealに近いデータを生成できるようにする
            mean_image_from_g = reduce_mean_0(fake)      # fakeデータの平均値
            smean_image_from_i = reduce_mean_0(real_cpu) # realデータの平均値
            # fakeデータとrealデータの違いが大きいほどペナルティを与える
            fm_g_loss2 = torch.mul(l2_loss(mean_image_from_g, smean_image_from_i), 0.01)

            # Gの誤差の記録(listへは追加しない)
            errG = g_loss0 + fm_g_loss1 + fm_g_loss2

            # 誤差逆伝搬
            errG.backward(retain_graph=True)
            D_G_z2 = D_.mean().item() # fakeへのDのロス．記録しないが表示する

            # Gの勾配からパラメータを更新
            optimizerG.step()



            ############################
            # (3) Gの学習(2) : log(D(G(z)))を再び最大化
            #     Gの学習(1)と同じ
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            D_, D_logits_, fm_ = netD(fake,chord_cpu,batch_size,pitch_range)

            ###loss
            g_loss0 = reduce_mean(sigmoid_cross_entropy_with_logits(D_logits_, torch.ones_like(D_)))
            #Feature Matching
            features_from_g = reduce_mean_0(fm_)
            features_from_i = reduce_mean_0(fm)
            loss_ = nn.MSELoss(reduction='sum') # 書き方は変わっているが，opts化したのを忘れている模様
            feature_l2_loss = loss_(features_from_g, features_from_i)/2
            fm_g_loss1 =torch.mul(feature_l2_loss, 0.1)

            mean_image_from_g = reduce_mean_0(fake)
            smean_image_from_i = reduce_mean_0(real_cpu)
            mean_l2_loss = loss_(mean_image_from_g, smean_image_from_i)/2
            fm_g_loss2 = torch.mul(mean_l2_loss, 0.01)
            errG = g_loss0 + fm_g_loss1 + fm_g_loss2
            sum_lossG +=errG
            errG.backward()
            lossG_list_all.append(errG.item()) # 2回目のGの学習ではロスの記録を行う

            D_G_z2 = D_.mean().item()
            sum_D_G_z += D_G_z2
            optimizerG.step()


            # 5エポックごとにロスの状況を表示
            # epochではなくiでは？バッチごとに出力されてしまわないか？
            if epoch % 5 == 0:
                print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                      % (epoch, epochs, i, len(train_loader),
                         errD, errG, D_x, D_G_z1, D_G_z2))

            # realデータとfakeデータを比較できるよう画像で保存
            if i % 100 == 0:
                vutils.save_image(real_cpu,
                        '%s/real_samples.png' % 'file',
                        normalize=True)
                fake = netG(fixed_noise,prev_data_cpu,chord_cpu,batch_size,pitch_range)
                vutils.save_image(fake.detach(),
                        '%s/fake_samples_epoch_%03d.png' % ('file', epoch),
                        normalize=True)

        # エポックごとの誤差の記録
        average_lossD = (sum_lossD / len(train_loader.dataset))
        average_lossG = (sum_lossG / len(train_loader.dataset))
        average_D_x = (sum_D_x / len(train_loader.dataset))
        average_D_G_z = (sum_D_G_z / len(train_loader.dataset))

        lossD_list.append(average_lossD)
        lossG_list.append(average_lossG)            
        D_x_list.append(average_D_x)
        D_G_z_list.append(average_D_G_z)

        print('==> Epoch: {} Average lossD: {:.10f} average_lossG: {:.10f},average D(x): {:.10f},average D(G(z)): {:.10f} '.format(
          epoch, average_lossD,average_lossG,average_D_x, average_D_G_z)) 

    # 記録の保存
    np.save('lossD_list.npy',lossD_list)
    np.save('lossG_list.npy',lossG_list)
    np.save('lossD_list_all.npy',lossD_list_all)
    np.save('lossG_list_all.npy',lossG_list_all)
    np.save('D_x_list.npy',D_x_list)
    np.save('D_G_z_list.npy',D_G_z_list)

    # モデルの保存
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % ('../models', epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % ('../models', epoch))



# 機能2. 誤差のグラフを作成する
if is_draw == 1:
    lossD_print = np.load('lossD_list.npy')
    lossG_print = np.load('lossG_list.npy')
    length = lossG_print.shape[0]

    x = np.linspace(0, length-1, length)
    x = np.asarray(x)
    plt.figure()
    plt.plot(x, lossD_print,label=' lossD',linewidth=1.5)
    plt.plot(x, lossG_print,label=' lossG',linewidth=1.5)

    plt.legend(loc='upper right')
    plt.xlabel('data')
    plt.ylabel('loss')
    plt.savefig('where you want to save/lr='+ str(lr) +'_epoch='+str(epochs)+'.png')



# 機能3. サンプルを作成する
if is_sample == 1:
    batch_size = 8
    nz = 100
    n_bars = 7

    # データの取得
    X_te = np.load('your testing x') # 最初の小節
    prev_X_te = np.load('your testing prev x') # 前の小節
    prev_X_te = prev_X_te[:,:,check_range_st:check_range_ed,:]
    y_te    = np.load('yourd chord') # コード

    # DataLoaderの準備
    test_iter = get_dataloader(X_te,prev_X_te,y_te)
    kwargs = {'num_workers': 4, 'pin_memory': True}# if args.cuda else {}
    test_loader = DataLoader(test_iter, batch_size=batch_size, shuffle=False, **kwargs)

    # サンプル生成用のGを用意し，訓練済みパラメータを読み込ませる
    netG = sample_generator()
    netG.load_state_dict(torch.load('your model'))

    # サンプルの生成ループ
    output_songs = []
    output_chords = []
    for i, (data,prev_data,chord) in enumerate(test_loader, 0):
        list_song = []
        first_bar = data[0].view(1,1,16,128)
        list_song.append(first_bar)

        list_chord = []
        first_chord = chord[0].view(1,13).numpy()
        list_chord.append(first_chord)
        noise = torch.randn(batch_size, nz)

        # 小節生成ループ
        for bar in range(n_bars):
            z = noise[bar].view(1,nz)
            y = chord[bar].view(1,13)

            if bar == 0:
                # 最初の小節はrealデータを使う
                prev = data[0].view(1,1,16,128)
            else:
                # 2小節目からは前の小節を条件にする
                prev = list_song[bar-1].view(1,1,16,128)

            # ランダムノイズを基に，前の小節と今のコードを条件として渡して，今の小節を生成
            sample = netG(z, prev, y, 1,pitch_range)

            # 小節を記録
            list_song.append(sample)
            list_chord.append(y.numpy())

        # 生成された曲を記録
        print('num of output_songs: {}'.format(len(output_songs)))
        output_songs.append(list_song)
        output_chords.append(list_chord)

    # 生成された曲の保存
    np.save('output_songs.npy',np.asarray(output_songs))
    np.save('output_chords.npy',np.asarray(output_chords))

    print('creation completed, check out what I make!')