In [22]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


# Inference process of WaveGrad

In [23]:
import sys
sys.path.insert(0, '..')

import json
import IPython.display as ipd

import torch

from tqdm import tqdm

import utils
from model import WaveGrad
from data import AudioDataset, MelSpectrogramFixed

**Load configuration**

In [24]:
CONFIG_PATH='../configs/default.json'

with open(CONFIG_PATH) as f:
    config = utils.ConfigWrapper(**json.load(f))
config.training_config.logdir = f'../{config.training_config.logdir}'
config.training_config.train_filelist_path = f'../{config.training_config.train_filelist_path}'
config.training_config.test_filelist_path = f'../{config.training_config.test_filelist_path}'
config

{'model_config': {'factors': [5, 5, 3, 2, 2], 'upsampling_preconv_out_channels': 768, 'upsampling_out_channels': [512, 512, 256, 128, 128], 'upsampling_dilations': [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], 'downsampling_preconv_out_channels': 32, 'downsampling_out_channels': [128, 128, 256, 512], 'downsampling_dilations': [[1, 2, 4], [1, 2, 4], [1, 2, 4], [1, 2, 4]]}, 'data_config': {'sample_rate': 22050, 'n_fft': 1024, 'win_length': 1024, 'hop_length': 300, 'f_min': 80.0, 'f_max': 8000, 'n_mels': 80}, 'training_config': {'logdir': '../logs/default', 'continue_training': False, 'train_filelist_path': '../filelists/train.txt', 'test_filelist_path': '../filelists/test.txt', 'batch_size': 48, 'segment_length': 7200, 'lr': 0.001, 'grad_clip_threshold': 1, 'scheduler_step_size': 1, 'scheduler_gamma': 0.9, 'n_epoch': 100000000, 'n_samples_to_test': 4, 'test_interval': 1, 'training_noise_schedule': {'n_iter': 1000, 'betas_range': [1e-06, 0.01]}, 'test_noise_sched

**Initialize the model**

In [25]:
model = WaveGrad(config).cuda()
print(f'Number of parameters: {model.nparams}')

Number of parameters: 15810401


In [26]:
model, _, _ = utils.load_latest_checkpoint(config.training_config.logdir, model)

Latest checkpoint: ../logs/default/checkpoint_2430.pt


**Initialize the dataset**

In [27]:
dataset = AudioDataset(config, training=False)
mel_fn = MelSpectrogramFixed(
    sample_rate=config.data_config.sample_rate,
    n_fft=config.data_config.n_fft,
    win_length=config.data_config.win_length,
    hop_length=config.data_config.hop_length,
    f_min=config.data_config.f_min,
    f_max=config.data_config.f_max,
    n_mels=config.data_config.n_mels,
    window_fn=torch.hann_window
).cuda()

In [28]:
TEST_BATCH_SIZE=1

# Sample test batch from test set 
test_batch = dataset.sample_test_batch(TEST_BATCH_SIZE)

**Set noise schedule**

In [31]:
SCHEDULES={
    '1000': {'n_iter': 1000, 'betas_range': (1e-6, 0.01)},
    '50': {'n_iter': 50, 'betas_range': (1e-6, 0.01)},
    '25': {'n_iter': 25, 'betas_range': (1e-6, 0.01)},
    '6': {'n_iter': 6, 'betas_range': (1e-6, 0.01)}
}
SHEDULE_TO_SET='6'

In [32]:
model.set_new_noise_schedule(
    n_iter=SCHEDULES[SHEDULE_TO_SET]['n_iter'],
    betas_range=SCHEDULES[SHEDULE_TO_SET]['betas_range']
)

**Inference**

In [33]:
STORE_INTERMEDIATE_STATES=False

test_preds = []
for test_sample in tqdm(test_batch):
    mel = mel_fn(test_sample[None].cuda())
    outputs = model.sample(
        mel,
        store_intermediate_states=STORE_INTERMEDIATE_STATES
    )
    test_preds.append(outputs)

100%|██████████| 1/1 [00:00<00:00,  4.09it/s]


In [34]:
for signal in test_preds:
    ipd.display(ipd.Audio(signal.squeeze().cpu(), rate=config.data_config.sample_rate))

**Compute RTF**

In [35]:
from benchmark import estimate_average_rtf_on_filelist

In [36]:
rtf_stats = estimate_average_rtf_on_filelist(
    '../filelists/test.txt', config, model, verbose=True
)
rtf_stats

100%|██████████| 100/100 [00:35<00:00,  2.80it/s]

DEVICE: cuda:0. average_rtf=0.04843389221471901, std=0.0034350928764161864





{'rtfs': [0.06296673097826086,
  0.04721706818181818,
  0.046860896551724145,
  0.04679387749003985,
  0.04584513807531381,
  0.046751732590529255,
  0.04686109090909091,
  0.04917216818181818,
  0.047020308739255015,
  0.04671767071320183,
  0.0629121059602649,
  0.04805081114327062,
  0.04591394360385144,
  0.045648056603773586,
  0.04623591,
  0.04660519382022472,
  0.05044431927710843,
  0.04572951854974705,
  0.04790507272727273,
  0.05283974137931034,
  0.053012107594936714,
  0.04935762096774194,
  0.04957418502202643,
  0.04783632515337423,
  0.04878446202531646,
  0.05290257216494845,
  0.04684448039215686,
  0.0524715625,
  0.04709630132450331,
  0.04642452336448598,
  0.04676413092979127,
  0.0462585448943662,
  0.04690550764818355,
  0.046106943750000004,
  0.05122658910891089,
  0.046242908955223885,
  0.04661723031496062,
  0.04577087562189055,
  0.04588364654002714,
  0.04707576642335766,
  0.048482760383386586,
  0.04639666448445171,
  0.05084658870967742,
  0.046113390