In [1]:

from datasets import load_metric, Dataset, load_dataset, load_from_disk
import torch, torchaudio
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import json, random, os, re
from tqdm import tqdm
from multiprocessing import Pool
import numpy as np
from torch.utils.data import DataLoader
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [3]:
ds = load_from_disk('./ds_biaobei')
# ds = ds.train_test_split(test_size=0.01)

def prepare_dataset(batch):
    resampled_audio = batch["audio"]["array"]
    batch["input_values"] = processor(resampled_audio, sampling_rate=16000).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch['text']).input_ids
    return batch

# ds_test = ds['test'].map(prepare_dataset, remove_columns=ds.column_names["train"], num_proc=32)
ds_test = ds.map(prepare_dataset, num_proc=32)

In [4]:
ds_test[0].keys()

dict_keys(['audio', 'text', 'path', 'trans', 'input_values', 'labels'])

In [5]:
class ExtendedWav2Vec2ForCTC(Wav2Vec2ForCTC):
    """
    In ESPNET there is a LayerNorm layer between encoder output and CTC classification head.
    """
    def __init__(self, config):
        super().__init__(config)
        self.lm_head = torch.nn.Sequential(
                torch.nn.LayerNorm(config.hidden_size),
                self.lm_head
        )

In [6]:
# processor = Wav2Vec2Processor.from_pretrained('./trained2/checkpoint-1000')

model = ExtendedWav2Vec2ForCTC.from_pretrained('./trained_full_tgml_n_fixed_t_ds4/checkpoint-32000')
model.to('cuda')

ExtendedWav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (2): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (3): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05

In [7]:
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
    logits = model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  batch["text"] = processor.decode(batch["labels"], group_tokens=False)
  
  batch["pred_str"] = re.sub(r'\d', lambda match: match.group() + ' ', batch["pred_str"])
  batch["text"] = re.sub(r'\d', lambda match: match.group() + ' ', batch["text"])
  return batch

results = ds_test.map(map_to_result, remove_columns=['audio', 'path', 'labels', 'input_values'])
# results = ds_test.map(map_to_result)


Map: 100%|██████████| 9999/9999 [08:00<00:00, 20.79 examples/s]


In [8]:
results

Dataset({
    features: ['text', 'trans', 'pred_str'],
    num_rows: 9999
})

In [9]:
rs = [r for r in results]
with open('./asr_ds4_pred.json', 'w', encoding='utf-8') as f:
    json.dump(rs, f, ensure_ascii=False, indent=4)

In [10]:
cer_metric = load_metric("cer")
cer_metric.compute(predictions=results['pred_str'], references=results['text'])

  cer_metric = load_metric("cer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


0.011680278199617129

In [11]:
wer_metric = load_metric("wer")
wer_metric.compute(predictions=results['pred_str'], references=results['text'])

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


0.04936323332555541

In [12]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))


In [13]:
show_random_elements(results)

