# MidiNetモデルの学習

- DataLoaderの作成
- Modelの作成
- 学習コードの作成
- 学習経過の可視化

In [1]:
import os, ipdb, pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from pypianoroll import Multitrack, Track
from utils import grid_plot, Timer

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
base_dir = "../datasets/theorytab/midinet"
output_dir = f"{base_dir}/learning"

if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

## DataLoaderの作成

In [3]:
class MidinetDataloader():
    def __init__(self, data_path, pitch_range=[0,128], show_shape=False):
        data = pickle.load(open(data_path,'rb'))
        
        melody, prev, chord = [], [], []
        for m, p, c in data:
            melody.append(m)
            prev.append(p)
            chord.append(c)
        
        self.size = len(melody)
        steps = len(melody[0])
        bottom, top = pitch_range
        
        melody = np.array(melody)[:,:,bottom:top].reshape(self.size, 1, steps, top-bottom)
        prev = np.array(prev)[:,:,bottom:top].reshape(self.size, 1, steps, top-bottom)
        chord = np.array(chord)
        
        if show_shape:
            print("melody shape", melody.shape)
            print("prev shape", prev.shape)
            print("chord shape", chord.shape)
        
        self.x = torch.from_numpy(melody).float()
        self.prev_x   = torch.from_numpy(prev).float()
        self.y  = torch.from_numpy(chord).float()

    def __getitem__(self, index):
        return self.x[index], self.prev_x[index], self.y[index]

    def __len__(self):
        return self.size

In [4]:
def get_dataloader(data_path, batch_size=72, shuffle=True):
    iterator = MidinetDataloader(data_path)
    kwargs = {'num_workers': 4, 'pin_memory': True}
    data_loader = DataLoader(iterator, batch_size=batch_size, shuffle=shuffle, **kwargs)
    print('Data loading is completed.')
    print(f'{len(data_loader)} batches from {len(iterator)} bars are obtained.')
    return data_loader

In [5]:
def to_device(data):
    if not isinstance(data, torch.Tensor):
        data = torch.from_numpy(data)
    if torch.cuda.is_available():
        return data.cuda()
    return data.cpu()

## Modelの作成

#### model用共通関数の作成

In [6]:
def concat_vector(x, y):
    x_0, _, x_2, x_3 = x.shape
    y2 = y.expand(x_0, y.shape[1], x_2, x_3)
    return torch.cat((x, y2),1)
    

def batch_norm(x, eps=1e-05, momentum=0.9, affine=True):
    if x.ndim == 2:
        return nn.BatchNorm1d(x.shape[1], eps=eps, momentum=momentum, affine=affine).cuda()(x)
    elif x.ndim == 3:
        return nn.BatchNorm2d(x.shape[1], eps=eps, momentum=momentum, affine=affine).cuda()(x)
    else:
        return x


def lrelu(x, leak=0.2):
    z = torch.mul(x,leak)
    return torch.max(x, z)

#### Generator
forwardの入力
- z (batch, noise_size) = (72, 113): ランダムノイズ
- prev_x (batch, ch, steps, pitch) = (72, 1, 16, 128): 前の小節
- y (batch, 13): コード，0~11次元はコードの主音，12次元目はmajorかminorかを区別する

forwardの出力
- g_x (batch, ch, steps, pitch)= (72, 1, 16, 128): 生成された今の小節
    
オリジナルのmidinetと同じ．詳しい説明はmidinet_understandingを参照

