In [1]:
from datasets import load_dataset
from io import BytesIO
import torch
import torchaudio
from spauq.core.metrics import spauq_eval
import zlib
import numpy as np
from dance.audio import RateDistortionAutoEncoder
from transformers import EncodecModel, AutoProcessor

In [2]:
dataset = load_dataset("danjacobellis/aria_ea_audio_preprocessed",split='validation').with_format("torch").select(range(0,1200,100))

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

In [3]:
encodec_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
encodec_processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")

  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


In [9]:
def encodec48(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    reference = []
    recovered = []
    size_bytes = []
    for c in range(0,7,2):
        channels = audio[c:c+2]
        if channels.shape[0] != 2:
            channels = torch.cat([channels,channels])
        with torch.no_grad():
            inputs = encodec_processor(raw_audio=channels, sampling_rate=fs, return_tensors='pt')
            encoder_outputs = encodec_model.encode(inputs["input_values"], inputs["padding_mask"])
            size_bytes.append(6.0*10.0*torch.tensor(encoder_outputs.audio_codes.shape).prod().item()/8)
            audio_values = encodec_model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0]
            audio_values = encodec_model(inputs["input_values"], inputs["padding_mask"]).audio_values
            reference.append(inputs['input_values'][0].detach())
            recovered.append(audio_values[0].detach())
    reference = torch.cat(reference)[0:7]
    recovered = torch.cat(recovered)[0:7]
    eval_output = spauq_eval(reference=reference,estimate=recovered,fs=fs)
    sample['encodec48_SSR'] = eval_output['SSR']
    sample['encodec48_SRR'] = eval_output['SRR']
    sample['encodec48_cr'] = 3*audio.numel()/sum(size_bytes) # 24 bit audio
    sample['encodec48_cr'] = (8/7)*sample['encodec48_cr'] # only 7 channels
    return sample

In [10]:
dataset = dataset.map(encodec48)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]



In [23]:
def mp3_q9_stereo(sample):
    fs = 48000
    audio = sample['audio'].permute(1, 0)
    config = torchaudio.io.CodecConfig(qscale=9)
    size_bytes = []
    recovered = []
    for c in range(0, 7, 2):
        channels = audio[c:c+2]
        if channels.shape[0] != 2:
            channels = torch.cat([channels,channels])
        
        with BytesIO() as f:
            torchaudio.save(f, channels, format="mp3", sample_rate=fs, compression=config)
            size_bytes.append(len(f.getvalue()))
            f.seek(0)
            recovered_channels, _ = torchaudio.load(f, format="mp3")
            recovered.append(recovered_channels)
    recovered = torch.cat(recovered, dim=0)[0:7]
    eval_output = spauq_eval(reference=audio, estimate=recovered, fs=fs)
    sample['mp3_q9_stereo_SSR'] = eval_output['SSR']
    sample['mp3_q9_stereo_SRR'] = eval_output['SRR']
    sample['mp3_q9_stereo_cr'] = 3 * audio.numel() / sum(size_bytes)
    sample['mp3_q9_stereo_cr'] = (8/7)*sample['mp3_q9_stereo_cr']
    return sample

In [24]:
dataset = dataset.map(mp3_q9_stereo)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

In [3]:
def mp3_q9(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    config = torchaudio.io.CodecConfig(qscale=9)
    size_bytes = []
    recovered = []
    for channel in audio:
        with BytesIO() as f:
            torchaudio.save(f, channel.unsqueeze(0), format="mp3", sample_rate=fs, compression=config)
            size_bytes.append(len(f.getvalue()))
            f.seek(0)
            recovered.append(torchaudio.load(f,format="mp3")[0])
    recovered = torch.cat(recovered)
    eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
    sample['mp3_q9_SSR'] = eval_output['SSR']
    sample['mp3_q9_SRR'] = eval_output['SRR']
    sample['mp3_q9_cr'] = 3*audio.numel()/sum(size_bytes) # 24 bit audio
    return sample

In [4]:
dataset = dataset.map(mp3_q9)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]



