In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
from hydra import compose, initialize
from hydra.utils import instantiate

from dac import DAC
from audiotools import AudioSignal
from IPython.display import Audio, display

try:
    initialize(config_path="config", version_base="1.3")
except ValueError:
    print("Hydra seems to be already initialized")

In [None]:
# Set your checkpoint directory here
CKPT_DIR = 'checkpoints/'  # Adjust path as needed

# Choose the model variant
MODEL = 'flowdec_75m'  # flowdec_75m or flowdec_25s

assert os.path.isdir(CKPT_DIR), "CKPT_DIR seems to not exist on your system. Did you download our checkpoints and set this path correctly?"

In [None]:
# Load the underlying (N)DAC model
ndac_model = {'flowdec_75m': 'ndac-75', 'flowdec_25s': 'ndac-25'}[MODEL]
dac_model = DAC.load(os.path.join(CKPT_DIR, f'ndac/{ndac_model}/800k/dac/weights.pth'))
dac_model.to('cuda')
dac_model.eval()

# Load the FlowDec model
conf = compose(config_name=MODEL)
ckpt = torch.load(os.path.join(CKPT_DIR, f'flowdec/{MODEL}/step=800000.ckpt'), map_location='cpu')
# IMPORTANT: To use EMA weights (default), follow the code below with use_ema_weights=True
use_ema_weights = True
state_dict_key = '_pl_ema_state_dict' if use_ema_weights else 'state_dict'
model = instantiate(conf['model'])
model.load_state_dict(ckpt[state_dict_key])
model.cuda()
model.eval()

In [None]:
# Config for inference

# The path to the file you'd like to enhance.
wav_path = "testfile.wav"

assert os.path.isfile(wav_path), "wav_path seems to not exist on your system. Did you set it correctly?"
# `nq` is the number of quantizers (codebooks):
#  * for flowdec_75m: [10, 8, 6, 4] were seen during training. These represent bitrates of [7.5, 6.0, 4.5, 3.0]kbps, respectively.
#  * for flowdec_25s: Only nq=10 was seen during training. This represents 4.0kbps.
nq = 10
# The solver to use for FlowDec. We use 'euler' or 'midpoint', and midpoint is generally preferable. Note that midpoint has NFE=2*N.
solver = 'midpoint'
# The number of solver steps for FlowDec. Our default is 3 (so NFE=6) which has a good tradeoff between inference speed and quality
N = 3

# Run inference. You can ignore the printed message "Your vector field does not have `nn.Parameters` to optimize."
with torch.inference_mode():
    signal = AudioSignal(wav_path)
    sr_orig = signal.sample_rate
    signal.resample(dac_model.sample_rate)
    signal.to(dac_model.device)

    x = dac_model.preprocess(signal.audio_data, signal.sample_rate)
    z, codes, latents, _, _ = dac_model.encode(x, n_quantizers=nq)
    xhat_ndac = dac_model.decode(z)

    xhat_flowdec = model.enhance(xhat_ndac, N=N, solver=solver)
    if xhat_flowdec.abs().max() > 1.0:
        print("Prevented clipping")
        xhat_flowdec = xhat_flowdec / xhat_flowdec.abs().max()

In [None]:
# Show original signal, NDAC output, and final FlowDec output.
display(
    Audio(signal.audio_data.cpu()[0], rate=signal.sample_rate, normalize=False),
    Audio(xhat_ndac.cpu()[0], rate=signal.sample_rate, normalize=False),
    Audio(xhat_flowdec.cpu()[0], rate=signal.sample_rate, normalize=False),
)