In [7]:
class Generator(nn.Module):
    def __init__(self,pitch_range=128):
        super(Generator, self).__init__()
        self.gf_dim  = 64
        self.y_dim   = 13

        self.h1      = nn.ConvTranspose2d(in_channels=157, out_channels=pitch_range, kernel_size=(2,1), stride=(2,2))
        self.h2      = nn.ConvTranspose2d(in_channels=157, out_channels=pitch_range, kernel_size=(2,1), stride=(2,2))
        self.h3      = nn.ConvTranspose2d(in_channels=157, out_channels=pitch_range, kernel_size=(2,1), stride=(2,2))
        self.h4      = nn.ConvTranspose2d(in_channels=157, out_channels=1, kernel_size=(1,pitch_range), stride=(1,2))

        self.h0_prev = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(1,pitch_range), stride=(1,2))
        self.h1_prev = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(2,1), stride=(2,2))
        self.h2_prev = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(2,1), stride=(2,2))
        self.h3_prev = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(2,1), stride=(2,2))

        self.linear1 = nn.Linear(113,1024)
        self.linear2 = nn.Linear(1037,self.gf_dim*2*2*1)

    def forward(self, z, prev_x, y, batch_size):
        
        h0_prev = lrelu(batch_norm(self.h0_prev(prev_x)))   # 72, 16, 16, 1
        h1_prev = lrelu(batch_norm(self.h1_prev(h0_prev)))  # 72, 16, 8, 1
        h2_prev = lrelu(batch_norm(self.h2_prev(h1_prev)))  # 72, 16, 4, 1
        h3_prev = lrelu(batch_norm(self.h3_prev(h2_prev)))  # 72, 16, 2, 1

        yb = y.view(batch_size, self.y_dim, 1, 1)           # 72, 13, 1, 1

        z = torch.cat((z,y),1)                              # 72, 113

        h0 = F.relu(batch_norm(self.linear1(z)))            # 72, 1024
        h0 = torch.cat((h0,y),1)                            # 72, 1037

        h1 = F.relu(batch_norm(self.linear2(h0)))           # 72, 256
        h1 = h1.view(batch_size, self.gf_dim * 2, 2, 1)     # 72, 128, 2, 1
        h1 = concat_vector(h1, yb)                          # 72, 141, 2, 1
        h1 = concat_vector(h1, h3_prev)                     # 72, 157, 2, 1

        h2 = F.relu(batch_norm(self.h1(h1)))                # 72, 128, 4, 1
        h2 = concat_vector(h2, yb)                          # 72, 141, 4, 1
        h2 = concat_vector(h2, h2_prev)                     # 72, 157, 4, 1

        h3 = F.relu(batch_norm(self.h2(h2)))                # 72, 128, 8, 1 
        h3 = concat_vector(h3, yb)                          # 72, 141, 8, 1
        h3 = concat_vector(h3, h1_prev)                     # 72, 157, 8, 1

        h4 = F.relu(batch_norm(self.h3(h3)))                # 72, 128, 16, 1
        h4 = concat_vector(h4, yb)                          # 72, 141, 16, 1
        h4 = concat_vector(h4, h0_prev)                     # 72, 157, 16, 1

        g_x = torch.sigmoid(self.h4(h4))                    # 72, 1, 16, 128

        return g_x

#### Discriminator
forwardの入力
- x (batch, 1, steps, pitch) = (72, 1, 16, 128): real/fake判定を行う小節データ
- y (batch, 13) = (72, 13): コード

forwardの出力
- h3_sigmoid (batch, 1) = (72, 1): 0~1に押し込められたreal/fake判定結果．0はfake, 1はreal
- h3 (batch, 1) = (72, 1): 0~1に押し込められていないreal/fake判定結果
- fm (batch, 1+13, steps, pitch) = (72, 14, 16, 128): 特徴マップ．

オリジナルのmidinetと同じ．詳しい説明はmidinet_understandingを参照

In [8]:
class Discriminator(nn.Module):
    def __init__(self,pitch_range=128):
        super(Discriminator, self).__init__()

        self.df_dim = 64
        self.dfc_dim = 1024
        self.y_dim = 13
        
        # out channels = y_dim +1 
        self.h0_prev = nn.Conv2d(in_channels=14, out_channels=14, kernel_size=(2,pitch_range), stride=(2,2))

        # out channels = df_dim + y_dim
        self.h1_prev = nn.Conv2d(in_channels=27, out_channels=77, kernel_size=(4,1), stride=(2,2))
        self.linear1 = nn.Linear(244, self.dfc_dim)
        self.linear2 = nn.Linear(1037, 1)

    def forward(self, x, y, batch_size):        

        yb = y.view(batch_size,self.y_dim, 1, 1)
        x = concat_vector(x, yb)                    #72, 14, 16, 128
        
        h0 = lrelu(self.h0_prev(x))                 #72, 14, 8, 1
        fm = h0
        h0 = concat_vector(h0, yb)                  #72, 27, 8, 1

        h1 = lrelu(batch_norm(self.h1_prev(h0)))    #72, 77, 3, 1
        h1 = h1.view(batch_size, -1)                #72, 231
        h1 = torch.cat((h1,y), 1)                   #72, 244

        h2 = lrelu(batch_norm(self.linear1(h1)))
        h2 = torch.cat((h2,y), 1)                   #72, 1037

        h3 = self.linear2(h2)
        h3_sigmoid = torch.sigmoid(h3)


        return h3_sigmoid, h3, fm

