In [1]:
from datasets import load_dataset
from utils import mp3_compress, opus_compress, encodec_compress
from utils import hf_audio_encode
from demucs.separate import Separator
import encodec
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import tempfile
import museval
import cdpam
import torch
import gc

2023-10-25 15:37:05.830677: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-25 15:37:05.849366: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-25 15:37:05.849380: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-25 15:37:05.849395: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-10-25 15:37:05.853454: I tensorflow/core/platform/cpu_feature_g

In [2]:
device = "cuda"
encodec_model_48_6 = encodec.EncodecModel.encodec_model_48khz()
encodec_model_48_6.set_target_bandwidth(6)
encodec_model_48_6.to(device);

In [3]:
def mp3_compress_musdb(sample):
    audio = torch.tensor(sample['mixture']['array']).to(torch.float32)
    fs = sample['mixture']['sampling_rate']
    audio,bps = mp3_compress(audio,fs)
    sample['mixture'] = hf_audio_encode(audio,fs)
    sample['bps'] = bps
    return sample

def opus_compress_musdb(sample):
    audio = torch.tensor(sample['mixture']['array']).to(torch.float32)
    fs = sample['mixture']['sampling_rate']
    audio,bps = opus_compress(audio,fs)
    sample['mixture'] = hf_audio_encode(audio,fs)
    sample['bps'] = bps
    return sample

def encodec_48_6_compress_musdb(sample):
    audio = torch.tensor(sample['mixture']['array']).to(torch.float32)
    fs = sample['mixture']['sampling_rate']
    audio,bps = encodec_compress(audio,fs,encodec_model_48_6,device)
    sample['mixture'] = hf_audio_encode(audio,fs)
    sample['bps'] = bps
    return sample

In [4]:
audio_compression_methods = [
    mp3_compress_musdb,
    opus_compress_musdb,
    encodec_48_6_compress_musdb
]
musdb = load_dataset("danjacobellis/musdb",split='test')
musdb_mix = musdb.remove_columns(['drums', 'bass', 'other', 'vocals'])
separator = Separator()
fs = musdb[0]['mixture']['sampling_rate']

In [5]:
musdb_compressed = [musdb_mix.map(method).with_format("torch") for method in audio_compression_methods]

In [6]:
cdpam_metric = cdpam.CDPAM()
mse_metric = torch.nn.MSELoss()
cdpam_distance = []
mse_distance = []
min_len = 999999999;
for dataset in musdb_compressed:
    cdpam_distance.append([])
    mse_distance.append([])
    for i_sample,sample in enumerate(musdb_mix.with_format("torch")):
        compressed_sample = dataset[i_sample]
        sample_rate = sample['mixture']['sampling_rate']
        reference = sample['mixture']['array']
        distorted = compressed_sample['mixture']['array']
        cdpam_distance[-1].append([])
        mse_distance[-1].append([])
        for i_chunk in range(50):
            ind1 = 100000*i_chunk
            ind2 = 100000*(i_chunk+1)
            ref = reference[:,ind1:ind2]
            dis = distorted[:,ind1:ind2]
            cdpam_distance[-1][-1].append(cdpam_metric.forward(ref,dis).detach().cpu().mean().item())
            mse_distance[-1][-1].append(mse_metric.forward(ref,dis).detach().cpu().mean().item())
            gc.collect()
            torch.cuda.empty_cache()

In [8]:
PSNR = [-10*np.log10(np.mean(d)) for d in mse_distance]
PSNR

[29.171616564378304, 22.170907366259108, 24.956013231093944]

In [9]:
cdpam_PSNR = [-10*np.log10(np.mean(d)) for d in cdpam_distance]
cdpam_PSNR

[38.432159884379466, 36.46644744986545, 45.33439174079839]

In [12]:
audio_bps = [method['bps'].mean().item() for method in musdb_compressed]
audio_bps

[0.36287301778793335, 0.06615560501813889, 0.06871381402015686]