Skip to content

Commit

Permalink
Adopt PyTorch's test util to librosa compatibilities test (pytorch#646)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored and bhargavkathivarapu committed May 19, 2020
1 parent 40338a7 commit 1a8478c
Showing 1 changed file with 80 additions and 86 deletions.
166 changes: 80 additions & 86 deletions test/test_librosa_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA
Expand All @@ -17,15 +18,8 @@
import common_utils


class _LibrosaMixin:
"""Automatically skip tests if librosa is not available"""
def setUp(self):
super().setUp()
if not IMPORT_LIBROSA:
raise unittest.SkipTest('Librosa not available')


class TestFunctional(_LibrosaMixin, unittest.TestCase):
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
class TestFunctional(TestCase):
"""Test suite for functions in `functional` module."""
def test_griffinlim(self):
# NOTE: This test is flaky without a fixed random seed
Expand All @@ -51,7 +45,7 @@ def test_griffinlim(self):
momentum=momentum, init=init, length=length)
lr_out = torch.from_numpy(lr_out).unsqueeze(0)

torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None):
librosa_fb = librosa.filters.mel(sr=sample_rate,
Expand All @@ -69,8 +63,8 @@ def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fm
norm=norm)

for i_mel_bank in range(n_mels):
torch.testing.assert_allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]),
atol=1e-4, rtol=1e-5)
self.assertEqual(
fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4, rtol=1e-5)

def test_create_fb(self):
self._test_create_fb()
Expand Down Expand Up @@ -101,7 +95,7 @@ def test_amplitude_to_DB(self):
lr_out = librosa.core.power_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out)

torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

# Amplitude to DB
multiplier = 20.0
Expand All @@ -110,7 +104,7 @@ def test_amplitude_to_DB(self):
lr_out = librosa.core.amplitude_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out)

torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)


@pytest.mark.parametrize('complex_specgrams', [
Expand Down Expand Up @@ -161,73 +155,73 @@ def _load_audio_asset(*asset_paths, **kwargs):
return sound, sample_rate


def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)

# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=power)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)

out_torch = spect_transform(sound).squeeze().cpu()
torch.testing.assert_allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)

# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
torch.testing.assert_allclose(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)

# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
torch.testing.assert_allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)

mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
torch.testing.assert_allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)

power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
torch.testing.assert_allclose(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)

# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)

# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in
# https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
#
# librosa_mfcc = librosa.feature.mfcc(
# y=sound_librosa, sr=sample_rate, n_mfcc = n_mfcc,
# hop_length=hop_length, n_fft=n_fft, htk=True, norm=None, n_mels=n_mels)

librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()

torch.testing.assert_allclose(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)


class TestTransforms(_LibrosaMixin, unittest.TestCase):
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
class TestTransforms(TestCase):
"""Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)

# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=power)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)

out_torch = spect_transform(sound).squeeze().cpu()
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)

# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertEqual(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)

# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertEqual(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)

mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
self.assertEqual(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)

power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertEqual(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)

# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)

# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in
# https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
#
# librosa_mfcc = librosa.feature.mfcc(
# y=sound_librosa, sr=sample_rate, n_mfcc = n_mfcc,
# hop_length=hop_length, n_fft=n_fft, htk=True, norm=None, n_mels=n_mels)

librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()

self.assertEqual(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)

def test_basics1(self):
kwargs = {
'n_fft': 400,
Expand All @@ -237,7 +231,7 @@ def test_basics1(self):
'n_mfcc': 40,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)

def test_basics2(self):
kwargs = {
Expand All @@ -248,7 +242,7 @@ def test_basics2(self):
'n_mfcc': 20,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)

# NOTE: Test passes offline, but fails on TravisCI (and CircleCI), see #372.
@unittest.skipIf('CI' in os.environ, 'Test is known to fail on CI')
Expand All @@ -261,7 +255,7 @@ def test_basics3(self):
'n_mfcc': 50,
'sample_rate': 24000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)

def test_basics4(self):
kwargs = {
Expand All @@ -272,7 +266,7 @@ def test_basics4(self):
'n_mfcc': 40,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)

@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope("sox")
Expand All @@ -295,7 +289,7 @@ def test_MelScale(self):
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
torch.testing.assert_allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)
self.assertEqual(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)

def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
Expand Down Expand Up @@ -338,7 +332,7 @@ def test_InverseMelScale(self):
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
torch.testing.assert_allclose(spec_ta, spec_lr, atol=threshold, rtol=1e-5)
self.assertEqual(spec_ta, spec_lr, atol=threshold, rtol=1e-5)

threshold = 1700.0
# This threshold was choosen empirically, based on the following observations
Expand Down

0 comments on commit 1a8478c

Please sign in to comment.