In [None]:
!wget https://hf.co/danjacobellis/walloc/resolve/main/Stereo_Li_27c_test2.pth

In [None]:
!wget https://huggingface.co/danjacobellis/LCCL/resolve/main/vss_walloc_20x.pth

In [1]:
import io
import torch
import torchaudio
import torchvision
import numpy as np
from datasets import load_dataset
from walloc import walloc
from IPython.display import display, Audio, Image, update_display, HTML
from torchvision.transforms import ToPILImage
from fastprogress.fastprogress import master_bar, progress_bar
from ioae import IsotropicOobleckAutoencoder
from spauq.core.metrics import spauq_eval
from einops import rearrange
import cdpam
class Config: pass

In [2]:
def compress(x_valid, codec):
    compressed_data = codec.encoder(codec.wavelet_analysis(x_valid / 2, J=codec.J))
    return compressed_data

def decompress(compressed_data, codec):
    decompressed_data = 2*codec.clamp(codec.post(codec.wavelet_synthesis(codec.decoder(compressed_data), codec.J)))
    return decompressed_data

In [3]:
device = "cuda"

cdpam_loss = cdpam.CDPAM()

codec_checkpoint = torch.load("Stereo_Li_27c_test2.pth",map_location="cpu",weights_only=False)
codec_config = codec_checkpoint['config']
codec = walloc.Codec1D(
    channels = codec_config.channels,
    J = codec_config.J,
    Ne = codec_config.Ne,
    Nd = codec_config.Nd,
    latent_dim = codec_config.latent_dim,
    latent_bits = codec_config.latent_bits,
    lightweight_encode = codec_config.lightweight_encode,
    post_filter = codec_config.post_filter
).to(device)
codec.load_state_dict(codec_checkpoint['model_state_dict'])
codec.eval();

checkpoint = torch.load("vss_walloc_20x.pth",map_location="cpu",weights_only=False)
config = checkpoint['config']
model = IsotropicOobleckAutoencoder(
    channels=codec_config.latent_dim,
    patch_size=config.patch_size,
    embed_dim=config.embed_dim,
    depth=config.depth,
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
codec.load_state_dict(checkpoint['codec_state_dict']) # post filter
model.eval()
sum(p.numel() for p in model.parameters())/1e6

  state = torch.load(modfolder,map_location="cpu")['state']
  WeightNorm.apply(module, name, dim)


6.71166

In [4]:
valid_dataset = load_dataset("danjacobellis/musdb18hq_vss",split='validation')

In [5]:
def eval_vss(sample):
    with torch.no_grad():
        xb, fs = torchaudio.load(sample['audio_mix']['bytes'])
        vb, fs = torchaudio.load(sample['audio_vocal']['bytes'])
        B = xb.shape[-1]//config.length_samples
        x = torch.zeros((B,2,config.length_samples))
        v = torch.zeros((B,2,config.length_samples))
        for i_frame in range(B):
            i_start = i_frame*config.length_samples
            i_end = (i_frame+1)*config.length_samples
            x[i_frame,:,:] = xb[:,i_start:i_end]
            v[i_frame,:,:] = vb[:,i_start:i_end]
        x = x.to(device)
        v = v.to(device)
        xc = compress(x, codec)
        pred = model(xc)
        v_hat = decompress(pred,codec)
    
        v = rearrange(v, 'B C L -> C (B L)')/2
        v_hat = rearrange(v_hat, 'B C L -> C (B L)')
        # v_hat = (v.std()/v_hat.std())*v_hat
        
        PSNR = -10*np.log10(torch.nn.functional.mse_loss(v,v_hat).item())
        SDR = spauq_eval(v.cpu(),v_hat.cpu(),fs=fs)
        SSDR = SDR['SSR']
        SRDR = SDR['SRR']
        cdpam = cdpam_loss.forward(v,v_hat).mean().item()
    
        return {
            'PSNR': PSNR,
            'SSDR': SSDR,
            'SRDR': SRDR,
            'CDPAM': cdpam
        }

In [6]:
eval_results = valid_dataset.map(eval_vss)
eval_only = eval_results.remove_columns(['audio_mix', 'audio_vocal', 'path_mix'])

Map:   0%|          | 0/235 [00:00<?, ? examples/s]



In [7]:
# without normalization
print(np.median(eval_only['PSNR']))
print(np.median(eval_only['SSDR']))
print(np.median(eval_only['SRDR']))
print(np.median(-10*np.log10(eval_only['CDPAM'])))

30.5087168627729
5.270039860281203
-2.4720733922734865
35.89614139790771


In [8]:
# without normalization
print(np.mean(eval_only['PSNR']))
print(np.mean(eval_only['SSDR']))
print(np.mean(eval_only['SRDR']))
print(np.mean(-10*np.log10(eval_only['CDPAM'])))

30.878920732075407
5.409454004264736
-16.355997063462834
36.10839206162546


In [9]:
eval_results.push_to_hub("danjacobellis/MUSDB_vss_walloc_20x",split='validation')

Uploading the dataset shards:   0%|          | 0/8 [00:00<?, ?it/s]

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/MUSDB_vss_walloc_20x/commit/896cf6282b4126355a3bae2deec3c9c7840430a5', commit_message='Upload dataset', commit_description='', oid='896cf6282b4126355a3bae2deec3c9c7840430a5', pr_url=None, pr_revision=None, pr_num=None)