In [1]:
import io
import time
import torch
import numpy as np
import PIL
import torchaudio
import datasets
import matplotlib.pyplot as plt
import einops
from IPython.display import Audio
from types import SimpleNamespace
from torchvision.transforms.v2 import CenterCrop
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from IPython.display import Audio as play
from spauq.core.metrics import spauq_eval

In [2]:
def pad(audio, p=2**16):
    B, C, L = audio.shape
    padding_size = (p - (L % p)) % p
    if padding_size > 0:
        left_pad = padding_size // 2
        right_pad = padding_size - left_pad
        audio = torch.nn.functional.pad(audio, (left_pad, right_pad), mode='constant', value=0)
    return audio


In [3]:
dataset = datasets.load_dataset("danjacobellis/aria_ea_audio_preprocessed").with_format("numpy",dtype='float16')

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

In [4]:
device = 'cuda'
checkpoint = torch.load('../../hf/autocodec/aria_7ch_f128c28.pth', map_location="cpu",weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']
model = AutoCodecND(
    dim=1,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device)
model.load_state_dict(state_dict)
model.to(torch.bfloat16)
model.eval();

In [5]:
def evaluate_quality(sample):

    x = torch.tensor(sample['audio']).permute(1,0)
    L = x.shape[-1]
    
    t0 = time.time()
    x_padded = pad(x.unsqueeze(0), config.F*32).to(device).to(torch.bfloat16)
    with torch.no_grad():
        z = model.quantize.compand(model.encode(x_padded)).round().cpu()
    img_list = latent_to_pil(
        einops.rearrange(z,'b c (h w) -> c b h w', h=32),n_bits=16,C=1)
    buff_list = []
    for img in img_list:
        buff_list.append(io.BytesIO())
        img.save(buff_list[-1], format='TIFF', compression='tiff_adobe_deflate')
    encode_time = time.time() - t0
    
    cr = 2*x.numel()/sum(len(b.getbuffer()) for b in buff_list)
    
    t0 = time.time()
    zhat = pil_to_latent([PIL.Image.open(b) for b in buff_list], N=1, n_bits=8, C=1)
    zhat = z.clone().to(device).to(torch.bfloat16)
    with torch.no_grad():
        xhat = model.decode(zhat)
    xhat = xhat.clamp(-1,1)
    decode_time = time.time() - t0
    
    xhat = xhat.to("cpu").to(torch.float)
    xhat = CenterCrop((config.input_channels,x.shape[1]))(xhat)[0]
    mse = torch.nn.functional.mse_loss(x,xhat)
    psnr = -10*mse.log10().item() + 6.02
    
    SDR = spauq_eval(x.to(torch.float),xhat.to("cpu"),fs=48000)
    ssdr = SDR['SSR']
    srdr = SDR['SRR']

    return {
        'samples': x.numel(),
        'cr': cr,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'psnr': psnr,
        'ssdr': ssdr,
        'srdr': srdr,
    }

In [6]:
metrics = [
    'samples',
    'cr',
    'encode_time',
    'decode_time',
    'psnr',
    'ssdr',
    'srdr',
]

In [7]:
gpu_results = dataset['validation'].select(range(10)).map(evaluate_quality).with_format("torch",dtype=torch.float)

In [8]:
print("mean\n---")
for metric in metrics:
    μ = gpu_results[metric].mean()
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(gpu_results['samples'])/1e6/np.array(gpu_results['encode_time']))} MS/sec")
print(f"{np.mean(np.array(gpu_results['samples'])/1e6/np.array(gpu_results['decode_time']))} MS/sec")

mean
---
samples: 2100000.0
cr: 1013.5135498046875
encode_time: 0.006712436676025391
decode_time: 0.009349226951599121
psnr: 30.780847549438477
ssdr: 11.865995407104492
srdr: 4.528044700622559
328.0713806152344 MS/sec
252.41958618164062 MS/sec


  print(f"{np.mean(np.array(gpu_results['samples'])/1e6/np.array(gpu_results['encode_time']))} MS/sec")
  print(f"{np.mean(np.array(gpu_results['samples'])/1e6/np.array(gpu_results['decode_time']))} MS/sec")
