In [2]:
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


# Inference process of WaveGrad

In [2]:
import sys
sys.path.insert(0, '..')

import json
import IPython.display as ipd

import torch

from tqdm import tqdm

import utils
from model import WaveGrad
from data import AudioDataset
from train import MelSpectrogramFixed

**Load configuration**

In [3]:
CONFIG_PATH='../configs/50iters.json'

In [4]:
with open(CONFIG_PATH) as f:
    config = utils.ConfigWrapper(**json.load(f))
config.training_config.logdir = f'../{config.training_config.logdir}'
config.training_config.train_filelist_path = f'../{config.training_config.train_filelist_path}'
config.training_config.test_filelist_path = f'../{config.training_config.test_filelist_path}'
config

{'model_config': {'noise_schedule': {'n_iter': 50, 'betas_range': [0.0001, 0.05]}, 'factors': [5, 5, 3, 2, 2], 'upsampling_preconv_out_channels': 768, 'upsampling_out_channels': [512, 512, 256, 128, 128], 'upsampling_dilations': [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], 'downsampling_preconv_out_channels': 32, 'downsampling_out_channels': [128, 128, 256, 512], 'downsampling_dilations': [[1, 2, 4], [1, 2, 4], [1, 2, 4], [1, 2, 4]]}, 'data_config': {'sample_rate': 22050, 'n_fft': 1024, 'win_length': 1024, 'hop_length': 300, 'f_min': 80.0, 'f_max': 8000, 'n_mels': 80}, 'training_config': {'logdir': '../logs/50iters', 'continue_training': False, 'train_filelist_path': '../filelists/train.txt', 'test_filelist_path': '../filelists/test.txt', 'batch_size': 48, 'segment_length': 7200, 'lr': 0.001, 'grad_clip_threshold': 1, 'scheduler_step_size': 1, 'scheduler_gamma': 0.9, 'n_epoch': 100000000, 'n_samples_to_test': 4, 'test_interval': 1}}

**Initialize the model**

In [5]:
model = WaveGrad(config).cuda()
print(f'Number of parameters: {model.nparams}')

Number of parameters: 15810401


In [6]:
model, _, _ = utils.load_latest_checkpoint(config.training_config.logdir, model)

Latest checkpoint: ../logs/50iters/checkpoint_3780.pt


**Initialize the dataset**

In [7]:
dataset = AudioDataset(config, training=False)
mel_fn = MelSpectrogramFixed(
    sample_rate=config.data_config.sample_rate,
    n_fft=config.data_config.n_fft,
    win_length=config.data_config.win_length,
    hop_length=config.data_config.hop_length,
    f_min=config.data_config.f_min,
    f_max=config.data_config.f_max,
    n_mels=config.data_config.n_mels,
    window_fn=torch.hann_window
).cuda()

In [8]:
TEST_BATCH_SIZE=4

# Sample test batch from test set 
test_batch = dataset.sample_test_batch(TEST_BATCH_SIZE)

**Inference**

In [9]:
STORE_INTERMEDIATE_STATE=False

test_preds = []
for test_sample in tqdm(test_batch):
    mel = mel_fn(test_sample[None].cuda())
    outputs = model.sample_subregions_parallel(
        mel,
        store_intermediate_states=STORE_INTERMEDIATE_STATE
    )
    test_preds.append(outputs)

100%|██████████| 4/4 [00:24<00:00,  6.22s/it]


In [10]:
for signal in test_preds:
    ipd.display(ipd.Audio(signal.squeeze().cpu(), rate=config.data_config.sample_rate))