In [1]:
import torch
import torchaudio
from supervoice_enhance.model import EnhanceModel
from supervoice_enhance.wrapper import SuperVoiceEnhance
from supervoice_enhance.config import config
from IPython.display import Audio, display

# Loading model

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load(repo_or_dir='ex3ndr/supervoice-enhance', model='enhance')
model.to(device)
model.eval()
print("OK")

Using cache found in /home/steve/.cache/torch/hub/ex3ndr_supervoice-enhance_main
Using cache found in /home/steve/.cache/torch/hub/ex3ndr_supervoice-vocoder_master
Using cache found in /home/steve/.cache/torch/hub/ex3ndr_supervoice-flow_main


OK


# Loading file
Provide custom file or try the sample

In [3]:
from ipywidgets import FileUpload
upload = FileUpload(multiple=False)
upload

FileUpload(value=(), description='Upload')

In [4]:
def load_mono_audio(path):
    # Load audio
    audio, sr = torchaudio.load(path)

    # Resample
    if sr != model.sample_rate:
        audio = torchaudio.transforms.Resample(sr, model.sample_rate)(audio)
        sr = model.sample_rate

    # Convert to mono
    if audio.shape[0] > 1:
        audio = audio.mean(dim=0, keepdim=True)

    # Convert to single dimension
    audio = audio[0]

    return audio

# Load
if len(upload.value) == 1:
    with open("eval.out", "w+b") as i:
        i.write(upload.value[0].content)
    source = load_mono_audio("eval.out")
else:
    source = load_mono_audio("./eval/eval_2.wav")

# Cut 5 seconds
target_length = 5 * model.sample_rate
current_length = source.shape[0]
padding_length = target_length - current_length
source = torch.nn.functional.pad(source, (0, padding_length), mode='constant')

# Play audio
display(Audio(data=source, rate=model.sample_rate))

# Enhance

In [5]:
# Enhance
enhanced_8_step = model.enhance(source, steps = 8)
enhanced_32_step = model.enhance(source, steps = 32)

# Play audio
display(Audio(data=enhanced_8_step, rate=model.sample_rate))
display(Audio(data=enhanced_32_step, rate=model.sample_rate))