In [1]:
import torch
from torch.optim import Adam

import pickle
import soundfile as sf
import IPython.display as ipd

from constants import *
from train import train_model
from early_stopper import EarlyStopper

from phase_loss import PhaseLoss
from magnitude_loss import MagnitudeLoss

from polar_dataset import build_data_loaders
from disk_utils import save_model, load_model
from plotter import plot_loss, plot_heatmaps, plot_waves
from predict import predict_polar, get_phases, make_wav

from models.model_18 import Model_18

In [None]:
USE_GPU = False

num_epochs = 250
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not USE_GPU:
    device = torch.device("cpu")
    num_epochs = 1

In [2]:
part = "magnitude"
with open("dataset/features/min_max.pkl", "rb") as handle:
    min_max = pickle.load(handle)

model = Model_18()

train_data_loader, test_data_loader = build_data_loaders(
    min_max, part=part, test_size=0.1)

criterion = MagnitudeLoss(
    min_max["ney"]["min"][part],
    min_max["ney"]["max"][part]
)
optimizer = Adam(model.parameters(), lr=3e-4)
es = EarlyStopper(7, 3e-6)

In [None]:
model, history = train_model(
    model,
    criterion,
    optimizer,
    device,
    train_data_loader, test_data_loader,
    es,
    num_epochs=num_epochs
)

In [None]:
model.to(torch.device("cpu"))

In [None]:
del model
torch.cuda.empty_cache()

In [None]:
plot_loss(history, "Loss", start=0)

In [None]:
predictions, targets = predict_polar(
    model,
    test_data_loader,
    min_max["ney"]["min"][part],
    min_max["ney"]["max"][part],
    limit=12)

In [None]:
plot_heatmaps(predictions[0], targets[0])

In [None]:
_, test_data_loader_phase = build_data_loaders(
    min_max, part="phase", test_size=0.1)
phases = get_phases(test_data_loader_phase, instrument="ney", limit=12)

In [None]:
plot_heatmaps(phases[0], phases[0])

In [None]:
wave_prediction = make_wav(predictions, phases)
wave_target = make_wav(targets, phases)
print(len(wave_prediction), len(wave_target))
plot_waves(wave_target, wave_prediction)

In [None]:
ipd.Audio(wave_target, rate=SR)

In [None]:
ipd.Audio(wave_prediction, rate=SR)

In [None]:
sf.write("z_target.wav", wave_target, SR, format="wav")
sf.write("z_prediction.wav", wave_prediction, SR, format="wav")