Skip to content

Commit

Permalink
Merge pull request #5800 from Emrys365/urgent_recipe
Browse files Browse the repository at this point in the history
SE function updates: new models and support for handling various sampling frequencies
  • Loading branch information
mergify[bot] committed Jun 13, 2024
2 parents 293c1cc + e707989 commit 63c4c09
Show file tree
Hide file tree
Showing 23 changed files with 1,234 additions and 61 deletions.
2 changes: 2 additions & 0 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ if python -c 'import torch as t; from packaging.version import parse as L; asser
for t in ${feats_types}; do
echo "==== feats_type=${t} without preprocessor ==="
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-packing false --feats-type "${t}" --ref-num 1 --python "${python}" --enh-args "--num_workers 0"
./run.sh --ngpu 0 --stage 6 --stop-stage 10 --skip-packing false --feats-type "${t}" --ref-num 1 --python "${python}" \
--enh_config conf/train_with_chunk_iterator_debug.yaml --enh-args "--num_workers 0"
done
# Remove generated files in order to reduce the disk usage
rm -rf exp dump data
Expand Down
25 changes: 23 additions & 2 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ if ! "${skip_train}"; then
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/utt2category,category,text "
fi

# Add the fs information at the end of the data path list
# Add the sampling frequency information at the end of the data path list
if [ -e "${_enh_train_dir}/utt2fs" ] && [ -e "${_enh_valid_dir}/utt2fs" ]; then
log "[INFO] Adding the sampling frequency information (fs) for training"

Expand Down Expand Up @@ -837,6 +837,19 @@ if ! "${skip_eval}"; then
_data_param+="--data_path_and_name_and_type ${_data}/enroll_spk${spk}.scp,enroll_ref${spk},text "
done
fi
# Add the category information at the end of the data path list
if [ -e "${_data}/utt2category" ]; then
log "[INFO] Adding the category information for inference"
log "[WARNING] Please make sure the category information is explicitly processed by the preprocessor defined in '${enh_config}' so that it is converted to an integer"

_data_param+="--data_path_and_name_and_type ${_data}/utt2category,category,text "
fi
# Add the sampling frequency information at the end of the data path list
if [ -e "${_data}/utt2fs" ]; then
log "[INFO] Adding the sampling frequency information for inference"

_data_param+="--data_path_and_name_and_type ${_data}/utt2fs,fs,text_int "
fi
# 1. Split the key file
key_file=${_data}/${_scp}
split_scps=""
Expand Down Expand Up @@ -888,6 +901,14 @@ if ! "${skip_eval}"; then
log "Stage 8: Scoring"
_cmd=${decode_cmd}

if ${gpu_inference}; then
_cmd=${cuda_cmd}
_ngpu=1
else
_cmd=${decode_cmd}
_ngpu=0
fi

# score_obs=true: Scoring for observation signal
# score_obs=false: Scoring for enhanced signal
for score_obs in true false; do
Expand Down Expand Up @@ -944,7 +965,7 @@ if ! "${skip_eval}"; then
# 2. Submit scoring jobs
log "Scoring started... log: '${_logdir}/enh_scoring.*.log'"
# shellcheck disable=SC2086
${_cmd} JOB=1:"${_nj}" "${_logdir}"/enh_scoring.JOB.log \
${_cmd} --gpu "${_ngpu}" JOB=1:"${_nj}" "${_logdir}"/enh_scoring.JOB.log \
${python} -m espnet2.bin.enh_scoring \
--key_file "${_logdir}"/keys.JOB.scp \
--output_dir "${_logdir}"/output.JOB \
Expand Down
41 changes: 41 additions & 0 deletions egs2/mini_an4/enh1/conf/train_with_chunk_iterator_debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# This is a debug config for CI
encoder: conv
encoder_conf:
channel: 32
kernel_size: 20
stride: 10
decoder: conv
decoder_conf:
channel: 32
kernel_size: 20
stride: 10
separator: tcn
separator_conf:
num_spk: 1
layer: 2
stack: 2
bottleneck_dim: 16
hidden_dim: 48
kernel: 3
causal: False
norm_type: "gLN"
nonlinear: relu

criterions:
# The first criterion
- name: mse_td
conf: {}
# the wrapper for the current criterion
# for single-talker case, we simplely use fixed_order wrapper
wrapper: fixed_order
wrapper_conf:
weight: 1.0