In [5]:
def mp3_q1(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    config = torchaudio.io.CodecConfig(qscale=1)
    size_bytes = []
    recovered = []
    for channel in audio:
        with BytesIO() as f:
            torchaudio.save(f, channel.unsqueeze(0), format="mp3", sample_rate=fs, compression=config)
            size_bytes.append(len(f.getvalue()))
            f.seek(0)
            recovered.append(torchaudio.load(f,format="mp3")[0])
    recovered = torch.cat(recovered)
    eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
    sample['mp3_q1_SSR'] = eval_output['SSR']
    sample['mp3_q1_SRR'] = eval_output['SRR']
    sample['mp3_q1_cr'] = 3*audio.numel()/sum(size_bytes) # 24 bit audio
    return sample

In [6]:
dataset = dataset.map(mp3_q1)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

In [7]:
device = "cuda"
dance1_model = RateDistortionAutoEncoder()
checkpoint = torch.load("dance/audio_stage1_20e.pth")
dance1_model.load_state_dict(checkpoint['model_state_dict'])
dance1_model = dance1_model.to(device)

In [8]:
def dance1(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    with torch.no_grad():
        compressed = dance1_model.encode(audio.unsqueeze(0).cuda()).round().to(torch.int8).cpu().numpy()
        original_shape = compressed.shape
        compressed = zlib.compress(compressed.tobytes(),level=9)
        size_bytes = len(compressed)
        recovered = zlib.decompress(compressed)
        recovered = np.frombuffer(recovered, dtype=np.int8)
        recovered = recovered.reshape(original_shape)
        recovered = torch.tensor(recovered).to(torch.float).cuda()
        recovered = dance1_model.decode(recovered).cpu()[0]
        eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
        sample['dance1_SSR'] = eval_output['SSR']
        sample['dance1_SRR'] = eval_output['SRR']
        sample['dance1_cr'] = 3*audio.numel()/size_bytes# 24 bit audio
        return sample

In [9]:
dataset = dataset.map(dance1)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

In [16]:
device = "cuda"
dance2_model = RateDistortionAutoEncoder()
checkpoint = torch.load("dance/audio_stage2_2e.pth")
dance2_model.load_state_dict(checkpoint['model_state_dict'])
dance2_model = dance2_model.to(device)

In [17]:
def dance2(sample):
    fs=48000
    audio = sample['audio'].permute(1,0)
    with torch.no_grad():
        compressed = dance2_model.encode(audio.unsqueeze(0).cuda()).round().to(torch.int8).cpu().numpy()
        original_shape = compressed.shape
        compressed = zlib.compress(compressed.tobytes(),level=9)
        size_bytes = len(compressed)
        recovered = zlib.decompress(compressed)
        recovered = np.frombuffer(recovered, dtype=np.int8)
        recovered = recovered.reshape(original_shape)
        recovered = torch.tensor(recovered).to(torch.float).cuda()
        recovered = dance2_model.decode(recovered).cpu()[0]
        eval_output = spauq_eval(reference=audio,estimate=recovered,fs=fs)
        sample['dance2_SSR'] = eval_output['SSR']
        sample['dance2_SRR'] = eval_output['SRR']
        sample['dance2_cr'] = 3*audio.numel()/size_bytes# 24 bit audio
        return sample

In [18]:
dataset = dataset.map(dance2)

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

In [25]:
metrics = dataset.remove_columns(['audio','seq_name'])
for m in metrics.features:
    print(f"{m}: {metrics[m].mean()} (mean)")
    print(f"{m}: {metrics[m].median()} (median) ")

encodec48_SSR: 2.7921228408813477 (mean)
encodec48_SSR: 1.3160523176193237 (median) 
encodec48_SRR: -11.988890647888184 (mean)
encodec48_SRR: -12.811376571655273 (median) 
encodec48_cr: 114.28572845458984 (mean)
encodec48_cr: 114.28571319580078 (median) 
mp3_q9_stereo_SSR: 18.158334732055664 (mean)
mp3_q9_stereo_SSR: 14.851905822753906 (median) 
mp3_q9_stereo_SRR: 7.02269172668457 (mean)
mp3_q9_stereo_SRR: 5.88826847076416 (median) 
mp3_q9_stereo_cr: 31.54193115234375 (mean)
mp3_q9_stereo_cr: 31.690698623657227 (median) 
