Skip to content

Commit

Permalink
Merge pull request #5155 from Emrys365/tse
Browse files Browse the repository at this point in the history
Update TD-SpeakerBeam
  • Loading branch information
mergify[bot] committed May 15, 2023
2 parents 0aa1049 + a93775c commit 6e35c14
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 42 deletions.
2 changes: 1 addition & 1 deletion egs2/librimix/tse1/conf/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
optim: adam
max_epoch: 100
batch_type: folded
batch_size: 16
iterator_type: chunk
chunk_length: 48000
# exclude keys "enroll_ref", "enroll_ref1", "enroll_ref2", ...
# from the length consistency check in ChunkIterFactory
chunk_excluded_key_prefixes:
- "enroll_ref"
num_workers: 4
optim_conf:
lr: 1.0e-03
eps: 1.0e-08
weight_decay: 0
unused_parameters: true
patience: 20
accum_grad: 1
grad_clip: 5.0
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- snr
- max
- - valid
- loss
- min
keep_nbest_models: 1
scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.7
patience: 3

model_conf:
num_spk: 2
share_encoder: true

train_spk2enroll: data/train-100/spk2enroll.json
enroll_segment: 48000
load_spk_embedding: false
load_all_speakers: false

encoder: conv
encoder_conf:
channel: 256
kernel_size: 32
stride: 16
decoder: conv
decoder_conf:
channel: 256
kernel_size: 32
stride: 16
extractor: td_speakerbeam
extractor_conf:
layer: 8
stack: 4
bottleneck_dim: 256
hidden_dim: 512
skip_dim: 256
kernel: 3
causal: False
norm_type: gLN
pre_nonlinear: prelu
nonlinear: relu
# enrollment related
i_adapt_layer: 7
adapt_layer_type: mul
adapt_enroll_dim: 256
use_spk_emb: false

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: snr
conf:
eps: 1.0e-7
wrapper: fixed_order
wrapper_conf:
weight: 1.0
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
optim: adam
max_epoch: 400
max_epoch: 100
batch_type: folded
batch_size: 16
iterator_type: chunk
chunk_length: 24000
chunk_length: 48000
# exclude keys "enroll_ref", "enroll_ref1", "enroll_ref2", ...
# from the length consistency check in ChunkIterFactory
chunk_excluded_key_prefixes:
Expand Down Expand Up @@ -39,20 +39,20 @@ model_conf:
share_encoder: true

train_spk2enroll: data/train-100/spk2enroll.json
enroll_segment: 24000
enroll_segment: 48000
load_spk_embedding: false
load_all_speakers: false

encoder: conv
encoder_conf:
channel: 256
kernel_size: 16
stride: 8
kernel_size: 32
stride: 16
decoder: conv
decoder_conf:
channel: 256
kernel_size: 16
stride: 8
kernel_size: 32
stride: 16
extractor: td_speakerbeam
extractor_conf:
layer: 8
Expand All @@ -63,11 +63,13 @@ extractor_conf:
kernel: 3
causal: False
norm_type: gLN
pre_nonlinear: prelu
nonlinear: relu
# enrollment related
i_adapt_layer: 7
adapt_layer_type: mul
adapt_enroll_dim: 256
use_spk_emb: false

# A list for criterions
# The overlall loss in the multi-task learning will be:
Expand Down
4 changes: 0 additions & 4 deletions egs2/librimix/tse1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ if [[ "$num_spk" != "2" ]] && [[ "$num_spk" != "3" ]]; then
exit 1
fi

if [ ! -e "${LIBRISPEECH}" ]; then
log "Fill the value of 'LIBRISPEECH' of db.sh"
exit 1
fi

librimix=data/LibriMix/libri_mix/Libri2Mix
mkdir -p data/{train,dev,test}
Expand Down
2 changes: 1 addition & 1 deletion egs2/librimix/tse1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -e
set -u
set -o pipefail

sample_rate=16k # If using 8k, please make sure `spk2enroll.json` points to 8k audios as well
sample_rate=16k # 8k or 16k
min_or_max=min # "min" or "max". This is to determine how the mixtures are generated in local/data.sh.


