In [None]:
import pickle
import soundfile as sf
import IPython.display as ipd

import torch
import torch.nn as nn
import torch.optim as optim
from models.unet_gen import UNetGenerator
from models.patch_gan_discriminator import PatchGANDiscriminator

from constants import *
from train_cgan import train_cgan
from polar_dataset import build_data_loaders
from disk_utils import save_model, load_model
from predict import predict_polar, get_phases, make_wav
from plotter import plot_loss, plot_heatmaps, plot_waves

In [None]:
part = "magnitude"
USE_GPU = True
num_epochs = 20
test_size = 0.1
device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU
                      else "cpu")
if not USE_GPU:
    num_epochs = 1

In [None]:
with open("dataset/features/min_max.pkl", "rb") as handle:
    min_max = pickle.load(handle)

train_data_loader, test_data_loader = build_data_loaders(
    min_max, part="magnitude", test_size=test_size)

In [None]:
generator = UNetGenerator(in_channels=1, out_channels=1).to(device)
discriminator = PatchGANDiscriminator(in_channels=2).to(device)

adversarial_loss = nn.BCELoss()
l1_loss = nn.L1Loss()

lr = 2e-4
optimizer_G = optim.Adam(generator.parameters(),
                         lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(),
                         lr=lr, betas=(0.5, 0.999))

In [None]:
train_cgan(device, train_data_loader, generator,
           discriminator, adversarial_loss, l1_loss,
           optimizer_G, optimizer_D, num_epochs)

In [None]:
generator.to(torch.device("cpu"))
discriminator.to(torch.device("cpu"))

In [None]:
torch.cuda.empty_cache()
del generator
del discriminator

In [None]:
pred_limit = 32
predictions, targets = predict_polar(
    generator,
    test_data_loader,
    min_max["ney"]["min"][part],
    min_max["ney"]["max"][part],
    limit=pred_limit,
    from_db=False)

In [None]:
plot_heatmaps(predictions[0], targets[0])

In [None]:
_, test_data_loader_phase = build_data_loaders(
    min_max, part="phase", test_size=test_size)
phases = get_phases(test_data_loader_phase,
                    instrument="ney",
                    limit=pred_limit)

In [None]:
plot_heatmaps(phases[0], phases[0])

In [None]:
wave_prediction = make_wav(predictions, phases)
wave_target = make_wav(targets, phases)
print(len(wave_prediction), len(wave_target))
plot_waves(wave_target, wave_prediction)

In [None]:
ipd.Audio(wave_target, rate=SR)

In [None]:
ipd.Audio(wave_prediction, rate=SR)

In [None]:
sf.write("z_target.wav", wave_target, SR, format="wav")
sf.write("z_prediction.wav", wave_prediction, SR, format="wav")

In [None]:
save_model(generator, "generator")
save_model(discriminator, "discriminator")