# Supervoice Vocoder evaluation

This notebook gives you an opportunity to test vocoder. This notebook is runnin on CPU just in case if your GPUs are busy training.

In [1]:
import torch
import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as F
from IPython.display import Audio, display
import matplotlib.pyplot as plt

In [2]:
# Model loading
model = torch.hub.load(repo_or_dir='ex3ndr/supervoice-vocoder', model='bigvsan')

Using cache found in /home/steve/.cache/torch/hub/ex3ndr_supervoice-vocoder_master


## Load file
You can provide any reasonable audio file to check re-synthesing it from Mel Spectogram

In [3]:
from ipywidgets import FileUpload
upload = FileUpload(multiple=False)
upload

FileUpload(value=(), description='Upload')

In [4]:
def load_mono_audio(path):
    # Load audio
    audio, sr = torchaudio.load(path)

    # Resample
    if sr != 24000:
        audio = resampler(sr, 24000, device)(audio)
        sr = 24000

    # Convert to mono
    if audio.shape[0] > 1:
        audio = audio.mean(dim=0, keepdim=True)

    # Convert to single dimension
    audio = audio[0]

    return audio

# Load
if len(upload.value) == 1:
    with open("eval_vocoder.out", "w+b") as i:
        i.write(upload.value[0].content)
    source = load_mono_audio("eval_vocoder.out")
else:
    source = load_mono_audio("./sample.wav")

## Mel Spectogram
This code shows how to configure mel spectogram that is compatible with this vocoder

In [5]:
def spectogram(src):
    # Hann Window
    window = torch.hann_window(1024)

    # STFT
    stft = torch.stft(src, 
        n_fft = 1024, 
        hop_length = 256, 
        win_length = 1024,
        window = window, 
        center = True,
        return_complex = True
    )

    # magnitudes = stft[..., :-1].abs() ** 2 # Power
    magnitudes = stft[..., :-1].abs() # Amplitude (used by BigVSAN)

    # Mel Bank
    mel_filters = F.melscale_fbanks(
            n_freqs=int(1024 // 2 + 1),
            sample_rate=24000,
            f_min=0,
            f_max=12000,
            n_mels=100,
            norm="slaney",
            mel_scale="slaney"
    ).transpose(-1, -2)
    mel_spec = (mel_filters @ magnitudes)

    # Log
    log_spec = torch.clamp(mel_spec, min=1e-10).log()

    return log_spec

spec = spectogram(source)

## Resynthesing
This code synthesizes sound back, first audio is the source audio (rescaled to match vocoder parameters) and recreated audio

In [7]:
resynth = model(spec.unsqueeze(0)).detach().squeeze(0)
display(Audio(data=source, rate=24000))
display(Audio(data=resynth, rate=24000))

tensor([[-0.0010, -0.0024, -0.0026,  ..., -0.0025, -0.0023, -0.0023]])


In [67]:
e0 = torch.load("./bigvsan.pt", map_location='cpu')['generator']
e1 = torch.load("./bigvsan_no_norm_ok.pt", map_location='cpu')

def manual_remove_weight_norm(state_dict):
    for key in list(state_dict.keys()):  # list(...) to make a copy of keys
        if 'weight_v' in key:
            layer_base = key.rsplit('.', 1)[0]  # Get the base name of the layer
            g_key = f'{layer_base}.weight_g'
            if g_key in state_dict:
                v = state_dict[key]
                g = state_dict[g_key]
                # Compute the recombined weight with proper norm calculation
                norm = v.norm(2, dim=-1, keepdim=True)  # Assuming last dim for norm
                recombined_weight = v * (g / norm.expand_as(v))
                # Update the original weight and remove g and v
                state_dict[layer_base + '.weight'] = recombined_weight
                del state_dict[key]  # Remove weight_v
                del state_dict[g_key]  # Remove weight_g
    return state_dict

e00 = manual_remove_weight_norm(torch.load("./bigvsan.pt", map_location='cpu')['generator'])

In [68]:
for k in e1:
    v = (e1[k] - e00[k]).abs().max()
    if v > 0:
        print(k, v)

conv_pre.weight tensor(1.5578)
ups.0.0.weight tensor(1.5023)
ups.1.0.weight tensor(1.7622)
ups.2.0.weight tensor(1.6154)
ups.3.0.weight tensor(0.5248)
ups.4.0.weight tensor(0.3663)
ups.5.0.weight tensor(0.4274)
resblocks.0.convs1.0.weight tensor(1.8410)
resblocks.0.convs1.1.weight tensor(1.2006)
resblocks.0.convs1.2.weight tensor(1.1234)
resblocks.0.convs2.0.weight tensor(1.1003)
resblocks.0.convs2.1.weight tensor(1.4996)
resblocks.0.convs2.2.weight tensor(2.5157)
resblocks.1.convs1.0.weight tensor(2.2503)
resblocks.1.convs1.1.weight tensor(1.0618)
resblocks.1.convs1.2.weight tensor(1.4948)
resblocks.1.convs2.0.weight tensor(1.2837)
resblocks.1.convs2.1.weight tensor(2.2779)
resblocks.1.convs2.2.weight tensor(3.1590)
resblocks.2.convs1.0.weight tensor(1.7225)
resblocks.2.convs1.1.weight tensor(1.4418)
resblocks.2.convs1.2.weight tensor(1.1042)
resblocks.2.convs2.0.weight tensor(1.4634)
resblocks.2.convs2.1.weight tensor(2.0598)
resblocks.2.convs2.2.weight tensor(2.1984)
resblocks.3.con

In [77]:
v = e0['conv_post.weight_v']
g = e0['conv_post.weight_g']
v_norm = torch.norm(v, p=2, dim=list(range(1, v.dim())), keepdim=True)
# v_norm = v_norm.expand_as(v)
cc = g * (v / v_norm)

print("source")
print(v)
print(g)
print(v_norm)
print(cc)

print("gen")
print(e1['conv_post.weight'])
print(e2['conv_post.weight'])
print(e00['conv_post.weight'])


source
tensor([[[-0.0308, -0.0057,  0.0640,  0.0099, -0.0444, -0.0029,  0.0103],
         [-0.0226, -0.0342, -0.0365, -0.0237,  0.0246,  0.0494,  0.0345],
         [ 0.0065, -0.0014, -0.0038, -0.0020, -0.0052,  0.0123, -0.0054],
         [-0.0122, -0.0232, -0.0290, -0.0345, -0.0298, -0.0241, -0.0140],
         [ 0.0027, -0.0020, -0.0177,  0.0298, -0.0022, -0.0302,  0.0192],
         [-0.0073, -0.0278,  0.0079,  0.0642,  0.0122, -0.0334, -0.0153],
         [-0.0308, -0.0542, -0.0613, -0.0726, -0.0612, -0.0549, -0.0334],
         [-0.0057,  0.0011,  0.0467, -0.0799,  0.0451, -0.0020, -0.0052],
         [ 0.0311, -0.0107, -0.0583, -0.0124,  0.0530,  0.0199, -0.0230],
         [-0.0086,  0.0416, -0.0683,  0.0493, -0.0143, -0.0045,  0.0044],
         [ 0.0170,  0.0303,  0.0361,  0.0265, -0.0302, -0.0474, -0.0320],
         [ 0.0071,  0.0005, -0.0527,  0.0878, -0.0431, -0.0066,  0.0070],
         [ 0.0207,  0.0510, -0.0053, -0.0791, -0.0562,  0.0274,  0.0419],
         [ 0.0134,  0.0303,  0.