In [1]:
import argparse
from denoiser.demucs import DemucsStreamer
from denoiser import pretrained
from denoiser.demucs import Demucs
import torch
import torchaudio
import soundfile as sf
import os

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/"
MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th"

def _demucs(pretrained, url, **kwargs):
    model = Demucs(**kwargs)
    if pretrained:
        state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
        model.load_state_dict(state_dict)
    return model

def get_model(model_path = None):
    """
    Load local model package or torchhub pre-trained model.
    """
    pretrained = True
    if model_path is not None:
        logger.info("Loading model from %s", model_path)
        pkg = torch.load(model_path)
        model = deserialize_model(pkg)
    else:
        print("Loading pre-trained real time H=64 model trained on DNS and Valentini.")
        model = _demucs(pretrained, MASTER_64_URL, hidden=64)
    return model


In [4]:
def get_estimate(model, noisy, streaming = False, dry=0):
    torch.set_num_threads(1)
    if streaming:
        streamer = DemucsStreamer(model, dry=dry)
        with torch.no_grad():
            estimate = torch.cat([
                streamer.feed(noisy[0]),
                streamer.flush()], dim=1)[None]
    else:
        with torch.no_grad():
            estimate = model(noisy)
            estimate = (1 - dry) * estimate + dry * noisy
    return estimate

In [5]:
def enhance(noisy_signal, model=None):
    model_path = None
    # Load model
    if not model:
        model = get_model(model_path).to(device)
    model.eval()

    noisy_signal = noisy_signal.to(device)
    # Forward
    estimate = get_estimate(model, noisy_signal)
    #save_wavs(estimate, noisy_signals, filenames, out_dir, sr=args.sample_rate)
    return estimate

In [26]:
def write(wav, filename, sr=16_000):
    # Normalize audio if it prevents clipping
    wav = wav / max(wav.abs().max().item(), 1)
    sf.write(filename, wav, sr) 

    
def save_wav(estimate, filename, out_dir, sr=16_000):
    # Write result
    filename = os.path.join(out_dir, os.path.basename(filename).rsplit(".", 1)[0])
    write(estimate, filename + "_enhanced.wav", sr=sr)

In [8]:
noisy_signal, sr = torchaudio.load('teste.wav')

In [9]:
unoise_signal = enhance(noisy_signal)


Loading pre-trained real time H=64 model trained on DNS and Valentini.


Downloading: "https://dl.fbaipublicfiles.com/adiyoss/denoiser/master64-8a5dfb4bb92753dd.th" to /home/fred/.cache/torch/hub/checkpoints/master64-8a5dfb4bb92753dd.th


HBox(children=(FloatProgress(value=0.0, max=134140945.0), HTML(value='')))




In [16]:
import IPython.display as ipd
unoise_signal = unoise_signal.to('cpu').squeeze()
ipd.Audio(unoise_signal, rate=22050)


In [27]:
save_wav(unoise_signal.to('cpu').squeeze(), 'teste.wav', './', 22050)

torch.Size([185856])
