Skip to content

Commit

Permalink
Merge pull request #4566 from Emrys365/egs2_aishell4
Browse files Browse the repository at this point in the history
Support enh_s2t joint training on multi-speaker data
  • Loading branch information
sw005320 committed Aug 31, 2022
2 parents e5d133c + 8af19a3 commit 6d52365
Show file tree
Hide file tree
Showing 50 changed files with 1,463 additions and 300 deletions.
486 changes: 342 additions & 144 deletions egs2/TEMPLATE/enh_asr1/enh_asr.sh

Large diffs are not rendered by default.

206 changes: 206 additions & 0 deletions egs2/TEMPLATE/enh_asr1/scripts/utils/eval_perm_free_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2020 Johns Hopkins University (Xuankai Chang)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
from collections import OrderedDict
from pathlib import Path
from typing import List

import numpy as np
import six
from scipy.optimize import linear_sum_assignment

sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)


def get_parser():
parser = argparse.ArgumentParser(description="evaluate permutation-free error")
parser.add_argument(
"--num-spkrs", type=int, default=2, help="number of mixed speakers."
)
parser.add_argument(
"--results",
type=str,
nargs="+",
help="the scores between references and hypotheses, "
"in ascending order of references (1st) and hypotheses (2nd), "
"e.g. [r1h1, r1h2, r2h1, r2h2] in 2-speaker-mix case.",
)
parser.add_argument(
"--results-dir",
type=str,
help="the score dir. ",
)
return parser


def convert_score(dic, num_spkrs=2) -> List[List[int]]:
ret = []
pat = re.compile(r"\d+")

for r_idx in range(num_spkrs):
ret.append([])
for h_idx in range(num_spkrs):
key = f"r{r_idx + 1}h{h_idx + 1}"

score = list(map(int, pat.findall(dic[key]["Scores"]))) # [c,s,d,i]
assert len(score) == 4 # [c,s,d,i]
ret[r_idx].append(score)

return ret


def compute_permutation(old_dic, num_spkrs=2):
"""Compute the permutation per utterance."""
all_scores, all_keys = [], list(old_dic.keys())
for scores in old_dic.values(): # compute error rate for each utt_id
all_scores.append(convert_score(scores, num_spkrs))
all_scores = np.array(all_scores) # (B, n_ref, n_hyp, 4)

all_error_rates = np.sum(
all_scores[:, :, :, 1:4], axis=-1, dtype=np.float
) / np.sum(
all_scores[:, :, :, 0:3], axis=-1, dtype=np.float
) # (s+d+i) / (c+s+d), (B, n_ref, n_hyp)

min_scores, hyp_perms = [], []
for idx, error_rate in enumerate(all_error_rates):
row_idx, col_idx = linear_sum_assignment(error_rate)

hyp_perms.append(col_idx)
min_scores.append(np.sum(all_scores[idx, row_idx, col_idx], axis=0))

min_scores = np.stack(min_scores)

return hyp_perms, all_keys


def read_result(result_file, result_key):
re_id = r"^id: "
re_strings = {"Scores": r"^Scores: "}
re_id = re.compile(re_id)
re_patterns = {}
for p in re_strings.keys():
re_patterns[p] = re.compile(re_strings[p])

results = OrderedDict()
tmp_id, tmp_ret = None, {}

with codecs.open(result_file, "r", encoding="utf-8") as f:
for line in f:
line = line.rstrip()
lst = line.split()

if re_id.match(line):
if tmp_id:
results[tmp_id] = {result_key: tmp_ret}
tmp_ret = {}

tmp_id = lst[1]
if tmp_id[0] == "(" and tmp_id[-1] == ")":
tmp_id = tmp_id[1:-1]

for key, pat in re_patterns.items():
if pat.match(line):
tmp_ret[key] = " ".join(lst[1:])

if tmp_ret != {}:
results[tmp_id] = {result_key: tmp_ret}

return results


def merge_results(results):
# make intersection set for utterance keys
for result in results[1:]:
assert results[0].keys() == result.keys()

# merging results
all_results = OrderedDict()
for key in results[0].keys():
v = results[0][key]
for result in results[1:]:
v.update(result[key])
all_results[key] = v

return all_results


def read_trn(file_path):
assert Path(file_path).exists()

ret_dict = OrderedDict()
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
text, utt_id = line.rsplit(maxsplit=1)
if utt_id[0] == "(" and utt_id[-1] == ")":
utt_id = utt_id[1:-1]
ret_dict[utt_id] = text
return ret_dict


def reorder_refs_or_hyps(result_dir, num_spkrs, all_keys, hyp_or_ref=None, perms=None):
assert hyp_or_ref in ["hyp", "ref"]
if hyp_or_ref == "ref":
assert perms is None
perms = [np.arange(0, num_spkrs) for _ in all_keys]

orig_trns = []
for i in range(1, num_spkrs + 1):
orig_trns.append(read_trn(Path(result_dir, f"{hyp_or_ref}_spk{i}.trn")))
if i > 1:
assert list(orig_trns[0].keys()) == list(orig_trns[-1].keys())

