In [1]:
import json, re, os, shutil
from pypinyin_dict.phrase_pinyin_data import large_pinyin
from pypinyin_dict.pinyin_data import zdic
large_pinyin.load()
zdic.load()
from pypinyin import pinyin, lazy_pinyin, Style
from g2pM import G2pM
from g2pw import G2PWConverter
from datasets import load_metric, Dataset, load_dataset, load_from_disk
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

ds = load_from_disk('./ds_biaobei')

def prepare_dataset(batch):
    del batch["audio"]
    return batch

# ds_test = ds['test'].map(prepare_dataset, remove_columns=ds.column_names["train"], num_proc=32)
ds_test = ds.map(prepare_dataset, num_proc=32)


In [3]:
def filter_punc(text):
    return re.sub(r'[,.!?;:、。？！ 《》<>，。_…“”：（）；——]', '', text).replace(' ', '')

In [4]:
tt_pairs = [d for d in ds_test]

In [5]:
g2pm = G2pM()
conv = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True)

In [6]:
g2pm_py = g2pm('你好', tone=True)
g2pm_py

['ni3', 'hao3']

In [7]:
results = {
    'lpy_pred': [],
    'g2pm_pred': [],
    'g2pw_pred': [],
    'gt': []
}

for d in tqdm(ds_test):
    line_py = lazy_pinyin(filter_punc(d['trans']), style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True)
    # line_py = list(filter(lambda x: len(filter_punc(x)) == len(x), line_py))
    g2pm_py = g2pm(filter_punc(d['trans']), tone=True)
    # g2pm_py = list(filter(lambda x: len(filter_punc(x)) == len(x), g2pm_py))
    g2pw_py = conv(filter_punc(d['trans']))[0]
    g2pw_py = list(filter(lambda x: x != None, g2pw_py))
    # print(line_py, gt_py)
    results['lpy_pred'].append(' '.join(line_py))
    results['g2pm_pred'].append(' '.join(g2pm_py))
    results['g2pw_pred'].append(' '.join(g2pw_py))
    results['gt'].append(d['text'])


100%|██████████| 9999/9999 [1:04:02<00:00,  2.60it/s]


In [8]:
cer_metric = load_metric("cer")
print(cer_metric.compute(predictions=results['lpy_pred'], references=results['gt']))
print(cer_metric.compute(predictions=results['g2pm_pred'], references=results['gt']))
print(cer_metric.compute(predictions=results['g2pw_pred'], references=results['gt']))

  cer_metric = load_metric("cer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


0.015899770970087605
0.024119867231903916
0.01874162161092616


In [9]:
wer_metric = load_metric("wer")
print(wer_metric.compute(predictions=results['lpy_pred'], references=results['gt']))
print(wer_metric.compute(predictions=results['g2pm_pred'], references=results['gt']))
print(wer_metric.compute(predictions=results['g2pw_pred'], references=results['gt']))

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


0.0705785550247461
0.1036019993368293
0.07993663035602441


In [10]:
results['trans'] = []

for d in tqdm(ds_test):
    results['trans'].append(filter_punc(d['trans']))
    

100%|██████████| 9999/9999 [00:00<00:00, 37219.89it/s]


In [11]:
def find_char_indexes(text, char):
    return [i for i, c in enumerate(text) if c == char]

In [12]:
asr_ds4_pred = json.load(open('./asr_ds4_pred.json'))
asr_ds4_pred[0]

{'text': 'ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1 ',
 'trans': '卡尔普陪外孙玩滑梯。',
 'pred_str': 'ka2 er2 pu3 pei2 wai4 sen1 wan2 hua2 ti1 '}

In [13]:
wlen_count = 0

results_a_t = {
    'trans': [],
    'gt': [],
    'pypy': [],
    'pred_py': []
}

for i in range(len(ds_test)):
    # for k in results:
    trans, gt, asr_pred = results['trans'][i], results['gt'][i], asr_ds4_pred[i]['pred_str']
    # print(trans, gt, lpy_pred)
    gts = gt.split(' ')
    asr_preds = asr_pred.strip().split(' ')
    # print(len(trans), len(gts), len(lpy_preds))
    try:
        if len(trans) != len(asr_preds):
            er_indexes = find_char_indexes(trans, '儿')
            pp = 0
            for r in er_indexes:
                if r != 0:
                    if asr_preds[r-1-pp][-2] == 'r' and asr_preds[r-1-pp][0] != 'e':
                        pp += 1
                        trans = trans[:r-pp] + trans[r-pp+1:]
            # print(trans, er_indexes)
    except Exception as e:
        pass
    if len(trans) != len(asr_preds):
        wlen_count += 1
    else:
        # print(line_py)
        pred_pys = []
        assert len(asr_preds) == len(trans)
        line_py = lazy_pinyin(trans, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True)
        for i in range(len(asr_preds)):
            polys = pinyin(trans[i], heteronym=True, style=Style.TONE3, neutral_tone_with_five=True)[0]
            polys_wo_tone = [p[:-1] for p in polys]
            if asr_preds[i][:-1] in polys_wo_tone or (asr_preds[i][-2] == 'r' and asr_preds[i][:-2] in polys_wo_tone):
                pred_pys.append(asr_preds[i])
            else:
                pred_pys.append(line_py[i])
        # print(pred_pys)
        results_a_t['trans'].append(trans)
        results_a_t['pypy'].append(' '.join(line_py))
        results_a_t['gt'].append(gt)
        results_a_t['pred_py'].append(' '.join(pred_pys))
        
# print(wlen_count)

In [14]:
wlen_count / len(ds_test)

0.0185018501850185

In [15]:
print(cer_metric.compute(predictions=results_a_t['pred_py'], references=results_a_t['gt']))
print(wer_metric.compute(predictions=results_a_t['pred_py'], references=results_a_t['gt']))

0.007937489587487821
0.03738030781571164


In [16]:
def rm_tones(py):
    return [' '.join([pp[:-1] for pp in p.split(' ')]) for p in py]

In [17]:
print(cer_metric.compute(predictions=rm_tones(results_a_t['pred_py']), references=rm_tones(results_a_t['gt'])))
print(wer_metric.compute(predictions=rm_tones(results_a_t['pred_py']), references=rm_tones(results_a_t['gt'])))

0.0007617662863103328
0.001973981049781922


In [18]:
print(cer_metric.compute(predictions=results_a_t['pypy'], references=results_a_t['gt']))
print(wer_metric.compute(predictions=results_a_t['pypy'], references=results_a_t['gt']))

0.015173238960212844
0.06834987717451246


In [19]:
print(cer_metric.compute(predictions=rm_tones(results_a_t['pypy']), references=rm_tones(results_a_t['gt'])))
print(wer_metric.compute(predictions=rm_tones(results_a_t['pypy']), references=rm_tones(results_a_t['gt'])))

0.0022473685874134716
0.005696345315084975