## 学習コードの作成: original

In [9]:
input_file_path = os.path.join(base_dir, "midinet_original.pkl")
save_dir = os.path.join(output_dir, "original")

if not os.path.isdir(save_dir):
    os.mkdir(save_dir)

save_npa = lambda file_name, npa: np.save(os.path.join(save_dir, file_name), npa)

### ハイパーパラメータの設定

In [10]:
epochs = 20
batch_size = 512
generator_train_times = 2
sample_bar_num = 16 # あとで実装

# for Adam
lr = 0.0002
betas = (0.5, 0.999)

# noise vector size
nz = 100

# a coefficient to real label for discriminator. 0 ~ 1
real_data_worthiness = 0.9

# feature matching coefficients
lambda_1, lambda_2 = 0.1, 0.01 # D, G

### 学習初期化処理  

In [11]:
data_loader = get_dataloader(input_file_path, batch_size=batch_size)
data_size = len(data_loader.dataset)

Data loading is completed.
279 batches from 142740 bars are obtained.


In [12]:
netD = Discriminator()
netG = Generator()

if torch.cuda.is_available():
    netD = netD.cuda()
    netG = netG.cuda()

netD.train()
netG.train()

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=betas)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=betas)

noise_for_sample = to_device(torch.randn(sample_bar_num, nz))

realD_list, fakeD_list = [], [] # Dのrealデータとfakeデータに対するエポックごとの識別結果平均(realに対しては1に近く，fakeに対しては0に近い方がDが強い)
lossD_list, lossG_list = [], [] # D, Gのロスのエポックごとの誤差

### 学習ループ
オリジナルのコードを若干書き換え  
ノイズベクトルは毎回作り直すことにしてみる

