In [1]:
from datasets import load_dataset
from io import BytesIO
import torch
import torchaudio
from spauq.core.metrics import spauq_eval
import zlib
import numpy as np
from dance.audio import RateDistortionAutoEncoder

In [2]:
dataset = load_dataset("danjacobellis/aria_ea_audio_preprocessed",split='validation').with_format("torch").select(range(0,1200,100))

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

In [3]:
def mp3_q9(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    config = torchaudio.io.CodecConfig(qscale=9)
    size_bytes = []
    recovered = []
    for channel in audio:
        with BytesIO() as f:
            torchaudio.save(f, channel.unsqueeze(0), format="mp3", sample_rate=fs, compression=config)
            size_bytes.append(len(f.getvalue()))
            f.seek(0)
            recovered.append(torchaudio.load(f,format="mp3")[0])
    recovered = torch.cat(recovered)
    eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
    sample['mp3_q9_SSR'] = eval_output['SSR']
    sample['mp3_q9_SRR'] = eval_output['SRR']
    sample['mp3_q9_cr'] = 3*audio.numel()/sum(size_bytes) # 24 bit audio
    return sample

In [4]:
dataset = dataset.map(mp3_q9)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]



In [5]:
def mp3_q1(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    config = torchaudio.io.CodecConfig(qscale=1)
    size_bytes = []
    recovered = []
    for channel in audio:
        with BytesIO() as f:
            torchaudio.save(f, channel.unsqueeze(0), format="mp3", sample_rate=fs, compression=config)
            size_bytes.append(len(f.getvalue()))
            f.seek(0)
            recovered.append(torchaudio.load(f,format="mp3")[0])
    recovered = torch.cat(recovered)
    eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
    sample['mp3_q1_SSR'] = eval_output['SSR']
    sample['mp3_q1_SRR'] = eval_output['SRR']
    sample['mp3_q1_cr'] = 3*audio.numel()/sum(size_bytes) # 24 bit audio
    return sample

In [6]:
dataset = dataset.map(mp3_q1)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

In [7]:
device = "cuda"
dance1_model = RateDistortionAutoEncoder()
checkpoint = torch.load("dance/audio_stage1_20e.pth")
dance1_model.load_state_dict(checkpoint['model_state_dict'])
dance1_model = dance1_model.to(device)

In [8]:
def dance1(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    with torch.no_grad():
        compressed = dance1_model.encode(audio.unsqueeze(0).cuda()).round().to(torch.int8).cpu().numpy()
        original_shape = compressed.shape
        compressed = zlib.compress(compressed.tobytes(),level=9)
        size_bytes = len(compressed)
        recovered = zlib.decompress(compressed)
        recovered = np.frombuffer(recovered, dtype=np.int8)
        recovered = recovered.reshape(original_shape)
        recovered = torch.tensor(recovered).to(torch.float).cuda()
        recovered = dance1_model.decode(recovered).cpu()[0]
        eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
        sample['dance1_SSR'] = eval_output['SSR']
        sample['dance1_SRR'] = eval_output['SRR']
        sample['dance1_cr'] = 3*audio.numel()/size_bytes# 24 bit audio
        return sample

In [9]:
dataset = dataset.map(dance1)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

In [16]:
device = "cuda"
dance2_model = RateDistortionAutoEncoder()
checkpoint = torch.load("dance/audio_stage2_2e.pth")
dance2_model.load_state_dict(checkpoint['model_state_dict'])
dance2_model = dance2_model.to(device)

In [17]:
def dance2(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    with torch.no_grad():
        compressed = dance2_model.encode(audio.unsqueeze(0).cuda()).round().to(torch.int8).cpu().numpy()
        original_shape = compressed.shape
        compressed = zlib.compress(compressed.tobytes(),level=9)
        size_bytes = len(compressed)
        recovered = zlib.decompress(compressed)
        recovered = np.frombuffer(recovered, dtype=np.int8)
        recovered = recovered.reshape(original_shape)
        recovered = torch.tensor(recovered).to(torch.float).cuda()
        recovered = dance2_model.decode(recovered).cpu()[0]
        eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
        sample['dance2_SSR'] = eval_output['SSR']
        sample['dance2_SRR'] = eval_output['SRR']
        sample['dance2_cr'] = 3*audio.numel()/size_bytes# 24 bit audio
        return sample

In [18]:
dataset = dataset.map(dance2)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

In [19]:
metrics = dataset.remove_columns(['audio','seq_name'])
for m in metrics.features:
    print(f"{m}: {metrics[m].mean()} (mean)")
    print(f"{m}: {metrics[m].median()} (median) ")

mp3_q9_SSR: 18.147558212280273 (mean)
mp3_q9_SSR: 15.21203327178955 (median) 
mp3_q9_SRR: 7.068716049194336 (mean)
mp3_q9_SRR: 5.95135498046875 (median) 
mp3_q9_cr: 26.980316162109375 (mean)
mp3_q9_cr: 26.646703720092773 (median) 
mp3_q1_SSR: 34.79498291015625 (mean)
mp3_q1_SSR: 28.872365951538086 (median) 
mp3_q1_SRR: 16.14482879638672 (mean)
mp3_q1_SRR: 13.828468322753906 (median) 
mp3_q1_cr: 8.73799991607666 (mean)
mp3_q1_cr: 8.577275276184082 (median) 
dance1_SSR: 11.579394340515137 (mean)
dance1_SSR: 7.241969108581543 (median) 
dance1_SRR: -0.7877809405326843 (mean)
dance1_SRR: 2.2263686656951904 (median) 
dance1_cr: 1519.7864990234375 (mean)
dance1_cr: 63.121192932128906 (median) 
dance2_SSR: 7.7751288414001465 (mean)
dance2_SSR: 4.414035320281982 (median) 
dance2_SRR: -8.304768562316895 (mean)
dance2_SRR: -1.9673932790756226 (median) 
dance2_cr: 3129.998779296875 (mean)
dance2_cr: 177.8354949951172 (median) 
