In [1]:
import io
import time
import torch
import numpy as np
import PIL
import torchaudio
import matplotlib.pyplot as plt
import tempfile
from einops import rearrange
from IPython.display import Audio
from torchvision.transforms import ToPILImage, PILToTensor
from datasets import load_dataset, Image
from spauq.core.metrics import spauq_eval
import cdpam
class Config: pass

In [2]:
cdpam_loss = cdpam.CDPAM()

  state = torch.load(modfolder,map_location="cpu")['state']


In [3]:
MUSDB = load_dataset("danjacobellis/musdb_segments", split='validation')

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/21 [00:00<?, ?it/s]

In [4]:
def pad(audio, p=2**16):
    B,C,L = audio.shape
    padding_size = (p - (L % p)) % p
    if padding_size > 0:
        audio = torch.nn.functional.pad(audio, (0, padding_size), mode='constant', value=0)
    return audio

In [5]:
def opus_compress_cpu(sample):
    with torch.no_grad():
        x, fs = torchaudio.load(sample['audio_mix']['bytes'],normalize=False)
        x = x.to(torch.float)
        x = x - x.mean()
        max_abs = x.abs().max()
        x = x / (max_abs + 1e-8)
        x = x/2
        L = x.shape[-1]
        x_padded = pad(x.unsqueeze(0), 2**16)[0]
        
        t0 = time.time()
        buff = io.BytesIO()
        torchaudio.save(
            uri=buff,
            src=x_padded,
            sample_rate=fs,
            channels_first=True,
            format='opus',
            encoding='OPUS',
            compression=torchaudio.io.CodecConfig(bit_rate=12000)
        )
        opus_bytes = buff.getbuffer()
        encode_time = time.time() - t0
    
        with tempfile.NamedTemporaryFile(delete=False, suffix='.opus') as temp_file:
            temp_file.write(opus_bytes)
            temp_file_path = temp_file.name
        
        t0 = time.time()
        x_hat, fs2 = torchaudio.load(temp_file_path)
        x_hat = torchaudio.transforms.Resample(fs2,fs)(x_hat)
        x_hat = x_hat[:,:L]
        decode_time = time.time() - t0
    
        bps = 8*len(opus_bytes)/(x.numel())
        PSNR = - 10*np.log10(torch.nn.functional.mse_loss(x,x_hat.to("cpu")))
        SDR = spauq_eval(x,x_hat.to("cpu"),fs=fs)
        SSDR = SDR['SSR']
        SRDR = SDR['SRR']
        cdpam = cdpam_loss.forward(x.to(device),x_hat).mean().item()

    return {
        'compressed': opus_bytes,
        'cpu_encode_time': encode_time,
        'cpu_decode_time': decode_time,
        'bps': bps,
        'L': L,
        'PSNR': PSNR,
        'SSDR': SSDR,
        'SRDR': SRDR,
        'CDPAM': cdpam
    }

In [6]:
device = "cpu"
cpu = MUSDB.map(
    opus_compress_cpu,
    writer_batch_size=16,
)

Map:   0%|          | 0/262 [00:00<?, ? examples/s]



In [7]:
metrics = [
    'bps',
    'PSNR',
    'SSDR',
    'SRDR',
    'CDPAM',
    'cpu_encode_time',
    'cpu_decode_time',
]

In [8]:
for metric in metrics:
    μ = np.mean(cpu[metric])
    print(f"{metric}: {μ}")

bps: 0.13497361336045594
PSNR: 30.370485604264353
SSDR: 16.734496307564836
SRDR: 5.034322195279302
CDPAM: 0.00010327670883659704
cpu_encode_time: 0.3653500262107558
cpu_decode_time: 0.04150824692413097


In [9]:
cpu.push_to_hub("danjacobellis/MUSDB_opus_12kbps",split='validation')

Uploading the dataset shards:   0%|          | 0/23 [00:00<?, ?it/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/MUSDB_opus_12kbps/commit/6a242ef9a89dfd93897ddde23a31c302cf6f0b45', commit_message='Upload dataset', commit_description='', oid='6a242ef9a89dfd93897ddde23a31c302cf6f0b45', pr_url=None, pr_revision=None, pr_num=None)