In [1]:
import io
import time
import torch
import numpy as np
import PIL
import torchaudio
import matplotlib.pyplot as plt
import opuspy
import tempfile
from einops import rearrange
from IPython.display import Audio
from torchvision.transforms import ToPILImage, PILToTensor
from datasets import load_dataset, Image
from spauq.core.metrics import spauq_eval
import cdpam
class Config: pass

In [2]:
cdpam_loss = cdpam.CDPAM()

  state = torch.load(modfolder,map_location="cpu")['state']


In [3]:
MUSDB = load_dataset("danjacobellis/musdb18HQ", split='validation')

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

In [4]:
def pad(audio, p=2**16):
    B,C,L = audio.shape
    padding_size = (p - (L % p)) % p
    if padding_size > 0:
        audio = torch.nn.functional.pad(audio, (0, padding_size), mode='constant', value=0)
    return audio

In [5]:
sample = MUSDB[5]
device="cpu"
max_duration = 44100*180

with torch.no_grad():
    x, fs = torchaudio.load(sample['audio']['bytes'])
    x = x[:,:max_duration]
    L = x.shape[-1]
    x_padded = pad(x.unsqueeze(0), 2**16)[0]
    
    t0 = time.time()
    buff = io.BytesIO()
    torchaudio.save(
        uri=buff,
        src=x_padded,
        sample_rate=fs,
        channels_first=True,
        format='opus',
        encoding='OPUS',
        compression=torchaudio.io.CodecConfig(bit_rate=12000)
    )
    opus_bytes = buff.getbuffer()
    encode_time = time.time() - t0

    with tempfile.NamedTemporaryFile(delete=False, suffix='.opus') as temp_file:
        temp_file.write(opus_bytes)
        temp_file_path = temp_file.name
    
    t0 = time.time()
    x_hat, fs2 = torchaudio.load(temp_file_path)
    x_hat = torchaudio.transforms.Resample(fs2,fs)(x_hat)
    x_hat = x_hat[:,:L]
    decode_time = time.time() - t0

    bps = 8*len(opus_bytes)/(x.numel())
    PSNR = 20*np.log10(2) - 10*np.log10(torch.nn.functional.mse_loss(x,x_hat.to("cpu")))
    SDR = spauq_eval(x,x_hat.to("cpu"),fs=fs)
    SSDR = SDR['SSR']
    SRDR = SDR['SRR']
    cdpam = cdpam_loss.forward(x.to(device),x_hat).mean().item()



In [12]:
Audio(x.numpy()[:,:500000], rate=44100)

In [13]:
Audio(x_hat.numpy()[:,:500000], rate=44100)

In [7]:
max_duration = 44100*180
def walloc_compress(sample):
    with torch.no_grad():
        x, fs = torchaudio.load(sample['audio']['bytes'])
        x = x[:,:max_duration]
        L = x.shape[-1]
        x_padded = pad(x.unsqueeze(0), 2**16).to(device)
        t0 = time.time()
        X = codec.wavelet_analysis(x_padded,codec.J)
        Y = codec.encoder(X)
        ℓ = Y.shape[-1]
        Y = pad(Y,256)
        Y = rearrange(Y, 'b c (w h) -> b c w h', h=256).to("cpu")
        webp = walloc.latent_to_pil(Y,codec.latent_bits,3)[0]
        buff = io.BytesIO()
        webp.save(buff, format='WEBP', lossless=True)
        webp_bytes = buff.getbuffer()
        encode_time = time.time() - t0
    
        t0 = time.time()
        Y = walloc.pil_to_latent([PIL.Image.open(buff)], codec.latent_dim, codec.latent_bits, 3).to(device)
        X_hat = codec.decoder(rearrange(Y.to(device), 'b c h w -> b c (h w)')[:,:,:ℓ])
        x_hat = codec.wavelet_synthesis(X_hat,codec.J)
        x_hat = codec.post(x_hat)
        x_hat = x_hat[0,:,:L].clamp(-1., 1.)
        decode_time = time.time() - t0
    
        bps = 8*len(webp_bytes)/(x.numel())
        PSNR = 20*np.log10(2) - 10*np.log10(torch.nn.functional.mse_loss(x,x_hat.to("cpu")))
        SDR = spauq_eval(x,x_hat.to("cpu"),fs=fs)
        SSDR = SDR['SSR']
        SRDR = SDR['SRR']
        cdpam = cdpam_loss.forward(x.to(device),x_hat).mean().item()
        
    return {
        'compressed': webp_bytes,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'bps': bps,
        'L': L,
        'PSNR': PSNR,
        'SSDR': SSDR,
        'SRDR': SRDR,
        'CDPAM': cdpam
    }

