Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MixIT support. It is unsupervised only. Semi-supervised config is not available for now. #4619

Merged
merged 2 commits into from Sep 14, 2022

Conversation

simpleoier
Copy link
Collaborator

@simpleoier simpleoier commented Sep 4, 2022

This PR is to replace the previous one #4263. It is mostly modified on top of the existing files to keep minimum new files.
Dynamic mixing is used to generate mixture of mixtures (4 speakers). For training data, I also include 8000 1-speaker speech from tr_min_8k/spk1.scp to generate 2-speaker mixture. This sounds like a pseudo "semi-supervised" training in terms of data usage. Actually we can consider it as using some 1-speaker speech data, because I only use the MixIT loss for training.

For now, only one kind of loss wrapper (MixIT) is used. I don't have an easy solution about how to use both unsupervised (MixIT) and supervised (PIT) in the same training process.

The preliminary results on wsj0_2mix/mixit_enh1 is as follows. The model is trained with 33 epochs and did not stop yet.

# RESULTS
## Environments
- date: `Mon Sep  5 14:55:27 EDT 2022`
- python version: `3.9.12 (main, Apr  5 2022, 06:56:58)  [GCC 7.5.0]`
- espnet version: `espnet 202207`
- pytorch version: `pytorch 1.10.1`
- Git hash: `6d5236553b7fb3e653907c447bbbbb0790a013f9`
  - Commit date: `Wed Aug 31 08:17:56 2022 -0400`


## enh_train_enh_mixit_conv_tasnet_raw

config: conf/tuning/train_enh_mixit_conv_tasnet.yaml

|dataset|STOI|SAR|SDR|SIR|SI_SNR|
|---|---|---|---|---|---|
|enhanced_cv_min_8k|91.43|14.55|13.96|24.12|13.34|
|enhanced_tt_min_8k|91.32|13.68|12.91|22.61|12.25|

@simpleoier
Copy link
Collaborator Author

@LiChenda I would also like to remote this, and specify the mixture source name in the config file as in egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mixit_conv_tasnet.yaml. Can you confirm that again?

@codecov
Copy link

codecov bot commented Sep 4, 2022

Codecov Report

Merging #4619 (d7f8047) into master (6d52365) will increase coverage by 0.03%.
The diff coverage is 98.61%.

@@            Coverage Diff             @@
##           master    #4619      +/-   ##
==========================================
+ Coverage   83.07%   83.10%   +0.03%     
==========================================
  Files         508      518      +10     
  Lines       43775    44777    +1002     
==========================================
+ Hits        36364    37210     +846     
- Misses       7411     7567     +156     
Flag Coverage Δ
test_integration_espnet1 66.36% <ø> (ø)
test_integration_espnet2 49.33% <45.83%> (-0.20%) ⬇️
test_python 70.94% <86.11%> (+0.32%) ⬆️
test_utils 23.28% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
espnet2/tasks/enh.py 97.41% <92.30%> (-1.92%) ⬇️
espnet2/bin/enh_scoring.py 88.76% <100.00%> (+0.25%) ⬆️
espnet2/enh/espnet_model.py 86.63% <100.00%> (+0.06%) ⬆️
espnet2/enh/loss/wrappers/mixit_solver.py 100.00% <100.00%> (ø)
espnet2/train/preprocessor.py 36.65% <100.00%> (-1.22%) ⬇️
espnet2/asr_transducer/beam_search_transducer.py 97.95% <0.00%> (-1.23%) ⬇️
espnet2/tasks/asr_transducer.py 100.00% <0.00%> (ø)
espnet2/asr_transducer/activation.py 100.00% <0.00%> (ø)
... and 23 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@popcornell
Copy link
Contributor

Regarding the different loss for different data, maybe is related to multi-condition training as here #4566. In theory utt2category as @Emrys365 points out can be leveraged to select which loss one can apply. I however do not know the status of this feature for that PR.

def type(self):
return "mixit"

def forward(self, ref, inf, others={}):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we include the constrained mixIT variant for Speech Enhancement ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the others: see 4.2 Section in https://arxiv.org/pdf/2006.12701.pdf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Perhaps it would be better to leave it later.

@simpleoier
Copy link
Collaborator Author

