In [14]:
from datasets import load_dataset
from utils import mp3_compress, opus_compress, encodec_compress
from utils import hf_audio_encode
from transformers import pipeline
from evaluate import evaluator
import encodec
import matplotlib.pyplot as plt
import numpy as np
from log_wmse_audio_quality import calculate_log_wmse

In [2]:
def mp3_compress_cv(sample):
    audio = sample['audio']['array'].unsqueeze(0)
    fs = sample['audio']['sampling_rate']
    audio,bps = mp3_compress(audio,fs)
    encoded = hf_audio_encode(audio,fs)
    sample['audio'] = encoded
    sample['bps'] = bps
    return sample
def opus_compress_cv(sample):
    audio = sample['audio']['array'].unsqueeze(0)
    fs = sample['audio']['sampling_rate']
    audio,bps = opus_compress(audio,fs)
    encoded = hf_audio_encode(audio,fs)
    sample['audio'] = encoded
    sample['bps'] = bps
    return sample

device = "cuda"
encodec_model_48_3 = encodec.EncodecModel.encodec_model_48khz()
encodec_model_48_3.set_target_bandwidth(6)
encodec_model_48_3.to(device)
def encodec_48_3_compress(sample):
    audio = sample['audio']['array'].unsqueeze(0)
    fs = sample['audio']['sampling_rate']
    audio,bps = encodec_compress(audio,fs, encodec_model_48_3, device)
    encoded = hf_audio_encode(audio,fs)
    sample['audio'] = encoded
    sample['bps'] = bps
    return sample 

In [3]:
audio_compression_methods = [
    mp3_compress_cv,
    opus_compress_cv,
    encodec_48_3_compress
]

In [4]:
common_voice = load_dataset("mozilla-foundation/common_voice_11_0",
                             "en",
                             split="validation[:802]"
                            ).with_format("torch")

In [5]:
exclude_idx = [362, 711]
common_voice = [common_voice.select(
    (
        i for i in range(len(common_voice)) 
        if i not in set(exclude_idx)
    )
)]



In [None]:
for method in audio_compression_methods:
    common_voice.append(common_voice[0].map(method))

In [23]:
log_wmse = []
for dataset in common_voice:
    log_wmse.append([])
    for i_sample,sample in enumerate(common_voice[0]):
        compressed_sample = dataset[i_sample]
        sample_rate = sample['audio']['sampling_rate']
        reference = sample['audio']['array'].numpy()
        distorted = compressed_sample['audio']['array'].numpy()
        log_wmse[-1].append(calculate_log_wmse(reference,reference,distorted,sample_rate))

In [38]:
avg_quality = [np.mean(log_wmse_i) for log_wmse_i in log_wmse][1:]

In [42]:
avg_quality

[11.787493639651426, 4.458682799052945, 6.812554373030236]

In [35]:
audio_bps = [method['bps'].mean().item() for method in common_voice[1:]]

In [43]:
audio_bps

[0.6700001358985901, 0.14700429141521454, 0.12622013688087463]