In [8]:
device = "cuda"
codec = codec.to(device)
gpu = MUSDB.map(
    walloc_compress,
    writer_batch_size=16,
)
gpu = gpu.cast_column('compressed',Image())

Map:   0%|          | 0/250 [00:00<?, ? examples/s]



In [9]:
def walloc_compress_cpu(sample):
    with torch.no_grad():
        x, fs = torchaudio.load(sample['audio']['bytes'])
        x = x[:,:max_duration]
        L = x.shape[-1]
        x_padded = pad(x.unsqueeze(0), 2**16).to(device)
        t0 = time.time()
        X = codec.wavelet_analysis(x_padded,codec.J)
        Y = codec.encoder(X)
        ℓ = Y.shape[-1]
        Y = pad(Y,256)
        Y = rearrange(Y, 'b c (w h) -> b c w h', h=256).to("cpu")
        webp = walloc.latent_to_pil(Y,codec.latent_bits,3)[0]
        buff = io.BytesIO()
        webp.save(buff, format='WEBP', lossless=True)
        webp_bytes = buff.getbuffer()
        encode_time = time.time() - t0
    
        t0 = time.time()
        Y = walloc.pil_to_latent([PIL.Image.open(buff)], codec.latent_dim, codec.latent_bits, 3).to(device)
        X_hat = codec.decoder(rearrange(Y.to(device), 'b c h w -> b c (h w)')[:,:,:ℓ])
        x_hat = codec.wavelet_synthesis(X_hat,codec.J)
        x_hat = codec.post(x_hat)
        x_hat = x_hat[0,:,:L].clamp(-1., 1.)
        decode_time = time.time() - t0
        
    return {
        'cpu_encode_time': encode_time,
        'cpu_decode_time': decode_time,
    }

In [10]:
device = "cpu"
codec = codec.to(device)
cpu = MUSDB.map(walloc_compress_cpu, writer_batch_size=16)
combined = gpu.add_column('cpu_encode_time',cpu['cpu_encode_time'])
combined = combined.add_column('cpu_decode_time',cpu['cpu_decode_time'])

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [11]:
metrics = [
    'encode_time',
    'decode_time',
    'bps',
    'PSNR',
    'SSDR',
    'SRDR',
    'CDPAM',
    'cpu_encode_time',
    'cpu_decode_time',
]

In [12]:
for metric in metrics:
    μ = np.mean(combined[metric])
    print(f"{metric}: {μ}")

encode_time: 0.23302008724212647
decode_time: 0.02391925811767578
bps: 0.8884749047115007
PSNR: 45.11973628997803
SSDR: 32.92267796734984
SRDR: 12.69360004136512
CDPAM: 6.876126027145802e-05
cpu_encode_time: 0.5509592161178589
cpu_decode_time: 2.262235514640808


In [13]:
combined.push_to_hub("danjacobellis/MUSDB_Stereo_Li_192c_J8_nf8",split='validation')

Uploading the dataset shards:   0%|          | 0/23 [00:00<?, ?it/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/MUSDB_Stereo_Li_192c_J8_nf8/commit/ea703312d8a323dba3968da1dce64687131df8ba', commit_message='Upload dataset', commit_description='', oid='ea703312d8a323dba3968da1dce64687131df8ba', pr_url=None, pr_revision=None, pr_num=None)

---

In [11]:
# L = 1048064
# sample = MUSDB[5]
# x, fs = torchaudio.load(sample['audio']['bytes'])
# x = x[:,:L]
# codec = codec.to("cuda")
# x = x.to("cuda").unsqueeze(0)
# with torch.no_grad():
#     X = codec.wavelet_analysis(x,codec.J)
#     Y = codec.encoder(X)
#     Z = codec.decoder(Y)
#     z = codec.wavelet_synthesis(Z,codec.J)
#     z = codec.post(z)
# x = x.to("cpu")[0]
# z = z.to("cpu")[0]

In [12]:
# Audio(x.numpy(),rate=44100)

In [13]:
# Audio(z.numpy(),rate=44100)

In [14]:
# start, end = 56500, 57000
# plt.figure(figsize=(8, 4), dpi=180)
# plt.plot(x[0, start:end], alpha=0.5, c='b', label='Ch.1 (Uncompressed)')
# plt.plot(z[0, start:end], alpha=0.5, c='g', label='Ch.1 (WaLLoC)')
# plt.plot(x[1, start:end], alpha=0.5, c='r', label='Ch.2 (Uncompressed)')
# plt.plot(z[1, start:end], alpha=0.5, c='purple', label='Ch.2 (WaLLoC)')

# plt.ylim([-1,0.6])
# plt.legend(loc='lower center')
# plt.box(False)
# plt.xticks([])
# plt.yticks([])
# # plt.savefig("test.svg")