In [1]:
import pickle
import soundfile as sf
import IPython.display as ipd

import torch
import torch.nn as nn
import torch.optim as optim

from models.model_1 import Model_1
from magnitude_loss import MagnitudeLoss
from models.sub_pix import UNetWithSubpixel
from models.patch_gan_discriminator import PatchGANDiscriminator

from constants import *
from train_cgan import train_cgan
from dataset import build_data_loaders
from disk_utils import save_model, load_model
from predict import predict_polar, get_phases, make_wav
from plotter import plot_gan_loss, plot_heatmaps, plot_waves

In [2]:
USE_GPU = True
device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu")

In [3]:
with open("dataset/features/min_max.pkl", "rb") as handle:
    min_max = pickle.load(handle)

part = "db"
mini = min_max["ney"]["min"][part]
maxi = min_max["ney"]["max"][part]

test_size = 0.0
train_data_loader, _ = build_data_loaders(min_max, part=part, test_size=test_size)

In [8]:
generator = UNetWithSubpixel(features=32).to(device)
discriminator = PatchGANDiscriminator(in_channels=2).to(device)

adversarial_loss = nn.BCELoss()
# l1_loss = nn.L1Loss()
l1_loss = MagnitudeLoss(mini, maxi)

lr = 2e-4
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [9]:
num_epochs = 50
generator, history = train_cgan(device, train_data_loader, generator,
                                discriminator, adversarial_loss, l1_loss,
                                optimizer_G, optimizer_D, num_epochs, 0.5)

 - B: 200	 D: 0.135813	 G: 27.341438	 Gadv: 2.398979
 - B: 400	 D: 0.027003	 G: 21.736961	 Gadv: 3.930495
 - B: 600	 D: 0.139407	 G: 21.045315	 Gadv: 3.008072
 - B: 800	 D: 0.026792	 G: 18.559715	 Gadv: 3.828878
E: 001/50	 D: 0.197625	 G: 30.947995
 - B: 200	 D: 0.036708	 G: 18.333744	 Gadv: 3.332595
 - B: 400	 D: 0.167410	 G: 20.821486	 Gadv: 2.628448
 - B: 600	 D: 0.216279	 G: 22.721176	 Gadv: 3.401337
 - B: 800	 D: 0.249197	 G: 19.575676	 Gadv: 3.675259
E: 002/50	 D: 0.295046	 G: 20.945536
 - B: 200	 D: 0.208989	 G: 20.210993	 Gadv: 1.893925
 - B: 400	 D: 0.227576	 G: 20.545208	 Gadv: 1.802784
 - B: 600	 D: 0.541844	 G: 19.140469	 Gadv: 2.930863
 - B: 800	 D: 0.211731	 G: 21.680067	 Gadv: 2.748521
E: 003/50	 D: 0.394236	 G: 21.115144
 - B: 200	 D: 0.281411	 G: 21.913511	 Gadv: 2.469565
 - B: 400	 D: 0.334496	 G: 21.375322	 Gadv: 2.167852
 - B: 600	 D: 0.463213	 G: 19.666477	 Gadv: 0.765149
 - B: 800	 D: 0.003670	 G: 22.816332	 Gadv: 5.698107
E: 004/50	 D: 0.323602	 G: 22.763780
 - B

In [10]:
generator.to(torch.device("cpu"))
discriminator.to(torch.device("cpu"))

PatchGANDiscriminator(
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(2, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

In [None]:
plot_gan_loss(history, "GAN", start=0)

In [None]:
torch.cuda.empty_cache()
del generator
del discriminator

In [None]:
pred_limit = 32
predictions, targets = predict_polar(
    generator,
    test_data_loader,
    min_max["ney"]["min"][part],
    min_max["ney"]["max"][part],
    limit=pred_limit,
    from_db=(part == "db"))

In [None]:
plot_heatmaps(predictions[0], targets[0], from_db=(part == "db"))

In [None]:
_, test_data_loader_phase = build_data_loaders(
    min_max, part="phase", test_size=test_size)
phases = get_phases(test_data_loader_phase,
                    instrument="ney",
                    limit=pred_limit)

In [None]:
plot_heatmaps(phases[0], phases[0])

In [None]:
wave_prediction = make_wav(predictions, phases)
wave_target = make_wav(targets, phases)
print(len(wave_prediction), len(wave_target))
plot_waves(wave_target, wave_prediction)

In [None]:
ipd.Audio(wave_target, rate=SR)

In [None]:
ipd.Audio(wave_prediction, rate=SR)

In [None]:
sf.write("z_target_gan.wav", wave_target, SR, format="wav")
sf.write("z_prediction_gan.wav", wave_prediction, SR, format="wav")

In [11]:
save_model(generator, "generator_sp_32_0_5")
save_model(discriminator, "discriminator_sp_32_0_5")