max_epoch: 1
batch_type: folded
batch_size: 2
iterator_type: chunk
chunk_length: 25 # 0.5s
chunk_default_fs: 50 # GCD among all possible sampling frequencies
chunk_max_abs_length: 100000 # max number of samples per chunk for all sampling frequencies (reduce this value if OOM occurs)
chunk_discard_short_samples: false
53 changes: 46 additions & 7 deletions espnet2/bin/enh_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import humanfriendly
import numpy as np
import torch
import torchaudio
import yaml
from tqdm import trange
from typeguard import typechecked
Expand All @@ -17,6 +18,9 @@
from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainMSE
from espnet2.enh.loss.criterions.time_domain import SISNRLoss
from espnet2.enh.loss.wrappers.pit_solver import PITSolver
from espnet2.enh.separator.bsrnn_separator import BSRNNSeparator
from espnet2.enh.separator.tfgridnetv3_separator import TFGridNetV3
from espnet2.enh.separator.uses_separator import USESSeparator
from espnet2.fileio.sound_scp import SoundScpWriter
from espnet2.tasks.enh import EnhancementTask
from espnet2.tasks.enh_s2t import EnhS2TTask
Expand Down Expand Up @@ -116,6 +120,7 @@ def __init__(

# 1. Build Enh model

self.sfi_processing = False # sampling-frequency-independent (SFI)
if inference_config is None:
enh_model, enh_train_args = task.build_model_from_file(
train_config, model_file, device
Expand Down Expand Up @@ -150,6 +155,10 @@ def __init__(
if enh_s2t_task:
enh_model = enh_model.enh_model
enh_model.to(dtype=getattr(torch, dtype)).eval()
if isinstance(
enh_model.separator, ((BSRNNSeparator, USESSeparator, TFGridNetV3))
):
self.sfi_processing = True

self.device = device
self.dtype = dtype
Expand Down Expand Up @@ -217,6 +226,22 @@ def __call__(
[batch_size], dtype=torch.long, fill_value=speech_mix.size(1)
)

lengths0 = lengths
if self.sfi_processing:
fs_ = fs
else:
fs_ = None
if self.enh_model.always_forward_in_48k:
lengths = lengths.new_tensor(
[
torchaudio.functional.resample(
torch.randn(L, device="meta"), fs, 48000
).size(0)
for L in lengths
]
)
speech_mix = torchaudio.functional.resample(speech_mix, fs, 48000)

# a. To device
speech_mix = to_device(speech_mix, device=self.device)
lengths = to_device(lengths, device=self.device)
Expand Down Expand Up @@ -244,7 +269,7 @@ def __call__(
raise ValueError(f"Category '{category}' is not listed in self.categories")

additional = {}
if category is not None:
if category is not None and self.enh_model.categories:
cat = self.enh_model.categories[category[0].item()]
print(f"category: {cat}", flush=True)
if cat.endswith("_reverb"):
Expand Down Expand Up @@ -279,13 +304,13 @@ def __call__(
[batch_size], dtype=torch.long, fill_value=T
)
# b. Enhancement/Separation Forward
feats, f_lens = self.enh_model.encoder(speech_seg, lengths_seg)
feats, f_lens = self.enh_model.encoder(speech_seg, lengths_seg, fs=fs_)
if isinstance(self.enh_model, ESPnetDiffusionModel):
feats = [self.enh_model.enhance(feats)]
else:
feats, _, _ = self.enh_model.separator(feats, f_lens, additional)
processed_wav = [
self.enh_model.decoder(f, lengths_seg)[0] for f in feats
self.enh_model.decoder(f, lengths_seg, fs=fs_)[0] for f in feats
]
if speech_seg.dim() > 2:
# multi-channel speech
Expand Down Expand Up @@ -339,12 +364,12 @@ def __call__(
waves = torch.unbind(waves, dim=0)
else:
# b. Enhancement/Separation Forward
feats, f_lens = self.enh_model.encoder(speech_mix, lengths)
feats, f_lens = self.enh_model.encoder(speech_mix, lengths, fs=fs_)
if isinstance(self.enh_model, ESPnetDiffusionModel):
feats = [self.enh_model.enhance(feats)]
else:
feats, _, _ = self.enh_model.separator(feats, f_lens, additional)
waves = [self.enh_model.decoder(f, lengths)[0] for f in feats]
waves = [self.enh_model.decoder(f, lengths, fs=fs_)[0] for f in feats]

###################################
# De-normalize the signal variance
Expand All @@ -357,6 +382,12 @@ def __call__(
mix_std_ = mix_std_.squeeze(2)
waves = [w * mix_std_ for w in waves]

if not self.sfi_processing and self.enh_model.always_forward_in_48k:
waves = [
torchaudio.functional.resample(sp, 48000, fs)[..., : lengths0.max()]
for sp in waves
]

assert len(waves) == self.num_spk, len(waves) == self.num_spk
assert len(waves[0]) == batch_size, (len(waves[0]), batch_size)
if self.normalize_output_wav:
Expand Down Expand Up @@ -524,10 +555,18 @@ def inference(
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}

waves = separate_speech(**batch, fs=fs)
if "utt2fs" in batch:
# All samples must have the same sampling rate
assert all([fs_ == batch["utt2fs"][0].item() for fs_ in batch["utt2fs"]])
fs_ = batch.pop("utt2fs")[0].item()
logging.info(f"Swichting to fs={fs_}Hz")
else:
fs_ = fs

waves = separate_speech(**batch, fs=fs_)
for spk, w in enumerate(waves):
for b in range(batch_size):
writers[spk][keys[b]] = fs, w[b]
writers[spk][keys[b]] = fs_, w[b]

for writer in writers:
writer.close()
Expand Down
23 changes: 21 additions & 2 deletions espnet2/bin/enh_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Dict, List, Union

import librosa
import numpy as np
import torch
from mir_eval.separation import bss_eval_sources
Expand Down Expand Up @@ -223,17 +224,35 @@ def scoring(
if pesq:
if sample_rate == 8000:
mode = "nb"
ref_ = ref[i]
inf_ = inf[int(perm[i])]
elif sample_rate == 16000:
mode = "wb"
ref_ = ref[i]
inf_ = inf[int(perm[i])]
elif sample_rate > 16000:
mode = "wb"
ref_ = librosa.resample(
ref[i], orig_sr=sample_rate, target_sr=16000
)
inf_ = librosa.resample(
inf[int(perm[i])], orig_sr=sample_rate, target_sr=16000
)
sample_rate = 16000
logging.warning(
"The sample rate is higher than 16000 Hz. "
"PESQ is calculated in the wideband mode and "
"the signal is resampled to 16 kHz."
)
else:
raise ValueError(
"sample rate must be 8000 or 16000 for PESQ evaluation, "
f"but got {sample_rate}"
)
pesq_score = pesq(
sample_rate,
ref[i],
inf[int(perm[i])],
ref_,
inf_,
mode=mode,
on_error=PesqError.RETURN_VALUES,
)
Expand Down
4 changes: 2 additions & 2 deletions espnet2/diar/separator/tcn_separator_nomask.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from distutils.version import LooseVersion
from typing import Tuple, Union

import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.diar.layers.tcn_nomask import TemporalConvNet
from espnet2.enh.layers.complex_utils import is_complex
from espnet2.enh.separator.abs_separator import AbsSeparator

is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


class TCNSeparatorNomask(AbsSeparator):
Expand Down
1 change: 0 additions & 1 deletion espnet2/enh/decoder/stft_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def _reconfig_for_fs(self, fs):
Args:
fs (int): new sampling rate
"""
assert fs % self.default_fs == 0 or self.default_fs % fs == 0
self.stft.n_fft = self.n_fft * fs // self.default_fs
self.stft.win_length = self.win_length * fs // self.default_fs
self.stft.hop_length = self.hop_length * fs // self.default_fs
Expand Down
1 change: 0 additions & 1 deletion espnet2/enh/encoder/stft_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def _reconfig_for_fs(self, fs):
Args:
fs (int): new sampling rate
""" # noqa: H405
assert fs % self.default_fs == 0 or self.default_fs % fs == 0
self.stft.n_fft = self.n_fft * fs // self.default_fs
self.stft.win_length = self.win_length * fs // self.default_fs
self.stft.hop_length = self.hop_length * fs // self.default_fs
Expand Down
Loading

0 comments on commit 63c4c09

Please sign in to comment.