In [13]:
for epoch in range(epochs):
    sum_lossD, sum_lossG = 0, 0
    sum_realD, sum_fakeD = 0, 0
    
    print(f"start epoch {epoch} / {epochs}")
    with Timer():
        for i, (real, prev, chord) in enumerate(data_loader):

            # バッチ(譜面，前の譜面，コード)をdeviceに渡す  
            real, prev, chord = [to_device(item) for item in [real, prev, chord]]

            # batchの切れ端はサイズが異なる場合があるので注意
            batch_size = real.size(0)

            ############################
            # Dの学習: log(D(x)) + log(1 - D(G(z))) を最大化
            # realデータを1，fakeデータを0と判断させるよう学習
            ###########################

            # Dの勾配の初期化
            netD.zero_grad()

            # realに対する識別結果からクロスエントロピー誤差(目的関数)の値を得る
            d_real, d_logits_real, fm_real = netD(real, chord, batch_size)
            d_real_label = real_data_worthiness * torch.ones_like(d_real)
            d_loss_real = nn.BCEWithLogitsLoss()(d_logits_real, d_real_label) # (72, 1), (72, 1) => scalar tensor

            # Gにノイズベクトル，前の譜面，コードを渡し，fakeデータを作成
            noise = to_device(torch.randn(batch_size, nz))
            fake = netG(noise, prev, chord, batch_size)

            # fakeに対する識別結果からクロスエントロピー誤差(目的関数)の値を得る
            d_fake, d_logits_fake, fm_fake = netD(fake.detach(), chord, batch_size)
            d_fake_label = torch.zeros_like(d_fake)
            d_loss_fake = nn.BCEWithLogitsLoss()(d_logits_fake, d_fake_label) # (72, 1), (72, 1) => scalar tensor

            # 誤差逆伝搬により勾配を更新し，それに基づきDのパラメータを更新する
            lossD = d_loss_real + d_loss_fake
            lossD.backward(retain_graph=True)
            optimizerD.step()

            # 学習記録
            # real, fakeデータに対してそれぞれrealだと識別した割合
            realD, fakeD = d_real.mean().item(), d_fake.mean().item()
            sum_realD += realD
            sum_fakeD += fakeD
            sum_lossD += lossD.item() # Dの学習におけるLoss


            ############################
            # Gの学習 : log(D(G(z)))を最大化
            # fakeデータを1と判断させるよう学習
            ###########################

            for t in range(generator_train_times):

                # Gの勾配の初期化
                netG.zero_grad()

                # Gにノイズベクトル，前の譜面，コードを渡し，fakeデータを作成
                noise = to_device(torch.randn(batch_size, nz))
                fake = netG(noise, prev, chord, batch_size)
                
                # fakeに対して1をラベルとした識別結果からクロスエントロピー誤差(目的関数)の値を得てGの誤差とする
                d_fake, d_logits_fake, fm_fake = netD(fake, chord, batch_size)
                deceive_label = torch.ones_like(d_fake)
                g_loss = nn.BCEWithLogitsLoss()(d_logits_fake, deceive_label) # (72, 1), (72, 1) => scalar tensor

                # Dの特徴マッチング：realとfakeでnetDの初段のreluの出力が近くなるようにする
                features_from_g = torch.mean(fm_fake, 0) # fakeデータに対するDのfeatureの平均値
                features_from_i = torch.mean(fm_real, 0) # realデータに対するDのfeatureの平均値
                # fakeとrealの出すfeatureの違いが大きいほどペナルティを与える
                fm_g_loss1 = nn.MSELoss(reduction='sum')(features_from_g, features_from_i) / 2
                fm_g_loss1 = torch.mul(fm_g_loss1, lambda_1)

                # Gの特徴マッチング：Gがrealに近いデータを生成できるようにする
                mean_image_from_g = torch.mean(fake, 0) # fakeデータの平均値
                mean_image_from_i = torch.mean(real, 0) # realデータの平均値
                # fakeデータとrealデータの違いが大きいほどペナルティを与える
                fm_g_loss2 = nn.MSELoss(reduction='sum')(mean_image_from_g, mean_image_from_i) / 2
                fm_g_loss2 = torch.mul(fm_g_loss2, lambda_2)

                # 誤差逆伝搬により勾配を更新し，それに基づきGのパラメータを更新する
                lossG = g_loss + fm_g_loss1 + fm_g_loss2
                lossG.backward(retain_graph=(t < generator_train_times - 1)) # 最後は計算グラフを放棄
                optimizerG.step()

            # 学習記録
            sum_lossG += lossG.item() # Gの学習におけるLoss
    

        # エポックごとの識別と誤差の記録
        realD_list.append(sum_realD / data_size)
        fakeD_list.append(sum_fakeD / data_size)
        lossD_list.append(sum_lossD / data_size)
        lossG_list.append(sum_lossG / data_size)
        print(f'==> avg lossD: {lossD_list[-1]:.4f} avg lossG: {lossG_list[-1]:.4f}, avg realD: {realD_list[-1]:.4f}, avg fakeD: {fakeD_list[-1]:.4f} ')
    
    # 5エポックごとに具体的なロスと識別の状況を報告し，生成データを画像で記録
    if epoch % 5 == 0:
        print(f'[epoch {epoch}] loss D: {lossD:.4f} loss G: {lossG:.4f} real D: {realD:.4f} fake D: {fakeD:.4f}')
        sample_fake = netG(noise_for_sample, prev[:sample_bar_num], chord[:sample_bar_num], sample_bar_num).detach()
        fake_file_name = f'fake_samples_epoch{epoch:03}.png'
        vutils.save_image(sample_fake, os.path.join(save_dir, fake_file_name), normalize=True)

start epoch 0 / 20
==> avg lossD: 0.0025 avg lossG: 0.0020, avg realD: 0.0009, avg fakeD: 0.0008 
13.466941
[epoch 0] loss D: 1.0889 loss G: 1.2906 real D: 0.6297 fake D: 0.4298
start epoch 1 / 20
==> avg lossD: 0.0020 avg lossG: 0.0024, avg realD: 0.0011, avg fakeD: 0.0006 
13.381445
start epoch 2 / 20
==> avg lossD: 0.0019 avg lossG: 0.0026, avg realD: 0.0012, avg fakeD: 0.0006 
13.457835
start epoch 3 / 20
==> avg lossD: 0.0021 avg lossG: 0.0023, avg realD: 0.0011, avg fakeD: 0.0007 
13.556808
start epoch 4 / 20
==> avg lossD: 0.0022 avg lossG: 0.0022, avg realD: 0.0011, avg fakeD: 0.0007 
13.210067
start epoch 5 / 20
==> avg lossD: 0.0017 avg lossG: 0.0029, avg realD: 0.0012, avg fakeD: 0.0005 
13.417396
[epoch 5] loss D: 0.6689 loss G: 1.7631 real D: 0.7298 fake D: 0.2131
start epoch 6 / 20
==> avg lossD: 0.0017 avg lossG: 0.0031, avg realD: 0.0013, avg fakeD: 0.0005 
11.883087
start epoch 7 / 20
==> avg lossD: 0.0021 avg lossG: 0.0023, avg realD: 0.0011, avg fakeD: 0.0007 
13.566