simpleoier commented Sep 5, 2022

Regarding the different loss for different data, maybe is related to multi-condition training as here #4566. In theory utt2category as @Emrys365 points out can be leveraged to select which loss one can apply. I however do not know the status of this feature for that PR.

@popcornell Thank you for the comments. It is a bit different from the multi-condition training we used to use. Because now the dynamic mixing is involved. So the input to the model is already mixture-of-mixtures. We need some other solution to semi-supervised training.

@@ -737,7 +737,7 @@ if ! "${skip_eval}"; then
${_ref_scp} \
${_inf_scp} \
--ref_channel ${ref_channel} \
--flexible_numspk True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note, I changed it because bollean args cannot be passed in this way. If some value is passed, the variable would be parsed as True, even here "False" is intended to be used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to redefine

group.add_argument("--flexible_numspk", type=bool, default=False)

as

from espnet2.utils.types import str2bool

...
group.add_argument("--flexible_numspk", type=str2bool, default=False)

@LiChenda
Copy link
Contributor

LiChenda commented Sep 6, 2022

@LiChenda I would also like to remote this, and specify the mixture source name in the config file as in egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mixit_conv_tasnet.yaml. Can you confirm that again?

As we discussed, you can remove it.

Comment on lines 61 to 63
spk_num=2 # Number of speakers
dynamic_mixing=false # Flag for dynamic mixing in speech separation task.
ref_num=2 # Number of references (similar to speakers)
inf_num= # Number of inferences output by the model
# If not specified, it will be the same as ref_num. If specified, it will be overwritten.
Copy link
Collaborator

@Emrys365 Emrys365 Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that all --spk_num arguments specified in enh1 recipes should be replaced after this change. (e.g., egs2/chime4/enh1/run.sh)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it is better to clarify the meaning of ref_num in different cases.
For example:

  • when MixIT-based training strategy is used, ref_num means the number of reference signals, while each reference signal may contain more than one speaker.
  • otherwise, ref_num is equivalent to the number of speakers in normal speech enhancement/separation tasks.

noise_type_num=1
dereverb_ref_num=1

# Training data related
use_dereverb_ref=false
use_noise_ref=false
use_preprocessor=false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this argument removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the preprocessor_choices in tasks/enh.py to determine if the preprocessor is used. Then this arguement is redundant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So now this argument is specified in the configuration file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right.

local/data_supervised.sh ${local_data_args}

sup_train_set="tr_"${min_or_max}_${sample_rate}
train_set="tr_"${min_or_max}_${sample_rate}_w_1spk_utt
Copy link
Collaborator

@Emrys365 Emrys365 Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the meaning of the suffix _w_1spk_utt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suffix means "with 1 speaker utterances".

Comment on lines 169 to 170
group.add_argument("--flexible_numspk", type=bool, default=False)
group.add_argument("--flexible_numspk", default=False, action="store_true")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting the type to str2bool seems a better option as we can intuitive show how to use --flexible_numspk by passing True or False.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. I'll update it.

Comment on lines 45 to 46
ref_tensor = torch.stack(ref[:num_ref], dim=1) # (batch, num_ref, ...)
inf_tensor = torch.stack(inf, dim=1) # (batch, num_inf, ...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this solver only accept waveforms as input? I think a solver should be independent of the domain that a signal belongs to. Instead, the criterion defines whether time-domain or frequency-domain is expected.

Copy link
Collaborator

@Emrys365 Emrys365 Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that complex-valued spectra may be used as input, we may need to care about the input data type (e.g., ComplexTensor vs torch.Tensor).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the support for more input tensor shapes. The main step is in the einsum computation. But basically, timedomain is preferred due to the additivity in MixIT. Is there any special care necessary for ComplexTensor?

Copy link
Collaborator

@Emrys365 Emrys365 Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for ComplexTensor, torch_complex.functional.stack should be used instead of torch.stack, and torch_complex.functional.einsum should be used instead of torch.einsum.

You could check the data type by isinstance(c, ComplexTensor).

A simple way to automatically handle the data type issue can be like:

from espnet2.enh.layers.complex_utils import einsum, stack