with open(Path(result_dir, f"{hyp_or_ref}.trn"), "w", encoding="utf-8") as f:
for idx, (key, perm) in enumerate(zip(orig_trns[0].keys(), perms)):
# todo: clean this part, because sclite turn all ids in to lower characters.
assert key.lower() == all_keys[idx].lower()
for i in range(num_spkrs):
f.write(orig_trns[perm[i]][key] + f"\t({key}-{i+1})" + "\n")


def main(args):
# Read results from files
all_results = []
for r in six.moves.range(1, args.num_spkrs + 1):
for h in six.moves.range(1, args.num_spkrs + 1):
key = f"r{r}h{h}"
result = read_result(
Path(args.results_dir, f"result_{key}.txt"), result_key=key
)
all_results.append(result)

# Merge the results of every permutation
results = merge_results(all_results)

# Get the final results with best permutation
hyp_perms, all_keys = compute_permutation(results, args.num_spkrs)

# Use the permutation order to reorder hypotheses file
# Then output the refs and hyps in a new file by combining all speakers
reorder_refs_or_hyps(
args.results_dir,
args.num_spkrs,
all_keys,
"hyp",
hyp_perms,
)
reorder_refs_or_hyps(
args.results_dir,
args.num_spkrs,
all_keys,
"ref",
)


if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()

main(args)
84 changes: 0 additions & 84 deletions egs2/TEMPLATE/enh_asr1/scripts/utils/show_enh_score.sh

This file was deleted.

1 change: 1 addition & 0 deletions egs2/TEMPLATE/enh_asr1/scripts/utils/show_enh_score.sh
8 changes: 7 additions & 1 deletion egs2/chime4/enh_asr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "stage 3: Srctexts preparation"
local/chime4_asr_data.sh --stage 2 --stop-stage 2
fi

for dset in data/*; do
if [ -e "${dset}/text" ] && [ ! -e "${dset}/text_spk1" ]; then
ln -s text ${dset}/text_spk1
fi
done
fi
4 changes: 2 additions & 2 deletions egs2/chime4/enh_asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ word_vocab_size=65000
--train_set "${train_set}" \
--valid_set "${valid_set}" \
--test_sets "${test_sets}" \
--bpe_train_text "data/${train_set}/text" \
--lm_train_text "data/${train_set}/text data/local/other_text/text" "$@"
--bpe_train_text "data/${train_set}/text_spk1" \
--lm_train_text "data/${train_set}/text_spk1 data/local/other_text/text" "$@"
3 changes: 3 additions & 0 deletions egs2/mini_an4/enh_asr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ EOF
fi
rm data/${x}/${f}.old
done
if [ ! -e "data/${x}/text_spk1" ]; then
ln -s text data/${x}/text_spk1
fi
utils/utt2spk_to_spk2utt.pl data/${x}/utt2spk > data/${x}/spk2utt

cp data/${x}/wav.scp data/${x}/spk1.scp
Expand Down
2 changes: 1 addition & 1 deletion egs2/mini_an4/enh_asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ set -o pipefail
--train_set train_nodev \
--valid_set train_dev \
--test_sets "train_dev test test_seg" \
--lm_train_text "data/train_nodev/text" "$@"
--lm_train_text "data/train_nodev/text_spk1" "$@"
28 changes: 28 additions & 0 deletions egs2/wsj0_2mix_spatialized/enh_asr1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<!-- Generated by scripts/utils/show_asr_result.sh and scripts/utils/show_enh_score.sh -->
# RESULTS
## Environments
- date: `Fri Aug 26 17:20:12 CST 2022`
- python version: `3.8.11 (default, Aug 3 2021, 15:09:35) [GCC 7.5.0]`
- espnet version: `espnet 202207`
- pytorch version: `pytorch 1.7.0`
- Git hash: `277ec3c33d2ca7f47d9d31c84e4dae54ce017bd7`
- Commit date: `Wed Aug 10 13:32:09 2022 -0400`

## enh_asr_train_raw_en_char
### WER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_transformer_normalize_output_wavtrue_lm_lm_train_lm_transformer_en_char_valid.loss.ave_enh_asr_model_valid.acc.ave/tt_spatialized_anechoic_multich_max_16k|6000|98613|92.9|6.0|1.2|1.0|8.1|45.3|

### CER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_transformer_normalize_output_wavtrue_lm_lm_train_lm_transformer_en_char_valid.loss.ave_enh_asr_model_valid.acc.ave/tt_spatialized_anechoic_multich_max_16k|6000|598296|96.7|1.6|1.7|0.9|4.3|48.1|

### Speech Separation Metrics

|dataset|STOI|SAR|SDR|SIR|SI_SNR|
|---|---|---|---|---|---|
|enhanced_tt_spatialized_anechoic_multich_max_16k|95.25|12.03|10.24|21.74|-3.35|

0 comments on commit 6d52365

Please sign in to comment.