Exception in thread Thread-15:
Traceback (most recent call last):
  File "/root/.pyenv/versions/3.7.3/lib/python3.7/threading.py", line 917, in _bootstrap_inner
    self.run()
  File "/root/.pyenv/versions/3.7.3/lib/python3.7/threading.py", line 865, in run
    self._target(*self._args, **self._kwargs)
  File "/root/midinet-followup/.venv/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py", line 21, in _pin_memory_loop
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/root/.pyenv/versions/3.7.3/lib/python3.7/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/root/midinet-followup/.venv/lib/python3.7/site-packages/torch/multiprocessing/reductions.py", line 284, in rebuild_storage_fd
    fd = df.detach()
  File "/root/.pyenv/versions/3.7.3/lib/python3.7/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/root/.pyenv/versions/3.7.3/lib/python3.7/mu

RuntimeError: DataLoader worker (pid(s) 3543) exited unexpectedly

モデルと記録の保存

In [None]:
save_npa('realD_list.npy', realD_list)
save_npa('fakeD_list.npy', fakeD_list)
save_npa('lossD_list.npy', lossD_list)
save_npa('lossG_list.npy', lossG_list)
torch.save(netG.state_dict(), os.path.join(save_dir, f'netG_epoch_{epoch}.pth'))
torch.save(netD.state_dict(), os.path.join(save_dir, f'netD_epoch_{epoch}.pth')

### 誤差グラフの表示

In [None]:
lossD_print = np.load(os.path.join(save_dir, 'lossD_list.npy'))
lossG_print = np.load(os.path.join(save_dir, 'lossG_list.npy'))

length = lossG_print.shape[0]
x = np.asarray(np.linspace(0, length-1, length))
plt.figure()
plt.plot(x, lossD_print,label=' lossD',linewidth=1.5)
plt.plot(x, lossG_print,label=' lossG',linewidth=1.5)
plt.legend(loc='upper right')
plt.xlabel('data')
plt.ylabel('loss')
plt.savefig(os.path.join(save_dir, f'loss_graph_lr={lr}_epoch={epochs}.png'))
plt.show()

### サンプルの作成

In [None]:
batch_size = 8
nz = 100
n_bars = 7

In [None]:
# データの取得
X_te = np.load('your testing x') # 最初の小節
prev_X_te = np.load('your testing prev x') # 前の小節
prev_X_te = prev_X_te[:,:,check_range_st:check_range_ed,:]
y_te    = np.load('yourd chord') # コード

# DataLoaderの準備
test_iter = get_dataloader(X_te,prev_X_te,y_te)
kwargs = {'num_workers': 4, 'pin_memory': True}# if args.cuda else {}
test_loader = DataLoader(test_iter, batch_size=batch_size, shuffle=False, **kwargs)

# サンプル生成用のGを用意し，訓練済みパラメータを読み込ませる
netG = sample_generator()
netG.load_state_dict(torch.load('your model'))

# サンプルの生成ループ
output_songs = []
output_chords = []
for i, (data,prev_data,chord) in enumerate(test_loader, 0):
    list_song = []
    first_bar = data[0].view(1,1,16,128)
    list_song.append(first_bar)

    list_chord = []
    first_chord = chord[0].view(1,13).numpy()
    list_chord.append(first_chord)
    noise = torch.randn(batch_size, nz)

    # 小節生成ループ
    for bar in range(n_bars):
        z = noise[bar].view(1,nz)
        y = chord[bar].view(1,13)

        if bar == 0:
            # 最初の小節はrealデータを使う
            prev = data[0].view(1,1,16,128)
        else:
            # 2小節目からは前の小節を条件にする
            prev = list_song[bar-1].view(1,1,16,128)

        # ランダムノイズを基に，前の小節と今のコードを条件として渡して，今の小節を生成
        sample = netG(z, prev, y, 1, pitch_range)

        # 小節を記録
        list_song.append(sample)
        list_chord.append(y.numpy())

    # 生成された曲を記録
    print('num of output_songs: {}'.format(len(output_songs)))
    output_songs.append(list_song)
    output_chords.append(list_chord)

# 生成された曲の保存
np.save('output_songs.npy',np.asarray(output_songs))
np.save('output_chords.npy',np.asarray(output_chords))

print('creation completed, check out what I make!')