In [1]:
!nvidia-smi --query-gpu=name --format=csv,noheader | head -n 1

NVIDIA GeForce RTX 4090


In [2]:
import io
import time
import torch
import numpy as np
import PIL
import torchaudio
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchvision.transforms import ToPILImage, PILToTensor
from datasets import load_dataset, Image
from walloc import walloc
from diffusers.models.autoencoders import AutoencoderOobleck
from spauq.core.metrics import spauq_eval
import cdpam
class Config: pass

In [3]:
codec = AutoencoderOobleck.from_pretrained(
    "stabilityai/stable-audio-open-1.0",
    subfolder='vae',
    torch_dtype=torch.float16
)
codec.eval();

  WeightNorm.apply(module, name, dim)


In [4]:
cdpam_loss = cdpam.CDPAM()

  state = torch.load(modfolder,map_location="cpu")['state']


In [5]:
MUSDB = load_dataset("danjacobellis/musdb18HQ", split='validation')

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

In [6]:
def pad(audio, p=2**16):
    B,C,L = audio.shape
    padding_size = (p - (L % p)) % p
    if padding_size > 0:
        audio = torch.nn.functional.pad(audio, (0, padding_size), mode='constant', value=0)
    return audio

In [7]:
max_duration = 44100*180
def stable_audio_compress(sample):
    with torch.no_grad():
        x, fs = torchaudio.load(sample['audio']['bytes'])
        x = x[:,:max_duration]
        L = x.shape[-1]

        t0 = time.time()
        x_padded = pad(x.unsqueeze(0), 2**16).to(device).to(torch.float16)
        Y = codec.encode(x_padded).latent_dist.mode().to(torch.float16).to("cpu")
        encode_time = time.time() - t0
    
        t0 = time.time()
        x_hat = codec.decode(Y.to(torch.float16).to(device)).sample
        x_hat = x_hat[0,:,:L].clamp(-1., 1.)
        decode_time = time.time() - t0
    
        bps = 16*Y.numel()/(x.numel())
        PSNR = 20*np.log10(2) - 10*np.log10(torch.nn.functional.mse_loss(x,x_hat.to("cpu")))
        SDR = spauq_eval(x,x_hat.to("cpu"),fs=fs)
        SSDR = SDR['SSR']
        SRDR = SDR['SRR']
        cdpam = cdpam_loss.forward(x.to(device),x_hat).mean().item()
        
    return {
        'compressed': Y,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'bps': bps,
        'L': L,
        'PSNR': PSNR,
        'SSDR': SSDR,
        'SRDR': SRDR,
        'CDPAM': cdpam
    }

In [None]:
device = "cuda"
codec = codec.to(device)
gpu = MUSDB.map(
    stable_audio_compress,
    writer_batch_size=16,
)



Map:   0%|          | 0/250 [00:00<?, ? examples/s]



In [None]:
max_duration = 44100*180
def stable_audio_compress_cpu(sample):
    with torch.no_grad():
        x, fs = torchaudio.load(sample['audio']['bytes'])
        x = x[:,:max_duration]
        L = x.shape[-1]

        t0 = time.time()
        x_padded = pad(x.unsqueeze(0), 2**16).to(device).to(torch.float)
        Y = codec.encode(x_padded).latent_dist.mode().to(torch.float16)
        encode_time = time.time() - t0
    
        t0 = time.time()
        x_hat = codec.decode(Y.to(torch.float).to(device)).sample
        x_hat = x_hat[0,:,:L].clamp(-1., 1.)
        decode_time = time.time() - t0
        
    return {
        'cpu_encode_time': encode_time,
        'cpu_decode_time': decode_time,
    }

In [None]:
device = "cpu"
codec = AutoencoderOobleck.from_pretrained(
    "stabilityai/stable-audio-open-1.0",
    subfolder='vae',
    torch_dtype=torch.float
)
codec.eval();
codec = codec.to(device)
cpu = MUSDB.map(stable_audio_compress_cpu, writer_batch_size=16)
combined = gpu.add_column('cpu_encode_time',cpu['cpu_encode_time'])
combined = combined.add_column('cpu_decode_time',cpu['cpu_decode_time'])

In [None]:
metrics = [
    'encode_time',
    'decode_time',
    'bps',
    'PSNR',
    'SSDR',
    'SRDR',
    'CDPAM',
    'cpu_encode_time',
    'cpu_decode_time',
]

In [None]:
for metric in metrics:
    μ = np.mean(combined[metric])
    print(f"{metric}: {μ}")

In [None]:
combined.push_to_hub("danjacobellis/MUSDB_stable_audio_fp16",split='validation')