## Download from GitHub

In [1]:
!git clone https://github.com/jerryuhoo/Perceptual-Loss.git
%cd Perceptual-Loss

Cloning into 'Perceptual-Loss'...
remote: Enumerating objects: 29, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 29 (delta 12), reused 20 (delta 6), pack-reused 0[K
Receiving objects: 100% (29/29), 818.28 KiB | 10.91 MiB/s, done.
Resolving deltas: 100% (12/12), done.
/content/Perceptual-Loss


## Display Audio

In [None]:
from IPython.display import Audio

Audio('test_wavs/v_gt.wav')

In [3]:
Audio('test_wavs/v_bad.wav')

In [4]:
Audio('test_wavs/v_good.wav')

## Load audio

In [5]:
import torchaudio
import torch
from psycho_acoustic_loss import (
    psycho_acoustic_loss,
    compute_STFT,
)

In [6]:
waveform, sample_rate = torchaudio.load("audio_mp3_align.wav")
audio_mp3_align = waveform[0]

waveform, sample_rate = torchaudio.load("audio_original.wav")
audio_original = waveform[0]

waveform, sample_rate = torchaudio.load("audio_quantized.wav")
audio_quantized = waveform[0]



## MSE loss in time domain

In [7]:
from sklearn.metrics import mean_squared_error

mse_quant=mean_squared_error(audio_original, audio_quantized)

mse_mp3=mean_squared_error(audio_original, audio_mp3_align[:220500])

print("mse_quant=", mse_quant)

print("mse_mp3=", mse_mp3)

print("time domain mse loss ratio: mp3/quant", mse_mp3 / mse_quant)

mse_quant= 0.013066127
mse_mp3= 0.00034247426
time domain mse loss ratio: mp3/quant 0.026210848


In [8]:
# load audios
ys_mp3_align = compute_STFT(audio_mp3_align, N=1024).unsqueeze(0).unsqueeze(0)
ys_original = compute_STFT(audio_original, N=1024).unsqueeze(0).unsqueeze(0)
ys_quantized = compute_STFT(audio_quantized, N=1024).unsqueeze(0).unsqueeze(0)
# shape: [batch size, channels, N+1, frame]

ys_mp3_align = ys_mp3_align[:, :, :, :ys_original.shape[-1]]

## MSE loss in frequency domain

In [9]:
freq_mse_quant=mean_squared_error(ys_original.squeeze(), ys_quantized.squeeze())

freq_mse_mp3=mean_squared_error(ys_original.squeeze(), ys_mp3_align.squeeze())

print("mse_quant=", freq_mse_quant)

print("mse_mp3=", freq_mse_mp3)

print("frequency domain mse loss ratio: mp3/quant", freq_mse_mp3 / freq_mse_quant)

mse_quant= 29656.951
mse_mp3= 681.46826
frequency domain mse loss ratio: mp3/quant 0.022978365


## Psycho-acoustic loss with masking threshold weighting, without LTQ

In [10]:
# single file example with weighting
mp3_ploss = psycho_acoustic_loss(ys_mp3_align, ys_original, fs=sample_rate)
print("loss: mp3, original", mp3_ploss.item())

quant_ploss = psycho_acoustic_loss(ys_quantized, ys_original, fs=sample_rate)
print("loss: quantized, original", quant_ploss.item())

print("PL mse loss ratio: mp3/quant", mp3_ploss / quant_ploss)

loss: mp3, original 0.26612234115600586
loss: quantized, original 106.79635620117188
PL mse loss ratio: mp3/quant tensor(0.0025)


## Psycho-acoustic loss without weighting, without LTQ

In [11]:
# single file example without weighting, only calculate mt difference
mp3_ploss = psycho_acoustic_loss(ys_mp3_align, ys_original, fs=sample_rate, use_weighting=False)
print("loss: mp3, original", mp3_ploss.item())

quant_ploss = psycho_acoustic_loss(ys_quantized, ys_original, fs=sample_rate, use_weighting=False)
print("loss: quantized, original", quant_ploss.item())
print("PL mse loss ratio: mp3/quant", mp3_ploss / quant_ploss)

loss: mp3, original 330.6063537597656
loss: quantized, original 102924.015625
PL mse loss ratio: mp3/quant tensor(0.0032)


## Psycho-acoustic loss with weighting, with LTQ

In [12]:
# single file example without weighting, only calculate mt difference
mp3_ploss = psycho_acoustic_loss(ys_mp3_align, ys_original, fs=sample_rate, use_weighting=True, use_LTQ=True)
print("loss: mp3, original", mp3_ploss.item())

quant_ploss = psycho_acoustic_loss(ys_quantized, ys_original, fs=sample_rate, use_weighting=True, use_LTQ=True)
print("loss: quantized, original", quant_ploss.item())
print("PL mse loss ratio: mp3/quant", mp3_ploss / quant_ploss)

loss: mp3, original 0.15484167635440826
loss: quantized, original 28.591999053955078
PL mse loss ratio: mp3/quant tensor(0.0054)


## Psycho-acoustic loss without weighting, with LTQ

In [13]:
# single file example without weighting, only calculate mt difference
mp3_ploss = psycho_acoustic_loss(ys_mp3_align, ys_original, fs=sample_rate, use_weighting=False, use_LTQ=True)
print("loss: mp3, original", mp3_ploss.item())

quant_ploss = psycho_acoustic_loss(ys_quantized, ys_original, fs=sample_rate, use_weighting=False, use_LTQ=True)
print("loss: quantized, original", quant_ploss.item())
print("PL mse loss ratio: mp3/quant", mp3_ploss / quant_ploss)

loss: mp3, original 197.8645477294922
loss: quantized, original 94812.5546875
PL mse loss ratio: mp3/quant tensor(0.0021)


In [14]:
# multi-batch example

# Concatenating all ys_pred inputs along a new batch dimension
ys_pred = torch.cat(
    [ys_mp3_align, ys_quantized], dim=0
)  # Shape: [2, ...other dims...]


# Replicating ys_true to have the same batch size as ys_pred
ys_original_batch = ys_original.repeat(2, 1, 1, 1)  # Shape: [2, ...other dims...]


# Compute loss for all batch entries at once
loss = psycho_acoustic_loss(ys_pred, ys_original_batch, fs=sample_rate)
print("loss", loss.item())

loss 53.53124237060547
