New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MixIT support. It is unsupervised only. Semi-supervised config is not available for now. #4619
Conversation
c9b066e
to
ea05bb0
Compare
Codecov Report
@@ Coverage Diff @@
## master #4619 +/- ##
==========================================
+ Coverage 83.07% 83.10% +0.03%
==========================================
Files 508 518 +10
Lines 43775 44777 +1002
==========================================
+ Hits 36364 37210 +846
- Misses 7411 7567 +156
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
5448639
to
9af7a81
Compare
def type(self): | ||
return "mixit" | ||
|
||
def forward(self, ref, inf, others={}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we include the constrained mixIT variant for Speech Enhancement ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the others: see 4.2 Section in https://arxiv.org/pdf/2006.12701.pdf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Perhaps it would be better to leave it later.
@popcornell Thank you for the comments. It is a bit different from the multi-condition training we used to use. Because now the dynamic mixing is involved. So the input to the model is already mixture-of-mixtures. We need some other solution to semi-supervised training. |
@@ -737,7 +737,7 @@ if ! "${skip_eval}"; then | |||
${_ref_scp} \ | |||
${_inf_scp} \ | |||
--ref_channel ${ref_channel} \ | |||
--flexible_numspk True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note, I changed it because bollean args cannot be passed in this way. If some value is passed, the variable would be parsed as True, even here "False" is intended to be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to redefine
group.add_argument("--flexible_numspk", type=bool, default=False)
as
from espnet2.utils.types import str2bool
...
group.add_argument("--flexible_numspk", type=str2bool, default=False)
1754de7
to
e08cd71
Compare
egs2/TEMPLATE/enh1/enh.sh
Outdated
spk_num=2 # Number of speakers | ||
dynamic_mixing=false # Flag for dynamic mixing in speech separation task. | ||
ref_num=2 # Number of references (similar to speakers) | ||
inf_num= # Number of inferences output by the model | ||
# If not specified, it will be the same as ref_num. If specified, it will be overwritten. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that all --spk_num
arguments specified in enh1
recipes should be replaced after this change. (e.g., egs2/chime4/enh1/run.sh
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also it is better to clarify the meaning of ref_num
in different cases.
For example:
- when MixIT-based training strategy is used,
ref_num
means the number of reference signals, while each reference signal may contain more than one speaker. - otherwise,
ref_num
is equivalent to the number of speakers in normal speech enhancement/separation tasks.
noise_type_num=1 | ||
dereverb_ref_num=1 | ||
|
||
# Training data related | ||
use_dereverb_ref=false | ||
use_noise_ref=false | ||
use_preprocessor=false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this argument removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the preprocessor_choices
in tasks/enh.py
to determine if the preprocessor is used. Then this arguement is redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. So now this argument is specified in the configuration file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right.
local/data_supervised.sh ${local_data_args} | ||
|
||
sup_train_set="tr_"${min_or_max}_${sample_rate} | ||
train_set="tr_"${min_or_max}_${sample_rate}_w_1spk_utt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the meaning of the suffix _w_1spk_utt
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The suffix means "with 1 speaker utterances".
espnet2/bin/enh_scoring.py
Outdated
group.add_argument("--flexible_numspk", type=bool, default=False) | ||
group.add_argument("--flexible_numspk", default=False, action="store_true") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting the type to str2bool
seems a better option as we can intuitive show how to use --flexible_numspk
by passing True
or False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good. I'll update it.
ref_tensor = torch.stack(ref[:num_ref], dim=1) # (batch, num_ref, ...) | ||
inf_tensor = torch.stack(inf, dim=1) # (batch, num_inf, ...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this solver only accept waveforms as input? I think a solver should be independent of the domain that a signal belongs to. Instead, the criterion defines whether time-domain or frequency-domain is expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering that complex-valued spectra may be used as input, we may need to care about the input data type (e.g., ComplexTensor vs torch.Tensor).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the support for more input tensor shapes. The main step is in the einsum
computation. But basically, timedomain is preferred due to the additivity in MixIT. Is there any special care necessary for ComplexTensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for ComplexTensor, torch_complex.functional.stack
should be used instead of torch.stack
, and torch_complex.functional.einsum
should be used instead of torch.einsum
.
You could check the data type by
isinstance(c, ComplexTensor)
.
A simple way to automatically handle the data type issue can be like:
from espnet2.enh.layers.complex_utils import einsum, stack
...
stack(...)
einsum(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to add more unit tests to cover the case where the input is a complex-valued spectrum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now complex is supported.
espnet2/tasks/enh.py
Outdated
) | ||
use_preprocessor = getattr(args, "preprocessor", None) is not None | ||
|
||
if train and use_preprocessor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the preprocessor only used for training? The preprocessor may have their own rules for handling training and inference data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at the usages of --use_preprocessor, it look like they were only used in training. So I changed in this logic. But we can keep it as it was.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might better keep it for inference as well, in case we do not want to store the mixture audios and always mix the reference signals in the preprocessor.
espnet2/train/preprocessor.py
Outdated
num_spk: int = 2, | ||
num_utts: int = 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe num_refs
is better, if we want to add a new preprocessor for meeting-style long-form data (containing many utterances in each reference channel) in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed to ref_num, same as the one used in enh.sh. Trying to reduce the similar confusing names across all files.
if mixture_source_name is None: | ||
self.mixture_source_name = f"{speech_ref_name_prefix}1" | ||
else: | ||
self.mixture_source_name = mixture_source_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add some comments to explain the meaning of this argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments to discuss the current design.
0cce3b9
to
63c6fc5
Compare
621a0e8
to
631dbd5
Compare
… not available for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
I just want to make sure.
You seem to change some option names (ref-num instead of spk-num).
If this does not break the compatibility, it is no problem.
If so, we may accept both option names, and add a deprecation message (and remove it later).
loss = loss.mean() | ||
perm = torch.index_select(all_mixture_matrix, 0, perm) | ||
|
||
if perm.is_complex(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think perm
is always real.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm concerned that in line 77-88, the all_mixture_matrix maybe torch.complex if the inputs are complex. Hence, the perm could be complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perm is here the reordered estimate ? It looks like to me because it is after index_select.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it is the reordered estimate you might as well leave it complex as it can be say the STFT representation (which should work in theory with MixIT as it is additive).
Also if it is the reordered estimate maybe is better to change the name of the variable no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it the assignment matrix. Not the reordered estimate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point now. Since you explicitly change the data type of all_mixture_matrix
to inf_tensor.dtype
, it can become a builtin complex tensor in PyTorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep it is ok
|
||
|
||
@pytest.mark.parametrize("inf_num", [4]) | ||
def test_MixITSolver_complex_tensor_forward(inf_num): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add an additional test by using the builtin complex tensor in PyTorch when torch 1.9.0+ is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I ask why 1.9.0+ ? For example, Pytorch 1.8.1 also supports complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 1.9.0+ added complex support for many torch.linalg operations making possible to migrate to native torch for e.g. beamforming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha! I already applied the check of 1.9.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, indeed. It is since 1.9.0 that the complex support has been greatly improved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks fine now. I just left some minor comments.
@sw005320 I have checked all enh recipe and updated the corresponding arguments. |
kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk) | ||
kwargs.get( | ||
f"speech_ref{spk + 1}", | ||
torch.zeros_like(kwargs["speech_ref1"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this torch.zeros_like(kwargs["speech_ref1"])
for the silent output channel in the mixit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. In mixit, the reference is less than the num_spk.
f"speech_ref{spk + 1}", | ||
torch.zeros_like(kwargs["speech_ref1"]), | ||
) | ||
for spk in range(self.num_spk) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that the num_spk
is renamed to ref_num
in some other files (e.g. espnet2/train/preprocessor.py
, enh.sh
). Maybe we can also rename it here if it's compatible with the past version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may need more work. Changing the name here would impact other files, like enh_inference.py and the enh_s2t tasks probably. I think the pre-trained enhancement models would be influenced, too.
Thanks, @simpleoier, and everyone contributing to this PR! |
This PR is to replace the previous one #4263. It is mostly modified on top of the existing files to keep minimum new files.
Dynamic mixing is used to generate mixture of mixtures (4 speakers). For training data, I also include 8000 1-speaker speech from tr_min_8k/spk1.scp to generate 2-speaker mixture. This sounds like a pseudo "semi-supervised" training in terms of data usage. Actually we can consider it as using some 1-speaker speech data, because I only use the MixIT loss for training.
For now, only one kind of loss wrapper (MixIT) is used. I don't have an easy solution about how to use both unsupervised (MixIT) and supervised (PIT) in the same training process.
The preliminary results on wsj0_2mix/mixit_enh1 is as follows. The model is trained with 33 epochs and did not stop yet.