Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add diffusion-base SE model to ESPnet-SE #5572

Merged
merged 64 commits into from
Jan 17, 2024
Merged

Conversation

LiChenda
Copy link
Contributor

What?

  • Implement DCUNET in "Welker S, Richter J, Gerkmann T. Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain".
  • New Python files:
├── decoder
├── encoder
├── separator
├── layers
│   └── dcunet.py
├── diffusion
│   ├── __init__.py
│   ├── abs_diffusion.py
│   ├── sampling
│   │   ├── __init__.py
│   │   ├── correctors.py
│   │   └── predictors.py
│   ├── score_based_diffusion.py
│   └── sdes.py
└── diffusion_enh.py

image

  • Add an enhancement recipe in the WSJ dataset.
  • Update STFT/iSTFT enc/dec with spectrum transform functions (exponent and log transform)

Why?

Extend ESPnet-SE to support diffusion-based generative enhancement models.

Others

  • Working in progress, Debugging and tuning models.

@mergify mergify bot added the ESPnet2 label Nov 28, 2023
@sw005320 sw005320 added SE Speech enhancement New Features labels Nov 28, 2023
@sw005320 sw005320 added this to the v.202312 milestone Nov 28, 2023
@mergify mergify bot added the Installation label Nov 30, 2023
egs2/wsj/derever1/conf/tuning/train_enh_blstm_tf.yaml Outdated Show resolved Hide resolved
egs2/wsj/derever1/conf/tuning/train_enh_sgmse_ncsnpp.yaml Outdated Show resolved Hide resolved
egs2/wsj/derever1/local/convert2wav.sh Outdated Show resolved Hide resolved
egs2/wsj/derever1/local/create_wsj0_reverb.py Outdated Show resolved Hide resolved
egs2/wsj/derever1/run.sh Outdated Show resolved Hide resolved
espnet2/enh/diffusion_enh.py Show resolved Hide resolved
espnet2/enh/espnet_model.py Show resolved Hide resolved
espnet2/enh/espnet_model.py Show resolved Hide resolved
from torch_complex.tensor import ComplexTensor
from typeguard import check_argument_types

from espnet2.enh.layers.complex_utils import to_complex
from espnet2.layers.inversible_interface import InversibleInterface
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask

is_torch_1_10_plus = V(torch.__version__) >= V("1.10.0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to add it because current ESPnet already dropped support for PyTorch versions before 1.12.1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See line 93 of espnet2/layers/stft.py. Add this check to use native support for FFT and STFT on all CPU targets including ARM.

test/espnet2/enh/diffusion/test_score_based_diffusion.py Outdated Show resolved Hide resolved
@LiChenda
Copy link
Contributor Author

LiChenda commented Jan 2, 2024

Maybe it is better to rename the recipe to egs2/wsj0_reverb/enh1 or egs2/wsj0_chime3/enh1 instead of using wsj?

Done.

@LiChenda LiChenda changed the title [WIP] Add diffusion-base SE model to ESPnet-SE Add diffusion-base SE model to ESPnet-SE Jan 3, 2024
egs2/wsj0_chime3/enh1/local/data.sh Outdated Show resolved Hide resolved
egs2/wsj0_reverb/enh1/local/data.sh Outdated Show resolved Hide resolved
LiChenda and others added 2 commits January 2, 2024 23:43
Co-authored-by: Wangyou Zhang <C0me_On@163.com>
Co-authored-by: Wangyou Zhang <C0me_On@163.com>
@LiChenda
Copy link
Contributor Author

LiChenda commented Jan 3, 2024

Hi, @sw005320 , I also asked @popcornell to help review this PR.

Copy link
Contributor

@sw005320 sw005320 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some high-level comments.

espnet2/tasks/enh.py Outdated Show resolved Hide resolved
espnet2/enh/layers/ncsnpp_utils/upfirdn2d.py Outdated Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that a lot of lines are not tested, according to Codecov. Can you double-check it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be ok as this code is taken mostly as it is from sgmse

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The files in ncsnpp_utils are taken from the sgmse repo. Some of them are not used, and thus not tested. Should I add tests to those unused code lines or just remove them? @sw005320 .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
The discussions are whether we can keep the unused functions or not.

  • If we will use them in the future, we can add tests
  • If not, maybe, we can keep them as they are or we can remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added unit tests for the unused NCSNpp functions as much as I could.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that a lot of lines are not tested, according to Codecov. Can you double-check it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that a lot of lines are not tested, according to Codecov. Can you double-check it?

espnet2/enh/encoder/stft_encoder.py Outdated Show resolved Hide resolved
Comment on lines +26 to +27
spec_factor: float = 0.15,
spec_abs_exponent: float = 0.5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do these values come from?
Can you explain them?

Copy link
Contributor

@popcornell popcornell Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they come largely from SGMSE https://arxiv.org/pdf/2203.17004.pdf but for spec factor they used 1/3 there

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the waveform is standard dev normalized to 1, using 1/3 as spec factor bounds the STFT max value to +- 1.5.
I think using 0.15 bounds instead to -+ 0.75 which may actually be better.
Do you normalize the waveform std @LiChenda ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diffusion is super sensitive to input min max range

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sw005320 , these numbers come from section V.D in [1]. I added comments to the code. @popcornell , in their journal paper [1], they use 0.15 and 0.5. I did not normalize the waveform std.

[1] J. Richter, S. Welker, J.-M. Lemercier, B. Lay, and T. Gerkmann, “Speech Enhancement and Dereverberation With Diffusion-Based Generative Models,” IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351–2364, 2023.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you normalize by the .abs().amax() to be always within -1 and 1 ?
https://github.com/sp-uhh/sgmse/blob/c6e3291ee56b07792c9d8c7d7d49487b3042e01b/sgmse/data_module.py#L72

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the normalize option but didn't use it in my model training by default. Because I feel it is not reasonable for causal models.

normfac = speech_mix.abs().max() * 1.1 + 1e-5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

@sw005320
Copy link
Contributor

LGTM.
If it is ready, please let me know.
I’ll merge this PR.

@LiChenda
Copy link
Contributor Author

LGTM. If it is ready, please let me know. I’ll merge this PR.

Hi, @sw005320 , I added unit tests for the unused NCSNpp functions as much as I could. Please merge it when the CI test is passed.

@sw005320 sw005320 added the auto-merge Enable auto-merge label Jan 16, 2024
@mergify mergify bot merged commit 0dc18d6 into espnet:master Jan 17, 2024
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants