## 通常のGANによる訓練
通常のGANによるファインチューニング用プログラム

In [None]:
# G, Dはあらかじめdeviceへ渡されているものとする(使い勝手的に．Deviceの準備時間かかるので)
def train_gan(G, D, g_optim, d_optim, dataloader, epoch_num, output_dir, device,
              start_epoch=1, report_period=50, save_epoch_period=5, clear_disp_epoch_period=3, get_status_dict=True):
    
    torch.backends.cudnn.benchmark = True
    
    batch_num = len(dataloader)

    criterion = nn.BCEWithLogitsLoss(reduction='mean')
    
    G.train()
    D.train()
    
    ones = torch.ones(dataloader.batch_size).to(device)
    zeros = torch.zeros(dataloader.batch_size).to(device)
    
    if get_status_dict:
        status_dict = { 'g_loss': [], 'd_loss': [], "d_acc_real": [], "d_acc_fake": []}
        
    
    try:
        for epoch in range(start_epoch, epoch_num+1):

            t_epoch_start = time.time()
            t_iter_start = time.time()

            total_g_loss = 0.0
            total_d_loss = 0.0
            total_d_acc_real = 0.0
            total_d_acc_fake = 0.0

            print(f"----- Epoch {epoch:>3} / {epoch_num:<3} start -----")
            for iteration, batch in enumerate(dataloader, 1):
                
                batch_size = len(batch.melody)
                label_real = ones[:batch_size]
                label_fake = zeros[:batch_size]

                # 使用するのはmelodyのone-hotなのでmelodyはdeviceへ渡さない
                batch.chord = batch.chord.to(device)
                
                #################
                # Discriminator
                #################
                d_optim.zero_grad()
                g_optim.zero_grad()
                
                with torch.set_grad_enabled(True):
                    # 本物の画像と偽物の画像を判定
                    d_out_real = D(batch.melody, batch.chord)
                    d_out_fake = D(G(batch.chord), batch.chord)

                    # 誤差を計算
                    d_loss_real = criterion(d_out_real, label_real)
                    d_loss_fake = criterion(d_out_fake, label_fake)
                    d_loss = d_loss_real + d_loss_fake

                    # 誤差逆伝搬
                    d_loss.backward()
                    d_optim.step()
                
                # 記録
                d_loss_batch = d_loss.item()
                d_acc_real = (d_out_real.cpu() >= 0.5).sum() / batch_size
                d_acc_fake = (d_out_fake.cpu() <  0.5).sum() / batch_size
                
                total_d_loss += d_loss_batch
                total_d_acc_real += d_acc_real
                total_d_acc_fake += d_acc_fake


                #################
                # Generator
                #################
                d_optim.zero_grad()
                g_optim.zero_grad()
                
                with torch.set_grad_enabled(True):
                    # 偽の画像を生成して判定
                    d_out_fake = D(G(batch.chord), batch.chord)

                    # 誤差を計算
                    g_loss = criterion(d_out_fake, label_real)

                    # 誤差逆伝搬
                    g_loss.backward()
                    g_optim.step()
                
                # 記録
                g_loss_batch = g_loss.item()
                total_g_loss += g_loss_batch


                #################
                # Report
                #################                
                if (iteration % report_period == 0):
                    duration = time.time() - t_iter_start
                    print(f"Iteration {iteration:>5}/{batch_num:<5} | G loss: {g_loss_batch:.6f} | D loss: {d_loss_batch:.6f} | D real: {d_acc_real:.2f} | D fake: {d_acc_fake:.2f} | {duration:.4f} [sec]")
                    t_iter_start = time.time()

            d_loss_epoch = total_d_loss / batch_num
            g_loss_epoch = total_g_loss / batch_num
            d_acc_real_epoch = total_d_acc_real / batch_num
            d_acc_fake_epoch = total_d_acc_fake / batch_num
            print(f"Epoch {epoch:>3} / {epoch_num:<3} Average | G loss: {g_loss_epoch:.6f} | D loss: {d_loss_epoch:.6f} | D real: {d_acc_real_epoch:.2f} | D fake: {d_acc_fake_epoch:.2f}\n")
            
            if get_status_dict:
                status_dict['g_loss'].append(g_loss_epoch)
                status_dict['d_loss'].append(d_loss_epoch)
                status_dict['d_acc_real'].append(d_acc_real_epoch)
                status_dict['d_acc_fake'].append(d_acc_fake_epoch)
            
            epoch_duration = time.time() - t_epoch_start
            print(f"Epoch {epoch:>3} / {epoch_num:<3} finished in {epoch_duration:.4f}[sec]")
            remain_sec = epoch_duration * (epoch_num - epoch)
            print(f"Remaining Time | {remain_sec/3600:.4f} [hour] | {remain_sec/60:.2f} [min] | {remain_sec:.0f} [sec]\n")
            
            if epoch % clear_disp_epoch_period == 0:
                clear_output()
                
            # チェックポイントモデルの保存
            if (epoch % save_epoch_period == 0):
                print("start saving models")
                save_model(config, G, epoch, output_dir)
                save_model(config, D, epoch, output_dir)
                print("")

        print("All Fine-Tuning Finished!")
    
    except KeyboardInterrupt:
        print("Keyboard interrupted, but return models.")
        if get_status_dict:
            return G, D, status_dict
        return G, D
    
    if get_status_dict:
        return G, D, status_dict
    return G, D

In [None]:
epoch_num = 10
G_trained, D_trained, status_dict = train_gan(
    G=G, D=D,
    g_optim=g_optim, d_optim=d_optim,
    dataloader=dataloader,
    epoch_num=epoch_num,
    output_dir=output_dir,
    device=device,
    start_epoch=1,
    report_period=len(dataloader) // 10,
    save_epoch_period = 100
)

Statusグラフの表示

In [None]:
def show_status_graph_gan(status_dict, title="A loss graph of G and D"):
    g_loss = status_dict['g_loss']
    d_loss = status_dict['d_loss']
    d_acc_real = status_dict['d_acc_real']
    d_acc_fake = status_dict['d_acc_fake']
    x = np.arange(len(g_loss))
    
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
        
    axs[0].plot(x, g_loss, label='G', linewidth=1.5)
    axs[0].plot(x, d_loss, label='D', linewidth=1.5)
    axs[0].legend(loc='upper right')
    axs[0].set_xlabel('epoch')
    axs[0].set_ylabel('loss')
    axs[0].set_title(f"Loss")

    axs[1].plot(x, d_acc_real, label='real', linewidth=1.5)
    axs[1].plot(x, d_acc_fake, label='fake', linewidth=1.5)
    axs[1].legend(loc='lower right')
    axs[1].set_xlabel('epoch')
    axs[1].set_ylabel('acc')
    axs[1].set_title(f"Acc of D")

    plt.show()

In [None]:
show_status_graph_gan(status_dict)

生成データの確認

In [None]:
batch = dataloader()
batch_id = 0

g_out = G_trained(batch.chord.to(device))
g_melody = G_trained.to_ids(g_out).cpu().numpy()

bundle = Bundle({
    'melody': g_melody[batch_id],
    'chord': batch.chord[batch_id],
    'meta': batch.meta[batch_id]
})

bundle.meta.melody_pitch_range = [0, config.melody_vocab_size]
ppr = bundle.get_ppr()
grid_plot(ppr, beat_resolution=bundle.meta.beat_resolution)

通常のGAN，Gが手も足も出ない  
Adaptive GTT使えばワンチャン？