Expand Down
69 changes: 51 additions & 18 deletions espnet2/enh/extractor/td_speakerbeam_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def __init__(
kernel: int = 3,
causal: bool = False,
norm_type: str = "gLN",
pre_nonlinear: str = "prelu",
nonlinear: str = "relu",
# enrollment related arguments
i_adapt_layer: int = 7,
adapt_layer_type: str = "mul",
adapt_enroll_dim: int = 128,
use_spk_emb: bool = False,
spk_emb_dim: int = 256,
):
"""Time-Domain SpeakerBeam Extractor.
Expand All @@ -40,15 +43,22 @@ def __init__(
kernel: int, kernel size.
causal: bool, defalut False.
norm_type: str, choose from 'BN', 'gLN', 'cLN'
pre_nonlinear: the nonlinear function right before mask estimation
select from 'prelu', 'relu', 'tanh', 'sigmoid', 'linear'
nonlinear: the nonlinear function for mask estimation,
select from 'relu', 'tanh', 'sigmoid'
select from 'relu', 'tanh', 'sigmoid', 'linear'
i_adapt_layer: int, index of adaptation layer
adapt_layer_type: str, type of adaptation layer
see espnet2.enh.layers.adapt_layers for options
adapt_enroll_dim: int, dimensionality of the speaker embedding
use_spk_emb: bool, whether to use speaker embeddings as enrollment
spk_emb_dim: int, dimension of input speaker embeddings
only used when `use_spk_emb` is True
"""
super().__init__()

if pre_nonlinear not in ("sigmoid", "prelu", "relu", "tanh", "linear"):
raise ValueError("Not supporting pre_nonlinear={}".format(pre_nonlinear))
if nonlinear not in ("sigmoid", "relu", "tanh", "linear"):
raise ValueError("Not supporting nonlinear={}".format(nonlinear))

Expand All @@ -63,27 +73,39 @@ def __init__(
out_channel=None,
norm_type=norm_type,
causal=causal,
pre_mask_nonlinear=pre_nonlinear,
mask_nonlinear=nonlinear,
i_adapt_layer=i_adapt_layer,
adapt_layer_type=adapt_layer_type,
adapt_enroll_dim=adapt_enroll_dim,
)

# Auxiliary network
self.auxiliary_net = TemporalConvNet(
N=input_dim,
B=bottleneck_dim,
H=hidden_dim,
P=kernel,
X=layer,
R=1,
C=1,
Sc=skip_dim,
out_channel=adapt_enroll_dim if skip_dim is None else adapt_enroll_dim * 2,
norm_type=norm_type,
causal=False,
mask_nonlinear="linear",
)
self.use_spk_emb = use_spk_emb
if use_spk_emb:
self.auxiliary_net = torch.nn.Conv1d(
spk_emb_dim,
adapt_enroll_dim if skip_dim is None else adapt_enroll_dim * 2,
1,
)
else:
self.auxiliary_net = TemporalConvNet(
N=input_dim,
B=bottleneck_dim,
H=hidden_dim,
P=kernel,
X=layer,
R=1,
C=1,
Sc=skip_dim,
out_channel=adapt_enroll_dim
if skip_dim is None
else adapt_enroll_dim * 2,
norm_type=norm_type,
causal=False,
pre_mask_nonlinear=pre_nonlinear,
mask_nonlinear="linear",
)