Unnamed: 0,text,trans,pred_str
0,zhu4 ni3 xing4 fu2 mei2 man3 yo5,祝你幸福美满哟！,zhu4 ni3 xing4 fu2 mei2 man4 you2
1,qiang2 shang4 gao4 shi4 wei4 neng2 shuo1 ming2 ma2 tong3 de5 shang4 shu4 hua1 shao4 gong1 neng2 hui4 xiao1 hao4 duo1 shao3 dian4 liang4,墙上告示未能说明，马桶的上述花哨功能会消耗多少电量。,qiang2 shang4 gao4 shi5 wei4 neng2 shuo1 ming2 ma2 tong3 de5 shang4 shu4 hua1 shao4 gong1 neng2 hui4 xiao1 hao4 duo1 shao3 dian4 liang4
2,wo2 hen3 le4 yi4 wei4 zhu3 ren2 fu2 wu4,我很乐意为主人服务。,wo2 hen3 le4 yi4 wei4 zhu3 ren2 fu2 wu4
3,chu1 ru4 ci3 dao4 de5 tou2 zi1 zhe3 qie4 mo4 mang2 mu4 gen1 jin4,初入此道的投资者切莫盲目跟进。,chu1 ru4 ci3 dao4 de5 tou2 zi1 zhe3 qie4 mo4 mang2 mu4 gen1 jin4
4,cong2 wu3 dao4 qi1 qian2 hou4 ye3 jiu4 shi4 liu4 ge5 yue4 de5 shi2 jian1,从五到七，前后也就是六个月的时间。,cong2 wu3 dao4 qi1 qian2 hou4 ye3 jiu4 shi2 liu4 ge4 yue4 de5 shi2 jian1
5,jiang1 yi4 mei2 ji1 dan4 fang4 zai4 dao4 li4 de5 bo1 li5 bei1 di3 bu4 ye3 mei2 you3 dian1 bo3 rang4 qi2 die1 luo4,将一枚鸡蛋放在倒立的玻璃杯底部，也没有颠簸让其跌落。,jiang1 yi4 mei2 ji1 dan4 fang4 zai4 dao4 li4 de5 bo1 li5 bei1 di3 bu4 ye3 mei2 you3 dian1 bo1 rang4 qi2 die1 luo4
6,jie2 zhi4 wei3 pan2 ren2 min2 bi4 dui4 mei3 yuan2 bao4 jia4 wei2 liu4 dian3 san1 liu4 yi1 jiu3 yuan2,截至尾盘，人民币兑美元报价为六点三六一九元。,jie2 zhi4 wei3 pan2 ren2 min2 bi4 dui4 mei3 yuan2 bao4 jia4 wei2 liu4 dian3 san1 liu4 yi1 jiu3 yuan2
7,nan2 zi3 guai1 guai1 shu4 shou3 jiu4 qin2,男子乖乖束手就擒。,nan2 zi5 guai1 guai1 shu4 shou3 jiu4 qin2
8,ka2 ta2 er3 ye2 biao3 shi4 jiang1 xiang4 ba1 lin2 pai4 bing1,卡塔尔也表示将向巴林派兵。,ka2 ta3 er3 ye3 biao3 shi4 jiang1 xiang4 ba1 lin2 pai4 bing1
9,shao4 nian2 zhun3 bei4 xia4 shui3 lao1 qiu2 ke3 shi4 gang1 yi2 xia4 shui3 jiu4 dao3 zai4 shui3 zhong1,少年准备下水捞球，可是刚一下水，就倒在水中。,shao4 nian2 zhun3 bei4 xia4 shui3 lao1 qiu2 ke3 shi4 gang1 yi2 xia4 shui3 jiu4 dao3 zai4 shui3 zhong1


In [14]:

