In [1]:
from datasets import load_dataset
from utils import mp3_compress, opus_compress, encodec_compress
from utils import hf_audio_encode
from demucs.separate import Separator
import encodec
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import tempfile
import museval

2023-10-24 17:59:08.658248: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-24 17:59:08.677114: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-24 17:59:08.677129: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-24 17:59:08.677145: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-10-24 17:59:08.681304: I tensorflow/core/platform/cpu_feature_g

In [2]:
device = "cuda"
encodec_model_48_6 = encodec.EncodecModel.encodec_model_48khz()
encodec_model_48_6.set_target_bandwidth(6)
encodec_model_48_6.to(device);

In [3]:
def mp3_compress_musdb(sample):
    audio = torch.tensor(sample['mixture']['array']).to(torch.float32)
    fs = sample['mixture']['sampling_rate']
    audio,bps = mp3_compress(audio,fs)
    sample['mixture'] = hf_audio_encode(audio,fs)
    sample['bps'] = bps
    return sample

def opus_compress_musdb(sample):
    audio = torch.tensor(sample['mixture']['array']).to(torch.float32)
    fs = sample['mixture']['sampling_rate']
    audio,bps = opus_compress(audio,fs)
    sample['mixture'] = hf_audio_encode(audio,fs)
    sample['bps'] = bps
    return sample

def encodec_48_6_compress_musdb(sample):
    audio = torch.tensor(sample['mixture']['array']).to(torch.float32)
    fs = sample['mixture']['sampling_rate']
    audio,bps = encodec_compress(audio,fs,encodec_model_48_6,device)
    sample['mixture'] = hf_audio_encode(audio,fs)
    sample['bps'] = bps
    return sample

In [4]:
audio_compression_methods = [
    mp3_compress_musdb,
    opus_compress_musdb,
    encodec_48_6_compress_musdb
]
models = []

In [5]:
musdb = load_dataset("danjacobellis/musdb",split='test')
musdb_mix = musdb.remove_columns(['drums', 'bass', 'other', 'vocals'])

In [7]:
musdb_compressed = []
for method in audio_compression_methods:
    musdb_compressed.append(musdb_mix.map(method))

In [9]:
musdb = musdb.with_format("torch")

In [13]:
separator = Separator()

In [36]:
sample = musdb[2]

ref = [
    sample['drums']['array'],
    sample['bass']['array'],
    sample['other']['array'],
    sample['vocals']['array']
]
ref = torch.cat([r.permute((1,0)).unsqueeze(0) for r in ref])

with tempfile.NamedTemporaryFile('w+b', delete=True) as f:
    torchaudio.save(
        f.name,
        sample['mixture']['array'],
        sample['mixture']['sampling_rate'],
        format="wav"
    )
    origin, separated = separator.separate_audio_file(f.name)
    sep = [
        separated['drums'],
        separated['bass'],
        separated['other'],
        separated['vocals']
    ]
    sep = torch.cat([s.permute((1,0)).unsqueeze(0) for s in sep])

In [39]:
sdr, isr, sir, sar, perm = museval.metrics.bss_eval(
    ref,
    sep,
    window=2 * 44100,
    hop=1.5 * 44100,
    compute_permutation=False,
    filters_len=512,
    framewise_filters=False,
    bsseval_sources_version=False,
)

In [40]:
np.mean(sdr)

7.746611040333569