In [1]:
import io
import time
import torch
import numpy as np
import PIL
import torchaudio
import datasets
import matplotlib.pyplot as plt
import einops
from IPython.display import Audio
from types import SimpleNamespace
from torchvision.transforms.v2 import CenterCrop
from transformers import EncodecModel, AutoProcessor
from IPython.display import Audio as play
from spauq.core.metrics import spauq_eval

In [2]:
def pad(audio, p=2**16):
    B, C, L = audio.shape
    padding_size = (p - (L % p)) % p
    if padding_size > 0:
        left_pad = padding_size // 2
        right_pad = padding_size - left_pad
        audio = torch.nn.functional.pad(audio, (left_pad, right_pad), mode='constant', value=0)
    return audio


In [3]:
dataset = datasets.load_dataset("danjacobellis/aria_ea_audio_preprocessed").with_format("numpy",dtype='float16')

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

In [4]:
device = 'cuda'
encodec_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to(device)
encodec_processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
def evaluate_quality(sample):
    
    x_orig = torch.tensor(sample['audio']).permute(1,0).to(torch.float)
    L = x_orig.shape[-1]
    
    encode_time = 0
    decode_time = 0
    compressed_size_bits = 0
    
    x = []
    xhat = []
    for c in range(0,7,2):
        channels = x_orig[c:c+2]
        if channels.shape[0] != 2:
            channels = torch.cat([channels,channels])
        inputs = encodec_processor(raw_audio=channels, sampling_rate=48000, return_tensors='pt')
        xi = inputs.input_values.to(device)
        padding_mask = inputs.padding_mask.to(device)
        
        t0 = time.time()
        with torch.no_grad():
            encoder_outputs = encodec_model.encode(xi, padding_mask)
        buff = io.BytesIO()
        torch.save(codes,buff)
        encode_time += time.time() - t0
    
        codes = encoder_outputs.audio_codes
        scales = encoder_outputs.audio_scales
        compressed_size_bits += 10*codes.numel() + 16*len(scales)
            
        t0 = time.time()
        with torch.no_grad():
            xhati = encodec_model.decode(codes, scales, padding_mask)[0]
        decode_time += time.time() - t0
    
        x.append(xi)
        xhat.append(xhati)
    
    x = torch.cat(x,dim=1)[0,:7,:L]
    xhat = torch.cat(xhat,dim=1)[0,:7,:L].clamp(-1,1)
    assert x.cpu().equal(x_orig)
    
    cr = 16*x.numel() / compressed_size_bits
    
    mse = torch.nn.functional.mse_loss(x,xhat)
    psnr = -10*mse.log10().item() + 6.02
    
    SDR = spauq_eval(x.cpu(),xhat.cpu(),fs=48000)
    ssdr = SDR['SSR']
    srdr = SDR['SRR']

    return {
        'samples': x.numel(),
        'cr': cr,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'psnr': psnr,
        'ssdr': ssdr,
        'srdr': srdr,
    }

In [7]:
metrics = [
    'samples',
    'cr',
    'encode_time',
    'decode_time',
    'psnr',
    'ssdr',
    'srdr',
]

In [7]:
gpu_results = dataset['validation'].select(range.map(evaluate_quality).with_format("torch",dtype=torch.float)

Map:   0%|          | 0/1215 [00:00<?, ? examples/s]



In [9]:
print("mean\n---")
for metric in metrics:
    μ = gpu_results[metric].mean()
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(gpu_results['samples'])/1e6/np.array(gpu_results['encode_time']))} MS/sec")
print(f"{np.mean(np.array(gpu_results['samples'])/1e6/np.array(gpu_results['decode_time']))} MS/sec")

mean
---
samples: 2100000.0
cr: 1013.5135498046875
encode_time: 0.00596621772274375
decode_time: 0.007144085131585598
psnr: 33.12417984008789
ssdr: 9.029350280761719
srdr: 1.9624067544937134
363.17364501953125 MS/sec
305.7591552734375 MS/sec


  print(f"{np.mean(np.array(gpu_results['samples'])/1e6/np.array(gpu_results['encode_time']))} MS/sec")
  print(f"{np.mean(np.array(gpu_results['samples'])/1e6/np.array(gpu_results['decode_time']))} MS/sec")
