Skip to content

Commit

Permalink
Use torchaudio melscale 'slaney' instead of librosa in WaveRNN pipeli…
Browse files Browse the repository at this point in the history
…ne preprocessing (pytorch#1444)

* Use torchaudio melscale instead of librosa
  • Loading branch information
discort authored and Caroline Chen committed Apr 30, 2021
1 parent fcfa07a commit 66ff2e6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 31 deletions.
10 changes: 5 additions & 5 deletions examples/pipeline_wavernn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import LinearToMel, NormalizeDB
from processing import NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint


Expand Down Expand Up @@ -269,12 +269,12 @@ def main(args):
}

transforms = torch.nn.Sequential(
torchaudio.transforms.Spectrogram(**melkwargs),
LinearToMel(
torchaudio.transforms.MelSpectrogram(
sample_rate=args.sample_rate,
n_fft=args.n_fft,
n_mels=args.n_freq,
fmin=args.f_min,
f_min=args.f_min,
mel_scale='slaney',
**melkwargs,
),
NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
)
Expand Down
27 changes: 1 addition & 26 deletions examples/pipeline_wavernn/processing.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,7 @@
import librosa
import torch
import torch.nn as nn


# TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved
class LinearToMel(nn.Module):
def __init__(self, sample_rate, n_fft, n_mels, fmin, htk=False, norm="slaney"):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.n_mels = n_mels
self.fmin = fmin
self.htk = htk
self.norm = norm

def forward(self, specgram):
specgram = librosa.feature.melspectrogram(
S=specgram.squeeze(0).numpy(),
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
fmin=self.fmin,
htk=self.htk,
norm=self.norm,
)
return torch.from_numpy(specgram)


class NormalizeDB(nn.Module):
r"""Normalize the spectrogram with a minimum db value
"""
Expand All @@ -37,7 +12,7 @@ def __init__(self, min_level_db, normalization):
self.normalization = normalization

def forward(self, specgram):
specgram = torch.log10(torch.clamp(specgram, min=1e-5))
specgram = torch.log10(torch.clamp(specgram.squeeze(0), min=1e-5))
if self.normalization:
return torch.clamp(
(self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1
Expand Down

0 comments on commit 66ff2e6

Please sign in to comment.