with torch.no_grad():
  logits = model(torch.tensor(ds_test[:1]["input_values"], device="cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)

# convert ids to tokens
" ".join(processor.tokenizer.convert_ids_to_tokens(pred_ids[0].tolist()))


'[PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ka2 [PAD] [PAD] [PAD] [PAD] er2 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] pu3 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] pei2 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] wai4 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] sen1 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] wan2 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] hua2 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ti1 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [15]:
results

Dataset({
    features: ['text', 'trans', 'pred_str'],
    num_rows: 9999
})

In [16]:
pred, gt = results['pred_str'], results['text']

In [17]:
confusion_list = []
total_p = 0
no_err = 0
len_err = 0

for i in range(len(pred)):
    ps = pred[i].split(' ')
    gs = gt[i].split(' ')
    if len(ps) != len(gs):
        # print('!!', ps, gs)
        len_err += 1
    else:
        for j in range(len(ps)):
            if ps[j] != gs[j]:
                confusion_list.append((ps[j], gs[j]))
                no_err += 1
            total_p += 1

In [18]:
tone_err = 0
g_err = 0

for c in confusion_list:
    a, b = c
    if a[:-1] == b[:-1]:
        tone_err += 1
    elif len(a) >= 4 and len(b) >= 4 and a[:2] == b[:2] and abs(len(b) - len(a)) == 1 and (a[-2] == 'g' or b[-2] == 'g'):
        g_err += 1
    else:
        print(a, b)
        pass

sen1 sun1
pa4 ta4
chan2 zhan2
ke4 kuo4
she2 re2
wa1 wanr1
yue4 yuan4
san1 suan1
xie2 xian2
qu3 chi1
mou3 mao3
sen1 sun1
qia4 qir4
hai3 ai3
men2 menr2
que4 quan3
mai4 nai4
tian1 tie1
bi2 bie2
yue1 ye1
shen1 shan1
san4 sa4
meng4 weng4
cheng1 chong1
lu4 lou4
ci2 cuo2
su2 song2
zhui1 zhun1
zhui1 zhun1
dao4 daor4
yan2 ren2
wu3 mu3
yue4 yuanr4
sao3 sou3
lao3 lou3
wu2 wo4
wei3 rui3
juan4 zhen4
zu1 zuo1
zheng1 zhang1
shou2 shao2
chu1 chuo1
zhe2 zhou5
xue3 xuan3
man4 mai4
jinr1 jir1
neng2 nong2
zou3 sou3
kang1 hang1
kong1 gong1
kang1 hang1
ti4 tie4
ti4 tie4
ti4 tie4
nve4 nie4
ni4 nie4
lou2 le5
qing1 jing1
le4 ne4
jia2 dia2
da3 dia3
can1 cuan1
zheng1 zhong1
sheng1 shang1
yu2 yue1
kao2 kai2
jie1 jian1
wan2 an2
mi4 bi4
o5 wo5
jie4 jian4
wo3 wang2
ya5 you1
xun4 xuan4
er2 ang2
zi5 ze4
man4 mai4
wan2 wanr2
cao3 zao3
wei3 wen3
yue1 yuan1
yi4 yir4
zun1 zhun1
wa5 a5
yong3 rong3
ce4 cuo4
jie3 jian3
xin4 xun4
qie4 qian4
er4 a1
san4 sai4
yue4 ri4
zhe2 re2
can1 cen1
zi1 ci1
zhuang4 shuang3
wei3 wai3
man4 me

In [19]:
len_results = len(results)

In [20]:
len_err, total_p, no_err, tone_err, g_err

(103, 170960, 7797, 5370, 480)

In [21]:
len_err/len_results, tone_err/no_err, g_err/no_err

(0.010301030103010301, 0.6887264332435552, 0.06156213928434013)

In [22]:
no_err, confusion_list

(7797,
 [('sen1', 'sun1'),
  ('bo2', 'bo3'),
  ('pa4', 'ta4'),
  ('yu3', 'yu2'),
  ('chan2', 'zhan2'),
  ('tui3', 'tui2'),
  ('yi3', 'yi2'),
  ('ni2', 'ni1'),
  ('shai4', 'shai1'),
  ('shai4', 'shai1'),
  ('man3', 'man2'),
  ('ke4', 'kuo4'),
  ('she2', 're2'),
  ('ye2', 'ye3'),
  ('you1', 'you5'),
  ('dian4', 'dian3'),
  ('bing4', 'bin4'),
  ('yu3', 'yu2'),
  ('wa1', 'wanr1'),
  ('yue4', 'yuan4'),
  ('xiang1', 'xiao1'),
  ('san1', 'suan1'),
  ('xie2', 'xian2'),
  ('wo3', 'wo2'),
  ('qu3', 'chi1'),
  ('mou3', 'mao3'),
  ('sen1', 'sun1'),
  ('qia4', 'qir4'),
  ('hai3', 'ai3'),
  ('ying3', 'ying2'),
  ('shen1', 'sheng1'),
  ('kong1', 'kou1'),
  ('men2', 'menr2'),
  ('ling4', 'lin4'),
  ('que4', 'quan3'),
  ('mai4', 'nai4'),
  ('qie4', 'qie1'),
  ('qie4', 'qie1'),
  ('kong3', 'kong2'),
  ('zi5', 'zi2'),
  ('tian1', 'tie1'),
  ('bi2', 'bie2'),
  ('yue1', 'ye1'),
  ('wo3', 'wo2'),
  ('wo3', 'wo2'),
  ('sao1', 'sao5'),
  ('shen1', 'shan1'),
  ('san4', 'sa4'),
  ('ma3', 'ma2'),
  ('yuan2', 'yu