In [None]:
import pickle
import soundfile as sf
import IPython.display as ipd

import torch
import torch.nn as nn
import torch.optim as optim

from models.model_1 import Model_1
from magnitude_loss import MagnitudeLoss
from models.patch_gan_discriminator import PatchGANDiscriminator

from constants import *
from train_cgan import train_cgan
from dataset import build_data_loaders
from disk_utils import save_model, load_model
from predict import predict_polar, get_phases, make_wav
from plotter import plot_gan_loss, plot_heatmaps, plot_waves

In [None]:
USE_GPU = True
device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU
                      else "cpu")

In [None]:
with open("dataset/features/min_max.pkl", "rb") as handle:
    min_max = pickle.load(handle)

part = "db"
mini = min_max["ney"]["min"][part]
maxi = min_max["ney"]["max"][part]

test_size = 0.05
train_data_loader, test_data_loader = build_data_loaders(
    min_max, part=part, test_size=test_size)

In [None]:
generator = Model_1(in_channels=1, out_channels=1, base_features=32).to(device)
discriminator = PatchGANDiscriminator(in_channels=2).to(device)

adversarial_loss = nn.BCELoss()
# l1_loss = nn.L1Loss()
l1_loss = MagnitudeLoss(mini, maxi)

lr = 2e-4
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

In [None]:
num_epochs = 50
generator, history = train_cgan(device, train_data_loader, generator,
                                discriminator, adversarial_loss, l1_loss,
                                optimizer_G, optimizer_D, num_epochs, 1)

In [None]:
generator.to(torch.device("cpu"))
discriminator.to(torch.device("cpu"))

In [None]:
plot_gan_loss(history, "GAN", start=0)

In [None]:
torch.cuda.empty_cache()
del generator
del discriminator

In [None]:
pred_limit = 32
predictions, targets = predict_polar(
    generator,
    test_data_loader,
    min_max["ney"]["min"][part],
    min_max["ney"]["max"][part],
    limit=pred_limit,
    from_db=(part == "db"))

In [None]:
plot_heatmaps(predictions[0], targets[0])

In [None]:
_, test_data_loader_phase = build_data_loaders(
    min_max, part="phase", test_size=test_size)
phases = get_phases(test_data_loader_phase,
                    instrument="ney",
                    limit=pred_limit)

In [None]:
plot_heatmaps(phases[0], phases[0])

In [None]:
wave_prediction = make_wav(predictions, phases)
wave_target = make_wav(targets, phases)
print(len(wave_prediction), len(wave_target))
plot_waves(wave_target, wave_prediction)

In [None]:
ipd.Audio(wave_target, rate=SR)

In [None]:
ipd.Audio(wave_prediction, rate=SR)

In [None]:
sf.write("z_target.wav", wave_target, SR, format="wav")
sf.write("z_prediction.wav", wave_prediction, SR, format="wav")

In [None]:
save_model(generator, "generator")
save_model(discriminator, "discriminator")