In [1]:
import pickle
import soundfile as sf
import IPython.display as ipd

import torch
import torch.nn as nn
import torch.optim as optim

from models.model_1 import Model_1
from magnitude_loss import MagnitudeLoss
from models.sub_pix import UNetWithSubpixel
from models.patch_gan_discriminator import PatchGANDiscriminator

from constants import *
from train_cgan import train_cgan
from dataset import build_data_loaders
from disk_utils import save_model, load_model
from predict import predict_polar, get_phases, make_wav
from plotter import plot_gan_loss, plot_heatmaps, plot_waves

In [2]:
USE_GPU = True
device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu")

In [3]:
with open("dataset/features/min_max.pkl", "rb") as handle:
    min_max = pickle.load(handle)

part = "db"
mini = min_max["ney"]["min"][part]
maxi = min_max["ney"]["max"][part]

test_size = 0.0
train_data_loader, _ = build_data_loaders(min_max, part=part, test_size=test_size)

In [6]:
generator = UNetWithSubpixel(features=32).to(device)
discriminator = PatchGANDiscriminator(in_channels=2).to(device)

adversarial_loss = nn.BCELoss()
# l1_loss = nn.L1Loss()
l1_loss = MagnitudeLoss(mini, maxi)

lr = 2e-4
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

In [7]:
num_epochs = 50
generator, history = train_cgan(device, train_data_loader, generator,
                                discriminator, adversarial_loss, l1_loss,
                                optimizer_G, optimizer_D, num_epochs, 0.8)

 - B: 200	 D: 0.109327	 G: 80.285378	 Gadv: 4.257643
 - B: 400	 D: 0.003657	 G: 42.192760	 Gadv: 5.915465
 - B: 600	 D: 0.011283	 G: 36.945305	 Gadv: 5.712330
 - B: 800	 D: 0.001141	 G: 36.403725	 Gadv: 7.910858
E: 001/50	 D: 0.085190	 G: 59.649465
 - B: 200	 D: 0.001280	 G: 31.222176	 Gadv: 9.226945
 - B: 400	 D: 0.000179	 G: 33.537182	 Gadv: 9.955525
 - B: 600	 D: 0.000474	 G: 33.931927	 Gadv: 8.550458
 - B: 800	 D: 0.171145	 G: 29.083483	 Gadv: 2.408933
E: 002/50	 D: 0.065845	 G: 33.387231
 - B: 200	 D: 0.240890	 G: 33.285950	 Gadv: 2.725527
 - B: 400	 D: 0.109944	 G: 37.931999	 Gadv: 4.681175
 - B: 600	 D: 0.454955	 G: 31.730644	 Gadv: 2.158812
 - B: 800	 D: 0.271027	 G: 34.431980	 Gadv: 3.084268
E: 003/50	 D: 0.226953	 G: 33.813724
 - B: 200	 D: 0.128857	 G: 32.387466	 Gadv: 2.611534
 - B: 400	 D: 0.444878	 G: 33.327080	 Gadv: 1.615031
 - B: 600	 D: 0.312119	 G: 36.760483	 Gadv: 1.969227
 - B: 800	 D: 0.433530	 G: 35.260288	 Gadv: 2.554440
E: 004/50	 D: 0.280387	 G: 33.965665
 - B

In [8]:
generator.to(torch.device("cpu"))
discriminator.to(torch.device("cpu"))

PatchGANDiscriminator(
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(2, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

In [None]:
plot_gan_loss(history, "GAN", start=0)

In [None]:
torch.cuda.empty_cache()
del generator
del discriminator

In [None]:
pred_limit = 32
predictions, targets = predict_polar(
    generator,
    test_data_loader,
    min_max["ney"]["min"][part],
    min_max["ney"]["max"][part],
    limit=pred_limit,
    from_db=(part == "db"))

In [None]:
plot_heatmaps(predictions[0], targets[0], from_db=(part == "db"))

In [None]:
_, test_data_loader_phase = build_data_loaders(
    min_max, part="phase", test_size=test_size)
phases = get_phases(test_data_loader_phase,
                    instrument="ney",
                    limit=pred_limit)

In [None]:
plot_heatmaps(phases[0], phases[0])

In [None]:
wave_prediction = make_wav(predictions, phases)
wave_target = make_wav(targets, phases)
print(len(wave_prediction), len(wave_target))
plot_waves(wave_target, wave_prediction)

In [None]:
ipd.Audio(wave_target, rate=SR)

In [None]:
ipd.Audio(wave_prediction, rate=SR)

In [None]:
sf.write("z_target_gan.wav", wave_target, SR, format="wav")
sf.write("z_prediction_gan.wav", wave_prediction, SR, format="wav")

In [9]:
save_model(generator, "generator_sp_32_0_8_full")
save_model(discriminator, "discriminator_sp_32_0_8_full")