def forward(
self,
Expand All @@ -99,7 +121,7 @@ def forward(
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
input_aux (torch.Tensor or ComplexTensor): Encoded auxiliary feature
for the target speaker [B, T, N]
for the target speaker [B, T, N] or [B, N]
ilens_aux (torch.Tensor): input lengths of auxiliary input for the
target speaker [Batch]
suffix_tag (str): suffix to append to the keys in `others`
Expand All @@ -118,10 +140,21 @@ def forward(
B, L, N = feature.shape

feature = feature.transpose(1, 2) # B, N, L
aux_feature = aux_feature.transpose(1, 2) # B, N, L'
# NOTE(wangyou): When `self.use_spk_emb` is True, `aux_feature` is assumed to be
# a speaker embedding; otherwise, it is assumed to be an enrollment audio.
if self.use_spk_emb:
# B, N, L'=1
if aux_feature.dim() == 2:
aux_feature = aux_feature.unsqueeze(-1)
elif aux_feature.size(-2) == 1:
assert aux_feature.dim() == 3, aux_feature.shape
aux_feature = aux_feature.transpose(1, 2)
else:
aux_feature = aux_feature.transpose(1, 2) # B, N, L'

enroll_emb = self.auxiliary_net(aux_feature).squeeze(1) # B, N', L'
enroll_emb.masked_fill_(make_pad_mask(ilens_aux, enroll_emb, -1), 0.0)
if not self.use_spk_emb:
enroll_emb.masked_fill_(make_pad_mask(ilens_aux, enroll_emb, -1), 0.0)
enroll_emb = enroll_emb.mean(dim=-1) # B, N'

mask = self.tcn(feature, enroll_emb) # B, N, L
Expand Down
26 changes: 21 additions & 5 deletions espnet2/enh/layers/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
out_channel=None,
norm_type="gLN",
causal=False,
pre_mask_nonlinear="linear",
mask_nonlinear="relu",
):
"""Basic Module of tasnet.
Expand All @@ -48,6 +49,7 @@ def __init__(
if it is None, `N` will be used instead.
norm_type: BN, gLN, cLN
causal: causal or non-causal
pre_mask_nonlinear: the non-linear function before masknet
mask_nonlinear: use which non-linear function to generate mask
"""
super().__init__()
Expand Down Expand Up @@ -94,9 +96,20 @@ def __init__(
# [M, B, K] -> [M, C*N, K]
mask_conv1x1 = nn.Conv1d(B, C * self.out_channel, 1, bias=False)
# Put together (for compatibility with older versions)
self.network = nn.Sequential(
layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1
)
if pre_mask_nonlinear == "linear":
self.network = nn.Sequential(
layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1
)
else:
activ = {
"prelu": nn.PReLU(),
"relu": nn.ReLU(),
"tanh": nn.Tanh(),
"sigmoid": nn.Sigmoid(),
}[pre_mask_nonlinear]
self.network = nn.Sequential(
layer_norm, bottleneck_conv1x1, temporal_conv_net, activ, mask_conv1x1
)

def forward(self, mixture_w):
"""Keep this API same with TasNet.
Expand All @@ -110,7 +123,7 @@ def forward(self, mixture_w):
M, N, K = mixture_w.size()
bottleneck = self.network[:2]
tcns = self.network[2]
masknet = self.network[3]
masknet = self.network[3:]
output = bottleneck(mixture_w)
skip_conn = 0.0
for block in tcns:
Expand Down Expand Up @@ -158,6 +171,7 @@ def __init__(
out_channel=None,
norm_type="gLN",
causal=False,
pre_mask_nonlinear="prelu",
mask_nonlinear="relu",
i_adapt_layer: int = 7,
adapt_layer_type: str = "mul",
Expand All @@ -178,6 +192,7 @@ def __init__(
if it is None, `N` will be used instead.
norm_type: BN, gLN, cLN
causal: causal or non-causal
pre_mask_nonlinear: the non-linear function before masknet
mask_nonlinear: use which non-linear function to generate mask
i_adapt_layer: int, index of the adaptation layer
adapt_layer_type: str, type of adaptation layer
Expand All @@ -196,6 +211,7 @@ def __init__(
out_channel=out_channel,
norm_type=norm_type,
causal=causal,
pre_mask_nonlinear=pre_mask_nonlinear,
mask_nonlinear=mask_nonlinear,
)
self.i_adapt_layer = i_adapt_layer
Expand Down Expand Up @@ -224,7 +240,7 @@ def forward(self, mixture_w, enroll_emb):

bottleneck = self.network[:2]
tcns = self.network[2]
masknet = self.network[3]
masknet = self.network[3:]
output = bottleneck(mixture_w)
skip_conn = 0.0
for i, block in enumerate(tcns):
Expand Down
4 changes: 0 additions & 4 deletions test/espnet2/bin/test_enh_inference_streaming.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import string
from argparse import ArgumentParser
from pathlib import Path

Expand All @@ -11,10 +10,7 @@
get_parser,
main,
)
from espnet2.enh.encoder.stft_encoder import STFTEncoder
from espnet2.tasks.enh import EnhancementTask
from espnet2.tasks.enh_s2t import EnhS2TTask
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump


Expand Down

0 comments on commit 6e35c14

Please sign in to comment.