In [1]:
from datasets import load_dataset
from io import BytesIO
import torch
import torchaudio
from spauq.core.metrics import spauq_eval
import zlib
import numpy as np

In [2]:
dataset = load_dataset("danjacobellis/aria_ea_audio_preprocessed",split='validation').with_format("torch").select(range(0,1200,10))

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

In [3]:
def mp3(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    config = torchaudio.io.CodecConfig(qscale=9)
    size_bytes = []
    recovered = []
    for channel in audio:
        with BytesIO() as f:
            torchaudio.save(f, channel.unsqueeze(0), format="mp3", sample_rate=fs, compression=config)
            size_bytes.append(len(f.getvalue()))
            f.seek(0)
            recovered.append(torchaudio.load(f,format="mp3")[0])
    recovered = torch.cat(recovered)
    eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
    sample['mp3_SSR'] = eval_output['SSR']
    sample['mp3_SRR'] = eval_output['SRR']
    sample['mp3_cr'] = 3*audio.numel()/sum(size_bytes) # 24 bit audio
    return sample

In [4]:
dataset = dataset.map(mp3)

Map:   0%|          | 0/120 [00:00<?, ? examples/s]



In [5]:
from dance.audio import RateDistortionAutoEncoder
device = "cuda"
dance_model = RateDistortionAutoEncoder()
checkpoint = torch.load("dance/audio_stage2_2e.pth")
dance_model.load_state_dict(checkpoint['model_state_dict'])
dance_model = dance_model.to(device)

In [6]:
def dance(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    with torch.no_grad():
        compressed = dance_model.encode(audio.unsqueeze(0).cuda()).round().to(torch.int8).cpu().numpy()
        original_shape = compressed.shape
        compressed = zlib.compress(compressed.tobytes(),level=9)
        size_bytes = len(compressed)
        recovered = zlib.decompress(compressed)
        recovered = np.frombuffer(recovered, dtype=np.int8)
        recovered = recovered.reshape(original_shape)
        recovered = torch.tensor(recovered).to(torch.float).cuda()
        recovered = dance_model.decode(recovered).cpu()[0]
        eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
        sample['dance_SSR'] = eval_output['SSR']
        sample['dance_SRR'] = eval_output['SRR']
        sample['dance_cr'] = 3*audio.numel()/size_bytes# 24 bit audio
        return sample

In [7]:
dataset = dataset.map(dance)

Map:   0%|          | 0/120 [00:00<?, ? examples/s]



In [8]:
metrics = dataset.remove_columns(['audio','seq_name'])
for m in metrics.features:
    print(f"{m}: {metrics[m].mean()} (mean)")
    print(f"{m}: {metrics[m].median()} (median) ")

mp3_SSR: 19.104087829589844 (mean)
mp3_SSR: 18.616052627563477 (median) 
mp3_SRR: 7.199904441833496 (mean)
mp3_SRR: 7.280512809753418 (median) 
mp3_cr: 26.87841796875 (mean)
mp3_cr: 26.595407485961914 (median) 
dance_SSR: 8.892154693603516 (mean)
dance_SSR: 6.188332557678223 (median) 
dance_SRR: -4.051605701446533 (mean)
dance_SRR: 0.4496538043022156 (median) 
dance_cr: 1599.774169921875 (mean)
dance_cr: 154.68093872070312 (median) 
