In [1]:
import io
import time
import torch
import numpy as np
import PIL
import torchaudio
import datasets
import matplotlib.pyplot as plt
import einops
from IPython.display import Audio
from types import SimpleNamespace
from torchvision.transforms.v2 import CenterCrop
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from IPython.display import Audio as play

In [2]:
device = 'cuda'
checkpoint = torch.load('../../hf/autocodec/cv_f512c8.pth', map_location="cpu",weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']
model = AutoCodecND(
    dim=1,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device)
model.load_state_dict(state_dict)
model.to(torch.bfloat16)
model.eval();

In [3]:
librispeech_dummy = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").cast_column('audio',datasets.Audio(decode=False))

In [4]:
def pad(audio, p=2**16):
    B,C,L = audio.shape
    padding_size = (p - (L % p)) % p
    if padding_size > 0:
        audio = torch.nn.functional.pad(audio, (0, padding_size), mode='constant', value=0)
    return audio

In [5]:
sample = librispeech_dummy[0]
with torch.no_grad():
    x, fs = torchaudio.load(sample['audio']['bytes'],normalize=False)
    x = x.to(torch.float)
    x = x - x.mean()
    max_abs = x.abs().max()
    x = x / (max_abs + 1e-8)
    x = x/2
    L = x.shape[-1]

    t0 = time.time()
    x_padded = pad(x.unsqueeze(0), 2**16).to(device).to(torch.bfloat16)
    z = model.quantize.compand(model.encode(x_padded)).round().cpu()
    img_list = latent_to_pil(
        einops.rearrange(z,'b c (h w) -> c b h w', h=16),
        n_bits=16,
        C=1
    )
    buff_list = []
    for img in img_list:
        buff_list.append(io.BytesIO())
        img.save(buff_list[-1], format='TIFF', compression='tiff_adobe_deflate')
    encode_time = time.time() - t0

    CR = 2*x.numel()/sum(len(b.getbuffer()) for b in buff_list)
    
    latent_decoded = pil_to_latent([PIL.Image.open(b) for b in buff_list], N=1, n_bits=8, C=1)
    latent_decoded = einops.rearrange(latent_decoded, 'c b h w -> b c (h w)')
    x_hat = model.decode(latent_decoded.to(device).to(torch.bfloat16))
    x_hat = x_hat.clamp(-1,1)
    decode_time = time.time() - t0
    x_hat = x_hat.to("cpu").to(torch.float)
    x_hat = CenterCrop((1,x.shape[1]))(x_hat)[0]
    mse = torch.nn.functional.mse_loss(x,x_hat)
    PSNR = -10*mse.log10().item() + 6.02