In [1]:
import random
import torch

from datasets import load_dataset
from IPython.display import Audio as AudioDisplay
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality as pesq
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility as stoi
from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score as dnsmos
from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment as nisqa
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio as si_snr
from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio as si_sdr

from codec_latent_denoiser import CodecLatentDenoiser, CodecLatentDenoiserProcessor

# Load Model and Dataset

In [2]:
model_path = "gokulkarthik/codec-latent-denoiser-default"
data_path = "JacobLinCool/VoiceBank-DEMAND-16k"
sampling_rate = 16000

In [3]:
processor = CodecLatentDenoiserProcessor.from_pretrained(model_path)
model = CodecLatentDenoiser.from_pretrained(model_path).eval()

In [4]:
ds = load_dataset("JacobLinCool/VoiceBank-DEMAND-16k", num_proc=32)
ds

DatasetDict({
    train: Dataset({
        features: ['id', 'clean', 'noisy'],
        num_rows: 11572
    })
    test: Dataset({
        features: ['id', 'clean', 'noisy'],
        num_rows: 824
    })
})

In [None]:
def compute_score(preds: torch.Tensor, target: torch.Tensor, sampling_rate: int = 16000) -> dict:
    result = {}
    
    score = si_snr(preds=preds, target=target).item()
    result["si_snr"] = round(score, 2)
    
    score = si_sdr(preds=preds, target=target).item()
    result["si_sdr"] = round(score, 2)
        
    score = stoi(preds=preds, target=target, fs=sampling_rate).item()
    result["stoi"] = round(score, 2)
    
    if sampling_rate in [16000, 8000]:
        mode = "wb" if sampling_rate == 16000 else "nb"
        score = pesq(preds=preds, target=target, fs=sampling_rate, mode=mode).item()
        result["pesq"] = round(score, 1)
    
    score = dnsmos(preds=preds, fs=sampling_rate, personalized=False)[-1].item()
    result["dnsmos"] = round(score, 1)
    
    score = nisqa(preds=preds, fs=sampling_rate)[0].item()
    result["nisqa"] = round(score, 1)

    return result

# Test sample

In [6]:
sample_idx = random.randint(0, len(ds['train']))
sample = ds['train'][sample_idx]
clean = torch.from_numpy(sample['clean']['array'])
noisy = torch.from_numpy(sample['noisy']['array'])
print(clean.shape, noisy.shape)
sample

torch.Size([37986]) torch.Size([37986])


{'id': 'p276_434',
 'clean': {'path': 'p276_434.wav',
  'array': array([0.00140381, 0.00204468, 0.00119019, ..., 0.00854492, 0.00872803,
         0.00787354], shape=(37986,)),
  'sampling_rate': 16000},
 'noisy': {'path': 'p276_434.wav',
  'array': array([-0.03491211, -0.03805542,  0.00653076, ...,  0.04101562,
          0.03640747,  0.03549194], shape=(37986,)),
  'sampling_rate': 16000}}

In [7]:
AudioDisplay(clean, rate=sampling_rate)

In [8]:
AudioDisplay(noisy, rate=sampling_rate)

In [9]:
print(compute_score(preds=noisy, target=clean))

{'si_snr': 4.53, 'si_sdr': 4.53, 'stoi': 0.76, 'pesq': 1.1, 'dnsmos': 2.1, 'nisqa': 2.0}


In [10]:
print(compute_score(preds=noisy, target=noisy))
print(compute_score(preds=clean, target=clean))
print(compute_score(preds=clean, target=noisy))

{'si_snr': 182.12, 'si_sdr': 182.12, 'stoi': 1.0, 'pesq': 4.6, 'dnsmos': 2.1, 'nisqa': 2.0}


{'si_snr': 180.86, 'si_sdr': 180.86, 'stoi': 1.0, 'pesq': 4.6, 'dnsmos': 3.1, 'nisqa': 4.5}
{'si_snr': 4.53, 'si_sdr': 4.53, 'stoi': 0.68, 'pesq': 1.1, 'dnsmos': 3.1, 'nisqa': 4.5}


# Codec Latent Denoiser

In [11]:
with torch.inference_mode():
    inputs = processor(noisy)["input_values"]
    outputs = model(inputs, denoise=False, decode=True)
    outputs_denoised = model(inputs, denoise=True, decode=True)
    print("noisy:", noisy.shape)
    print("inputs:", inputs.shape)
    print("audio_embeddings:", outputs_denoised.audio_embeddings.shape)
    print("audio_generated:", outputs_denoised.audio_generated.shape)

noisy: torch.Size([37986])
inputs: torch.Size([1, 1, 38080])
audio_embeddings: torch.Size([1, 1024, 119])
audio_generated: torch.Size([1, 1, 38080])


In [12]:
noisy_generated = torch.zeros_like(noisy)
noisy_denoised_generated = torch.zeros_like(noisy)
T_min = min(outputs_denoised.audio_generated.shape[-1], noisy.shape[-1])
noisy_generated[:T_min] = outputs.audio_generated[0][0][:T_min]
noisy_denoised_generated[:T_min] = outputs_denoised.audio_generated[0][0][:T_min]

In [13]:
AudioDisplay(clean, rate=sampling_rate)

In [14]:
AudioDisplay(noisy, rate=sampling_rate)

In [15]:
AudioDisplay(noisy_generated, rate=sampling_rate)

In [16]:
AudioDisplay(noisy_denoised_generated, rate=sampling_rate)

In [17]:
print(compute_score(preds=noisy, target=clean))
print(compute_score(preds=noisy_generated, target=clean))
print(compute_score(preds=noisy_denoised_generated, target=clean))

{'si_snr': 4.53, 'si_sdr': 4.53, 'stoi': 0.76, 'pesq': 1.1, 'dnsmos': 2.1, 'nisqa': 2.0}
{'si_snr': -5.78, 'si_sdr': -5.78, 'stoi': 0.74, 'pesq': 1.1, 'dnsmos': 2.3, 'nisqa': 1.8}
{'si_snr': -4.07, 'si_sdr': -4.17, 'stoi': 0.66, 'pesq': 1.2, 'dnsmos': 2.0, 'nisqa': 1.1}
