We explore the possibility of maximizing the information represented in spectrograms by making the spectrogram basis functions trainable.
A number of experiments are conducted in which we compare the performance of trainable short-time Fourier transform (STFT) and Mel basis functions provided by FastAudio and nnAudio on two tasks: keyword spotting (KWS) and automatic speech recognition (ASR).
Broadcasting-residual network (BC-ResNet) as well as a Simple model (constructed with a linear layer) are used for these two tasks.
In our experiments, we explore four different training settings:
A Both gMel and gSTFT are non-trainable.
B gMel is trainable while gSTFT is fixed.
C gMel is fixed while gSTFT is trainable.
D Both gMel and gSTFT are trainable.
trainable-STFT-Mel
├── conf
│ ├─model
│ │ ├─BC_ResNet.yaml
│ │ ├─BC_ResNet_ASR.yaml
│ │ ├─BC_ResNet_maskout.yaml
│ │ │
│ │ ├─Linearmodel.yaml
│ │ ├─Linearmodel_ASR.yaml
│ │ ├─Linearmodel_maskout.yaml
│ │ │
│ │
│ ├─ASR_config.yaml
│ └─KWS_config.yaml
│
├── models
│ ├─nnAudio_model.py
│ └─fastaudio_model.py
├── tasks
│ ├─speechcommand.py
│ ├─speechcommand_maskout.py
│ ├─Timit.py
│ ├─Timit_maskout.py
│ │
├──train_KWS_hydra.py
├──train_ASR_hydra.py
├──phonemics_dict
├──requirements.txt
conf
contains the.yaml
configuration files.models
contains the model architectures.tasks
contains the lightning modules for KWS and ASR.train_KWS_hydra.py
andtrain_ASR_hydra.py
are training script of KWS and ASR respectively.phonemics_dict
is the phoneme labels provided in TIMIT which used for phoneme recognition.
Python 3.8.10
is required to run this repo.
You can install all required libraries at once via
pip install -r requirements.txt
python train_KWS_hydra.py
python train_ASR_hydra.py
Note:
- If this is your 1st time to train the model, you need to set
download
setting toTrue
via
python train_KWS_hydra.py download=True
- If you use CPU instead of GPU to train the model, set gpus to 0 via
python train_KWS_hydra.py gpus=0
Default:
- nnAudio BC_ResNet model:
model=BC_ResNet
- setting A (Both gMel and gSTFT are non-trainable):
model.spec_args.trainable_mel=False
model.spec_args.trainable_STFT=False
- 40 number of Mel bases:
model.spec_args.n_mels=40
- use 1 gpus
python train_KWS_hydra.py -m gpus=<arg> model=<arg> model.spec_args.trainable_mel=True,False model.spec_args.trainable_STFT=True,False
python train_KWS_hydra.py -m gpus=<arg> model=<arg> model.fastaudio.freeze=True,False model.spec_args.trainable=True,False
model.fastaudio.freeze
controls Mel basis functions:
model.fastaudio.freeze=True
represent mel non-trainablemodel.fastaudio.freeze=False
represent mel trainable
model.spec_args.trainable
controls STFT:
model.spec_args.trainable=True
represent STFT trainablemodel.spec_args.trainable=False
represent STFT non-trainable
Note:
- simply replace
train_KWS_hydra.py
withtrain_ASR_hydra.py
for ASR task.
python train_KWS_hydra.py -m gpus=<arg> model=<arg> model.spec_args.n_mels=10,20,30,40
python train_KWS_hydra.py -m gpus=<arg> model=<arg> model.fastaudio.n_mels=10,20,30,40
Note: simply replace train_KWS_hydra.py
with train_ASR_hydra.py
for ASR task.
python train_KWS_hydra.py gpus=<arg> model=<arg> model.maskout_start=<arg> model.maskout_end=<arg>
Applicable model:
- KWS nnAudio BC_ResNet
- KWS nnAudio Simple
- ASR nnAudio Simple
Note: simply replace train_KWS_hydra.py
with train_ASR_hydra.py
for ASR task.
python train_KWS_hydra.py gpus=<arg> model=<arg> model.random_mel=True
Applicable model:
- KWS nnAudio BC_ResNet
- ASR nnAudio BC_ResNet
- KWS nnAudio Simple
- ASR nnAudio Simple
Note: simply replace train_KWS_hydra.py
with train_ASR_hydra.py
for ASR task.