...
stack(...)
einsum(...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to add more unit tests to cover the case where the input is a complex-valued spectrum.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now complex is supported.

)
use_preprocessor = getattr(args, "preprocessor", None) is not None

if train and use_preprocessor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the preprocessor only used for training? The preprocessor may have their own rules for handling training and inference data.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the usages of --use_preprocessor, it look like they were only used in training. So I changed in this logic. But we can keep it as it was.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might better keep it for inference as well, in case we do not want to store the mixture audios and always mix the reference signals in the preprocessor.

num_spk: int = 2,
num_utts: int = 2,
Copy link
Collaborator

@Emrys365 Emrys365 Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe num_refs is better, if we want to add a new preprocessor for meeting-style long-form data (containing many utterances in each reference channel) in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed to ref_num, same as the one used in enh.sh. Trying to reduce the similar confusing names across all files.

Comment on lines +522 to +526
if mixture_source_name is None:
self.mixture_source_name = f"{speech_ref_name_prefix}1"
else:
self.mixture_source_name = mixture_source_name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add some comments to explain the meaning of this argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea.

Copy link
Collaborator

@Emrys365 Emrys365 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments to discuss the current design.

@sw005320 sw005320 added this to the v.202209 milestone Sep 7, 2022
@simpleoier
Copy link
Collaborator Author

@Emrys365 @sw005320 Can you please check this PR again?

Copy link
Contributor

@sw005320 sw005320 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
I just want to make sure.
You seem to change some option names (ref-num instead of spk-num).
If this does not break the compatibility, it is no problem.
If so, we may accept both option names, and add a deprecation message (and remove it later).

loss = loss.mean()
perm = torch.index_select(all_mixture_matrix, 0, perm)

if perm.is_complex():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think perm is always real.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that in line 77-88, the all_mixture_matrix maybe torch.complex if the inputs are complex. Hence, the perm could be complex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perm is here the reordered estimate ? It looks like to me because it is after index_select.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it is the reordered estimate you might as well leave it complex as it can be say the STFT representation (which should work in theory with MixIT as it is additive).
Also if it is the reordered estimate maybe is better to change the name of the variable no ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it the assignment matrix. Not the reordered estimate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point now. Since you explicitly change the data type of all_mixture_matrix to inf_tensor.dtype, it can become a builtin complex tensor in PyTorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep it is ok



@pytest.mark.parametrize("inf_num", [4])
def test_MixITSolver_complex_tensor_forward(inf_num):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an additional test by using the builtin complex tensor in PyTorch when torch 1.9.0+ is used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask why 1.9.0+ ? For example, Pytorch 1.8.1 also supports complex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 1.9.0+ added complex support for many torch.linalg operations making possible to migrate to native torch for e.g. beamforming.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha! I already applied the check of 1.9.0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed. It is since 1.9.0 that the complex support has been greatly improved.

Copy link
Collaborator

@Emrys365 Emrys365 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks fine now. I just left some minor comments.

@simpleoier
Copy link
Collaborator Author

simpleoier commented Sep 12, 2022

LGTM! I just want to make sure. You seem to change some option names (ref-num instead of spk-num). If this does not break the compatibility, it is no problem. If so, we may accept both option names, and add a deprecation message (and remove it later).

@sw005320 I have checked all enh recipe and updated the corresponding arguments.

kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
kwargs.get(
f"speech_ref{spk + 1}",
torch.zeros_like(kwargs["speech_ref1"]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this torch.zeros_like(kwargs["speech_ref1"]) for the silent output channel in the mixit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In mixit, the reference is less than the num_spk.

f"speech_ref{spk + 1}",
torch.zeros_like(kwargs["speech_ref1"]),
)
for spk in range(self.num_spk)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the num_spk is renamed to ref_num in some other files (e.g. espnet2/train/preprocessor.py, enh.sh). Maybe we can also rename it here if it's compatible with the past version.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may need more work. Changing the name here would impact other files, like enh_inference.py and the enh_s2t tasks probably. I think the pre-trained enhancement models would be influenced, too.

@sw005320 sw005320 merged commit a7bd652 into espnet:master Sep 14, 2022
@sw005320
Copy link
Contributor

Thanks, @simpleoier, and everyone contributing to this PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants