In [1]:
import torch
import torch.nn.functional as F
import gc
from IPython.display import Audio, display
from supervoice_enhance.model import EnhanceModel 
from supervoice_enhance.audio import load_mono_audio, spectogram
from supervoice_enhance.config import config
from training.audio import do_reverbrate

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
vocoder = torch.hub.load(repo_or_dir='ex3ndr/supervoice-vocoder', model='bigvsan')
flow = torch.hub.load(repo_or_dir='ex3ndr/supervoice-flow', model='flow', force_reload=True)
direct_flow = torch.hub.load(repo_or_dir='ex3ndr/supervoice-flow', model='flow', force_reload=True)
checkpoint = torch.load(f'./output/ft-01.pt', map_location="cpu")
enhance = EnhanceModel(flow, config)
enhance.load_state_dict(checkpoint['model'])
direct_flow.to(device)
direct_flow.eval()
vocoder.to(device)
vocoder.eval()
enhance.to(device)
enhance.eval()
print("OK")

Using cache found in /home/steve/.cache/torch/hub/ex3ndr_supervoice-vocoder_master
Downloading: "https://github.com/ex3ndr/supervoice-flow/zipball/main" to /home/steve/.cache/torch/hub/main.zip
Downloading: "https://github.com/ex3ndr/supervoice-flow/zipball/main" to /home/steve/.cache/torch/hub/main.zip


OK


In [3]:
def do_vocoder(src):
    with torch.no_grad():
        return vocoder.generate(src)

def do_enhance(src, steps = 8):
    src = (src - config.audio.norm_mean) / config.audio.norm_std
    pr = enhance.sample(source = src.to(torch.float32), steps = steps)
    return ((pr * config.audio.norm_std) + config.audio.norm_mean).to(torch.float32)

def do_flow(src, steps = 8):
    src = (src - config.audio.norm_mean) / config.audio.norm_std
    pr, _ = direct_flow.sample(audio = src.to(torch.float32), steps = steps)
    return ((pr * config.audio.norm_std) + config.audio.norm_mean).to(torch.float32)

In [4]:
# Load source file
source_file = "./external_datasets/libritts-r/test-clean/1221/135766/1221_135766_000015_000000.wav"
source_raw = load_mono_audio(source_file, sample_rate = config.audio.sample_rate)

# Distort audio
rir_file = "./external_datasets/rir-1/files/00000008.wav"
rir = load_mono_audio(rir_file, config.audio.sample_rate)
distorted = do_reverbrate(source_raw, rir)

# Pad audio
source = distorted
target_length = 5 * config.audio.sample_rate
current_length = source.shape[0]
padding_length = target_length - current_length
source = F.pad(source, (0, padding_length), mode='constant')

# Get spectogram
spec = spectogram(source, 
    n_fft = config.audio.n_fft, 
    n_mels = config.audio.n_mels, 
    n_hop = config.audio.hop_size, 
    n_window = config.audio.win_size,  
    mel_norm = config.audio.mel_norm, 
    mel_scale = config.audio.mel_scale, 
    sample_rate = config.audio.sample_rate
)

# Vocode back
source_rec = do_vocoder(spec.to(device).unsqueeze(0)).squeeze(0)

In [5]:
display(Audio(data=source_raw.cpu(), rate=config.audio.sample_rate))
display(Audio(data=source_rec.cpu(), rate=config.audio.sample_rate))

In [None]:
gc.collect()
enhanced = do_enhance(spec.to(device).transpose(0, 1), 32)
enhanced_rec = do_vocoder(enhanced.transpose(0, 1).to(device).unsqueeze(0)).squeeze(0)
display(Audio(data=enhanced_rec.cpu(), rate=config.audio.sample_rate))

In [7]:
flow_out = do_flow(spec.transpose(0, 1), 8)
flow_out_rec = do_vocoder(flow_out.transpose(0, 1).to(device).unsqueeze(0)).squeeze(0)
display(Audio(data=flow_out_rec.cpu(), rate=config.audio.sample_rate))