In [1]:
import sys

sys.path.append('../../')
import json
import torchaudio
import os
import librosa
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
import torch
from academicodec.models.hificodec.vqvae import VQVAE
from academicodec.models.hificodec.meldataset import mel_spectrogram
from librosa.util import normalize

In [2]:
pretrained_config_path = './config_24k_960d.json'
refiner_config_path = './config_refiner_24k_120d.json'

pretrained_path = '50m_logs_4cb_960d/step_110k'
# pretrained_path = '50m_logs_2cb_960d/g_00078000'
refiner_path = '10m_logs_refiner_64cb_240d/g_00015000'

wav_path = './sample.wav'
base_output_path = './base.wav'
refined_output_path = './refined.wav'

assert pretrained_config_path and os.path.exists(pretrained_config_path)
assert refiner_config_path and os.path.exists(refiner_config_path)
assert pretrained_path and os.path.exists(pretrained_path)
if refiner_path:
    assert os.path.exists(refiner_path)
assert wav_path and os.path.exists(wav_path)
if refined_output_path:
    assert refiner_path and os.path.exists(refiner_path)

In [3]:
with open(pretrained_config_path, 'r') as f:
    pretrained_config = json.load(f)

with open(refiner_config_path, 'r') as f:
    refiner_config = json.load(f)
    sample_rate = refiner_config['sampling_rate']

In [4]:
wav, sr = librosa.load(wav_path, sr=sample_rate)
print("wav.shape:",wav.shape)
assert sr == sample_rate

wav = normalize(wav) * 0.95
wav = torch.FloatTensor(wav).unsqueeze(0)

print(wav.size())

wav.shape: (240000,)
torch.Size([1, 240000])


In [5]:
print("Init model and load weights")

pretrained_model = VQVAE(
    pretrained_config_path,
    ckpt_path=pretrained_path,
    with_encoder=True)
pretrained_model.eval()

refiner_model = VQVAE(
    refiner_config_path,
    ckpt_path=refiner_path,
    with_encoder=True
)

print("Model ready")

Init model and load weights


Model ready


In [6]:
text_targets = pretrained_model.wav_to_text_target(wav)
y_hat = pretrained_model.text_target_to_wav(text_targets)

In [7]:
y_hat.size()
torchaudio.save('text_target_and_back.wav', y_hat[0], sample_rate, channels_first=True)

In [9]:
latent_image = pretrained_model.text_target_to_latent_image(text_targets)
print(latent_image.size())
y_hat = pretrained_model.generator(latent_image)
torchaudio.save('text_target_and_back.wav', y_hat[0], sample_rate, channels_first=True)

torch.Size([1, 512, 250])


In [6]:
acoustic_token = pretrained_model.wav_to_acoustic_token(wav)
generated_output = pretrained_model.acoustic_token_to_wav(acoustic_token)
torchaudio.save('debug.wav', generated_output[0], sample_rate, channels_first=True)

refined_generated_output = refiner_model(generated_output[0])
torchaudio.save('debug_refined.wav', generated_output[0], sample_rate, channels_first=True)

In [None]:
# if base_output_path:
#     torchaudio.save(base_output_path, base_y_hat[0], sample_rate, channels_first=True)
# if refined_output_path:
#     torchaudio.save(refined_output_path, refined_y_hat[0], sample_rate, channels_first=True)

In [None]:
# print(torch.nn.functional.mse_loss(mel_spectrogram(
#     wav.squeeze(1).cpu(), 1024, 80,
#     24000, 256, 1024,
#     0, 8000), mel_spectrogram(
#     base_y_hat.squeeze(1).detach(), 1024, 80,
#     24000, 256, 1024,
#     0, 8000)).numpy())

# print(torch.nn.functional.l1_loss(mel_spectrogram(
#     wav.squeeze(1).cpu(), 1024, 80,
#     24000, 256, 1024,
#     0, 8000), mel_spectrogram(
#     base_y_hat.squeeze(1).detach(), 1024, 80,
#     24000, 256, 1024,
#     0, 8000)).numpy())

In [None]:
# if refiner_model:
#     print(torch.nn.functional.mse_loss(mel_spectrogram(
#         wav.squeeze(1).cpu(), 1024, 80,
#         24000, 256, 1024,
#         0, 8000), mel_spectrogram(
#         refined_y_hat.squeeze(1).detach(), 1024, 80,
#         24000, 256, 1024,
#         0, 8000)).numpy())

#     print(torch.nn.functional.l1_loss(mel_spectrogram(
#         wav.squeeze(1).cpu(), 1024, 80,
#         24000, 256, 1024,
#         0, 8000), mel_spectrogram(
#         refined_y_hat.squeeze(1).detach(), 1024, 80,
#         24000, 256, 1024,
#         0, 8000)).numpy())