In [60]:
import torch
import pickle
import librosa
import numpy as np
import soundfile as sf
from constants import *
import IPython.display as ipd
from disk_utils import load_model
from pathlib import Path, PurePath
from plotter import plot_loss, plot_heatmaps, plot_waves

In [53]:
GTR_AUDIO_FEATURES_DIR = "dataset/audio_features/gtr/"
NEY_AUDIO_FEATURES_DIR = "dataset/audio_features/ney/"

In [54]:
gtr_feature_paths = sorted([f_dir for f_dir in Path(
    GTR_FEATURE_DIR).iterdir() if f_dir.is_dir()])

ney_feature_paths = sorted([f_dir for f_dir in Path(
    NEY_FEATURE_DIR).iterdir() if f_dir.is_dir()])

In [55]:
with open("dataset/features/min_max.pkl", "rb") as handle:
    min_max = pickle.load(handle)

In [57]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("generator_sp_32_0_8_full").to(device)

In [None]:
num_dirs = len(gtr_feature_paths)
AUDIO_DIR = "dataset/audio_features/"

for i, (gtr_fp, ney_fp) in enumerate(zip(gtr_feature_paths, ney_feature_paths)):
    gtr_files = sorted([f.name for f in gtr_fp.iterdir() if f.name.startswith(
        "chunk")], key=lambda x: int(x.split("_")[1]))
    ney_files = sorted([f.name for f in ney_fp.iterdir() if f.name.startswith(
        "chunk")], key=lambda x: int(x.split("_")[1]))
    for j, (gtr_file, ney_file) in enumerate(zip(gtr_files, ney_files)):
        gtr_path = str(gtr_fp) + "/" + gtr_file
        ney_path = str(ney_fp) + "/" + ney_file

        with open(gtr_path, "rb") as handle:
            gtr_chunk = pickle.load(handle)

        with open(ney_path, "rb") as handle:
            ney_chunk = pickle.load(handle)

        gtr_phase = gtr_chunk["phase"]
        gtr_db = gtr_chunk["db"]
        # 0 - 1 scale
        gtr_db = (gtr_db - min_max["gtr"]["min"]["db"]) / \
            (min_max["gtr"]["max"]["db"] - min_max["gtr"]["min"]["db"])
        gtr_db = np.expand_dims(gtr_db, axis=0)
        with torch.no_grad():
            gtr_db = torch.from_numpy(
                np.array([gtr_db], dtype=np.float32)).to(device)
            predicted_db = model(gtr_db)[0]
            predicted_db = predicted_db.to(
                torch.device("cpu")).numpy().squeeze(axis=0)
        
        # un-scale
        predicted_db = predicted_db * \
            (min_max["ney"]["max"]["db"] - min_max["ney"]["min"]["db"]) + \
            min_max["ney"]["min"]["db"]
        
        # back to magnitude
        predicted_db = librosa.db_to_amplitude(predicted_db)
        
        # reconstruct predicted signal
        pred_signal = librosa.istft(predicted_db * np.exp(1j * gtr_phase),
                                    n_fft=N_FFT, hop_length=HOP)
        
        # overload protection
        signal_max = np.max(np.abs(pred_signal))
        if signal_max > 1.0:
            pred_signal = pred_signal / signal_max

        # write wave files
        idx = i * num_dirs + j
        file_trail = f"{idx:03d}.wav"
        sf.write(AUDIO_DIR + "gtr/gtr_" + file_trail,
                 pred_signal, SR, format="wav")
        ney_signal = ney_chunk["signal"]
        sf.write(AUDIO_DIR + "ney/ney_" + file_trail,
                 ney_signal, SR, format="wav")

    print(str(gtr_fp), "done!")