In [1]:
import torchaudio
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
)
import torch
from transformers import AutoTokenizer, AutoModelWithLMHead 
import IPython.display as ipd
from math import log

In [2]:
# model_name = './facebook/wav2vec2-xls-r-300m-vol1_vol2_clean_cleanest_data/checkpoint-1400'
# processor_name = "./facebook/wav2vec2-xls-r-300m-vol1_vol2_clean_cleanest_data/"
# tokenizer = AutoTokenizer.from_pretrained("ckiplab/gpt2-base-chinese")  
# lm_model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese").to(device)
# model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
# processor = Wav2Vec2Processor.from_pretrained(processor_name)

In [3]:
# # wav -> 台文
# model_name = '/work/u9296553/aics/xls-r-fine-tuning/facebook/wav2vec2-xls-r-300m-vol1_vol2_condenser_data/checkpoint-30000'
# processor_name = '/work/u9296553/aics/xls-r-fine-tuning/facebook/wav2vec2-xls-r-300m-vol1_vol2_condenser_data/'
# tokenizer_name = processor_name
# device = "cuda"
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
# processor = Wav2Vec2Processor.from_pretrained(processor_name)
# # lm_model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese").to(device)

In [4]:
# #wav -> 台羅
# model_name = '/work/u9296553/aics/xls-r-fine-tuning/facebook/wav2vec2-xls-r-1b-vol1_vol2_condenser_tai_lo_no_spec_augment/checkpoint-57860'
# processor_name = '/work/u9296553/aics/xls-r-fine-tuning/facebook/wav2vec2-xls-r-1b-vol1_vol2_condenser_tai_lo_no_spec_augment/'
# tokenizer_name = processor_name
# device = "cuda"
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
# processor = Wav2Vec2Processor.from_pretrained(processor_name)
# lm_model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese").to(device)


In [36]:
# wav -> 台羅數字
# model_name = '/work/u9296553/aics/xls-r-fine-tuning/facebook/wav2vec2-xls-r-300m-fintune_aishell_condenser_tailo_number/checkpoint-2880'
# processor_name = '/work/u9296553/aics/xls-r-fine-tuning/facebook/wav2vec2-xls-r-300m-fintune_aishell_condenser_tailo_number/'

model_name = '/work/u9296553/aics/xls-r-fine-tuning/facebook/wav2vec2-xls-r-300m-aishell_tailo_number/checkpoint-4565'
processor_name = '/work/u9296553/aics/xls-r-fine-tuning/facebook/wav2vec2-xls-r-300m-aishell_tailo_number/'

tokenizer_name = processor_name
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2Processor.from_pretrained(processor_name)
# lm_model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese").to(device)

In [37]:
def load_file_to_data(file,sampling_rate=16_000):
    batch = {}
    speech, _ = torchaudio.load(file)
    if sampling_rate != '16_000' or sampling_rate != '16000':
        resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000)
        batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
        batch["sampling_rate"] = resampler.new_freq
    else:
        batch["speech"] = speech.squeeze(0).numpy()
        batch["sampling_rate"] = '16000'
    return batch

def predict_beam(data,beamsize=3):
    features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits
    decoded_results = []
    for logit in logits:
        sequences = [[[], 1.0]]
        pred_ids = torch.argmax(logit, dim=-1)
        mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
        vocab_size = logit.size()[-1]
        voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
        while True:
            all_candidates = list()
            exceed = False
            for seq in sequences:
                tokens, score = seq
                gpt_input = torch.tensor([tokenizer.cls_token_id]+tokens).to(device)
                gpt_prob = torch.nn.functional.softmax(lm_model(gpt_input).logits, dim=-1)[:len(gpt_input),:]
                if len(gpt_input) >= len(voice_prob):
                    exceed = True
                comb_pred_ids = gpt_prob*voice_prob[:len(gpt_input)]
                v,i = torch.topk(comb_pred_ids,50,dim=-1)
                for tok_id,tok_prob in zip(i.tolist()[-1],v.tolist()[-1]):
                    candidate = [tokens + [tok_id], score + -log(tok_prob)]
                    all_candidates.append(candidate)
            ordered = sorted(all_candidates, key=lambda tup: tup[1])
            sequences = ordered[:beamsize]
            if exceed:
                break

        for i in sequences:
            decoded_results.append(processor.decode(i[0]))

    return decoded_results

def predict(data, GPT_FIX=False):
    features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits
    
    decoded_results = []
    for logit in logits:
        pred_ids = torch.argmax(logit, dim=-1)
        mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
        vocab_size = logit.size()[-1]
        voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
        if GPT_FIX:
            gpt_input = torch.cat((torch.tensor([tokenizer.cls_token_id]).to(device),pred_ids[pred_ids>0]), 0)
            gpt_prob = torch.nn.functional.softmax(lm_model(gpt_input).logits, dim=-1)[:voice_prob.size()[0],:]
            comb_pred_ids = torch.argmax(gpt_prob*voice_prob, dim=-1)
        else: 
            comb_pred_ids = torch.argmax(voice_prob, dim=-1)
        # for wer
        pred_str = processor.decode(comb_pred_ids, skip_special_tokens=False, spaces_between_special_tokens=True)
        decoded_results.append(pred_str)
        # for cer
        # decoded_results.append(processor.decode(comb_pred_ids, skip_special_tokens=True))

    return decoded_results

In [38]:
import editdistance as ed
import csv
def cer_cal(groundtruth, hypothesis):
    err = 0
    tot = 0
    for p, t in zip(hypothesis, groundtruth):
        err += float(ed.eval(p.lower(), t.lower()))
        tot += len(t)
    return err / tot

def wer_cal(groundtruth, hypothesis):
    err = 0
    tot = 0
    for p, t in zip(hypothesis, groundtruth):
        p = p.lower().split(' ')
        t = t.lower().split(' ')

        err += float(ed.eval(p, t))
        tot += len(t)
        
    return err / tot




In [8]:
import json
import re
chars_to_ignore_regex = r"[¥•＂＃＄％＆＇（）＊＋，－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､　、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·．﹑︰〈〉─《﹖﹣﹂﹁﹔！？｡。＂＃＄％＆＇（）＊＋，﹐－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.．!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"

# eval_csv_path = '/work/u9296553/aics/data/vol1_vol2_condenser_test_tai_lo.csv'
# eval_csv_path = '/work/u9296553/aics/data/vol1_condenser_test_tai_lo_number.csv'



groundtruth_tailo = []
groundtruth_tailo_number = []
hypothesis = []
tai_wen_label = []


file = open(eval_csv_path)
csvreader = csv.reader(file)
for i, row in enumerate(csvreader):
    if i == 0:
        continue
    wav_path = row[0]
    wav_id = wav_path.replace('.wav', '')
    json_id = '-'.join(wav_id.split('-')[:-1])
    json_path = json_id.replace('-test-master', '-test-key-master').replace('/condenser/wav', '/json')
    json_path +=  '.json'
    with open(json_path, 'r') as f:
        json_data = json.load(f)
    tai_wen = json_data['漢羅台文'].replace(',', '')
    tai_wen = re.sub(chars_to_ignore_regex, '', tai_wen).lower().replace("’", "'")
    tai_wen_label.append(tai_wen)

    tai_lo = json_data['台羅'].replace(',', '')
    tai_lo = tai_lo.replace('--',' ').replace('-', ' ')
    tai_lo = re.sub(chars_to_ignore_regex, '', tai_lo).lower().replace("’", "'")

    tai_lo_number = json_data['台羅數字調'].replace(',', ' ')
    tai_lo_number = tai_lo_number.replace('--',' ').replace('-', ' ').replace("”", " ")
    tai_lo_number = re.sub(chars_to_ignore_regex, ' ', tai_lo_number).lower()
    tai_lo_number_list = tai_lo_number.split()
    tai_lo_number = ' '.join([token for token in tai_lo_number_list if token != ''])
    # tai_lo = tai_lo_number

    # label = str(row[1])
    # groundtruth.append(label)
    # groundtruth.append(tai_lo)
    groundtruth_tailo.append(tai_lo)
    groundtruth_tailo_number.append(tai_lo_number)

    
    vdata = load_file_to_data(wav_path)
    pred = predict(vdata, GPT_FIX=False)
    
    pred = ''.join(pred)
    pred = pred.replace('[UNK]', '@')

    hypothesis.append(pred)

    
    

    # wer = wer_cal(groundtruth, hypothesis)
    # print(hypothesis)
    # print(groundtruth)
    # print(wer)
    # input()

    print(f'\r {i}', end='')

# wer = wer_cal(groundtruth, hypothesis)
# cer = cer_cal(groundtruth, hypothesis)
# print()
# print(wer)
# print(cer)


# print(len(groundtruth), len(hypothesis))


 5837

In [42]:
groundtruth = groundtruth_tailo_number
# groundtruth = groundtruth_tailo
# groundtruth = tai_wen_label

In [43]:
print(len(groundtruth), len(hypothesis))
print(groundtruth[:10])
print(hypothesis[:10])

7176 7176
['tsit4 tit8 tsiap4 ing2 hiang2 tioh8 pang5 te7 san2 e5 be2 te7 tsing5 hing5', 'sui1 bong2 i2 king1 huat4 hing5 e5 siann5 tau5 tse3 kng3 e5 hing5 pun2 hu3 sik4 long2 si7 tsing3 siong5 e5', 'koh4 beh4 ui7 kok4 lai7 gua7 un7 tong7 uan5 hian3 siong7 1 tiunn1 pak4 kiann1 siong7 sui2 e5 mia5 phinn3', 'in1 ui7 kuat4 ting7 tshu2 siau1 tiau1 thian1 un1 e5 thoo2 te7 sing5 pau1 hap8 tong5 sua2 tsai1 thoo5 kha1 e5 tshiu7 tsai1', 'tshuan1 tsi1 su1 siang7 pan1 si5 kan1 tua3 tshai2 gu5 lok8 sin5 pi3 jin5 phok8 kng1 si7 tsu1 tsing3 ki3', 'in1 tsha1 put4 to1 long2 si7 kiu2 khong3 au7', 'huat4 hing5 kui1 boo5 iau2 e7 un2 poo7 khok4 tua7', 'pi7 king2 hong1 khak4 jim7 tso3 sin1 long5 tshuan1 tong2 tsi1 poo7 su1 ki3 moo5 ka1 bun5', 'sui1 jian5 kang1 sin3 poo7 tsin1 kin2 to7 san1 tiau7 liau2 au7 puann3 ku3 ue7', 'pak4 kiann1 beh4 ke3 siok8 kai2 tsin3 tshu3 theh8 san2 giap8 hua3 thui1 tsin3 hong1 sik4']
['tsit4   tit8   tsiap4   ing2   hiang2   tioh8   pang5   te7   san2 e5 be2   tsing5   hing5', 

In [64]:
# hypothesis = [x.replace(' ', '').strip() for x in hypothesis]
# print(groundtruth[:10])
# print(hypothesis[:10])

# cer = cer_cal(groundtruth, hypothesis)
# print(cer)

['我的護照號碼是九七八二八二空九空', '這盤是啥物菜', '阿鐘仔你這个囡仔是按怎', '淑珠緊猛揤掉', '趁無食', '敢有人知影文鐘厝裡的情形', '著愛久久矣才有一領新衫通穿', '霆無三聲', '恬恬咧共老婦人人鬥整理紙坯', '想欲留下親情']
['我的護照號碼是九七八二八二空九空', '這盤是啥物菜', '阿整仔你tsit个囡仔是按怎', '設主緊猛試牢', '趁無食', '敢有人知影文精厝內的精形', '著hi久仔才有錢領先三通穿', '但無三聲', '恬咧共老婦人鬥整理紙皮', '想欲留下成前']
0.33161528736749973


In [44]:
hypothesis = [' '.join(x.split()) for x in hypothesis]
print(groundtruth[:10])
print(hypothesis[:10])

['tsit4 tit8 tsiap4 ing2 hiang2 tioh8 pang5 te7 san2 e5 be2 te7 tsing5 hing5', 'sui1 bong2 i2 king1 huat4 hing5 e5 siann5 tau5 tse3 kng3 e5 hing5 pun2 hu3 sik4 long2 si7 tsing3 siong5 e5', 'koh4 beh4 ui7 kok4 lai7 gua7 un7 tong7 uan5 hian3 siong7 1 tiunn1 pak4 kiann1 siong7 sui2 e5 mia5 phinn3', 'in1 ui7 kuat4 ting7 tshu2 siau1 tiau1 thian1 un1 e5 thoo2 te7 sing5 pau1 hap8 tong5 sua2 tsai1 thoo5 kha1 e5 tshiu7 tsai1', 'tshuan1 tsi1 su1 siang7 pan1 si5 kan1 tua3 tshai2 gu5 lok8 sin5 pi3 jin5 phok8 kng1 si7 tsu1 tsing3 ki3', 'in1 tsha1 put4 to1 long2 si7 kiu2 khong3 au7', 'huat4 hing5 kui1 boo5 iau2 e7 un2 poo7 khok4 tua7', 'pi7 king2 hong1 khak4 jim7 tso3 sin1 long5 tshuan1 tong2 tsi1 poo7 su1 ki3 moo5 ka1 bun5', 'sui1 jian5 kang1 sin3 poo7 tsin1 kin2 to7 san1 tiau7 liau2 au7 puann3 ku3 ue7', 'pak4 kiann1 beh4 ke3 siok8 kai2 tsin3 tshu3 theh8 san2 giap8 hua3 thui1 tsin3 hong1 sik4']
['tsit4 tit8 tsiap4 ing2 hiang2 tioh8 pang5 te7 san2 e5 be2 tsing5 hing5', 'sui1 jian5 i2 huat4 hing5 e5 

In [45]:
wer = wer_cal(groundtruth, hypothesis)
cer = cer_cal(groundtruth, hypothesis)
print(wer)
print(cer)


0.1910180272663603
0.12250532049469438


In [22]:
vocab = tokenizer.get_vocab()
print(len(vocab))

gt_tokens = set()
for gt in groundtruth:
    gt_tokens.update(gt.split())

print(len(gt_tokens))

OOV_set = set()
for gt_token in gt_tokens:
    if gt_token not in vocab:
        OOV_set.add(gt_token)
print(len(OOV_set))
print(OOV_set)

23039
1413
121
{'\x08luh4', 'thut4', 'onn1', 'phih8', '妳si7', 'ngai7', 'khir3', 'oo9', 'me2', 'tsau7', 'oo33', 'tue2', 'siannh4', 'mi7', 'khiap4', 'uah4', 'bai51', 'nngh4', 'gua9', 'too55', 'tsak4', 'tsuinn7', 'ber2', 'poh4', 'alid', 'hann5', 'tsak8', 'luah8', 'tsher6', 'sam3', '川端康成', 'serh4', 'thit4', 'meh4', 'ker2', 'mi1', '塑化劑', 'kim7', 'nylon', 'khinn1', 'thun5', 'gio7', 'ma3', 'hinn3', '寶特瓶', 'tsam1', 'mngh8', 'liang1', 'pue2', 'lir2', 'her2', 'thiau2', 'liu7', 'lin7', '兩岸一家親', 'hah4', 'khang9', 'tser6', 'liau3', 'thoo1', '習近平', 'luah4', 'hia5', 'ji̍\x08t', 'thin5', '謝謝', 'sue2', 'ter3', 'pui7', 'nooh4', 'haih4', 'tiak8', 'ker3', 'hann9', 'berh4', '奈須氏', 'lap4', 'khuh4', 'the1', 'hiam3', 'hngh4', 'giau3', 'hannh4', 'jiau3', 'ter6', 'khuainnh4', 'huann2', 'painn2', 'sa2', 'tshiak8', 'sap8', 'terry', 'jiau5', 'khia1', 'ke5', 'tshe2', 'lim7', 'lio2', 'lang1', 'siunn5', 'honnh4', 'penn3', 'gim5', 'tshinn5', 'gim7', 'her3', 'mainnh', 'kinn5', 'lio7', 'luh4', 'tam3', '陶瓷', 'sann2', 'ph

In [26]:
for hp, gt in zip(hypothesis, groundtruth):
    hp_tokens = hp.split()
    gt_tokens = gt.split()
    if len(hp_tokens) != len(gt_tokens):
        continue
    for i in range(len(hp_tokens)):
        hp_token = hp_tokens[i]
        gt_token = gt_tokens[i]
        if gt_token in OOV_set:
            print('---OOV occur---')
            print(f'{gt_token} -> {hp_token}')



---OOV occur---
alid -> lit8
---OOV occur---
liang1 -> liang5
---OOV occur---
liang1 -> liang5
---OOV occur---
liang1 -> liang5
---OOV occur---
ma3 -> ma7
---OOV occur---
liang1 -> liang5
---OOV occur---
liang1 -> liang5
---OOV occur---
liang1 -> liang5
---OOV occur---
siunn5 -> tshiu5
---OOV occur---
sap8 -> sat4
---OOV occur---
sap8 -> sau3
---OOV occur---
tam3 -> tan3
---OOV occur---
thin5 -> thing5
---OOV occur---
sam3 -> sam1
---OOV occur---
sam3 -> sam1
---OOV occur---
lin7 -> un7
---OOV occur---
hiam3 -> t
---OOV occur---
tenn1 -> te7
---OOV occur---
khiap4 -> khiam3
---OOV occur---
tsin5 -> tsin1
---OOV occur---
honnh4 -> ho
---OOV occur---
thun5 -> tun5
---OOV occur---
jiau5 -> ing5
---OOV occur---
tiak8 -> tik4
---OOV occur---
tshe2 -> tshe1
---OOV occur---
jit -> jit8
---OOV occur---
oo9 -> mau7
---OOV occur---
ke5 -> ki5
---OOV occur---
ke5 -> ke1
---OOV occur---
ke5 -> ki5
---OOV occur---
ke5 -> ki5
---OOV occur---
tsak8 -> tsap8
---OOV occur---
phann7 -> ann7
---OOV occur

In [47]:
normal_groundtruth = []
normal_hypothesis = []

replace_map = {
    'á' : 'a',
    'à' : 'a',
    'â' : 'a',
    'ǎ' : 'a',
    'a̋' : 'a',
    'ā' : 'a',
    'a̍h' : 'ah',
    'a̍' : 'a',

    'é' : 'e',
    'è' : 'e',
    'ê' : 'e',
    'ě' : 'e',
    'ē' : 'e',
    'e̋' : 'e',
    'e̍h' : 'eh',
    'e̍' : 'e',


    'í' : 'i',
    'ì' : 'i',
    'î' : 'i',
    'ǐ' : 'i',
    'ī' : 'i',
    'i̋' : 'i',
    'i̍h' : 'ih',
    'i̍' : 'i',


    'ó' : 'o',
    'ò' : 'o',
    'ô' : 'o',
    'ǒ' : 'o',
    'ō' : 'o',
    'ő' : 'o',
    'o̍h' : 'oh',
    'o̍' : 'o',


    'ú' : 'u',
    'ù' : 'u',
    'û' : 'u',
    'ǔ' : 'u',
    'ū' : 'u',
    'ű' : 'u',
    'u̍h' : 'uh',
    'u̍' : 'u',


    'ḿ' : 'm',
    'm̀' : 'm',
    'm̂' : 'm',
    'm̌' : 'm',
    'm̄' : 'm',
    'm̋' : 'm',
    'm̍h' : 'mh',
    'm̍' : 'm',


    'ń' : 'n',
    'ǹ' : 'n',
    'n̂' : 'n',
    'ň' : 'n',
    'n̄' : 'n',
    'n̋' : 'n',
    'n̍h' : 'nh',
    'n̍' : 'n',

}


for i in range(len(groundtruth)):
    
    normalize_gt = groundtruth[i]
    normalize_hp = hypothesis[i]

    for k, v in replace_map.items():
        normalize_gt = normalize_gt.replace(k, v)
        normalize_hp = normalize_hp.replace(k, v)

    normal_groundtruth.append(normalize_gt)
    normal_hypothesis.append(normalize_hp)



In [48]:
print(len(normal_groundtruth), len(normal_hypothesis))

wer = wer_cal(normal_groundtruth, normal_hypothesis)
cer = cer_cal(normal_groundtruth, normal_hypothesis)
print()
print(wer)
print(cer)

2752 2752

0.10916782197575628
0.04215573115510965


In [17]:
# for i in range(10):
#     print(normal_groundtruth[i])
#     print(normal_hypothesis[i])
#     print(wer_cal([normal_groundtruth[i]], [normal_hypothesis[i]]))
#     input()


In [18]:
# from transformers import MBartForConditionalGeneration
# translate_model_name = '/work/u9296553/aics/seq2seq/tailo_to_taiwen_origin_tokenizer_with_unk_resize/checkpoint-1448'
# src_lang="en_XX"
# tgt_lang="zh_CN"
# trans_tokenizer = AutoTokenizer.from_pretrained(translate_model_name, src_lang=src_lang, tgt_lang=tgt_lang)
# device = "cuda"
# translate_model = MBartForConditionalGeneration.from_pretrained(translate_model_name).to(device)

In [27]:
from transformers import BertTokenizer, BartForConditionalGeneration
translate_model_name = '/work/u9296553/aics/seq2seq/checkpoints/fine-tune-TGB-data-on-TAT-tailo-number/checkpoint-1448'
# translate_model_name = '/work/u9296553/aics/seq2seq/checkpoints/bart_base_chinese_tailo_token_based/checkpoint-2172'
trans_tokenizer = BertTokenizer.from_pretrained(translate_model_name, use_fast=True)
device = "cuda"
translate_model = BartForConditionalGeneration.from_pretrained(translate_model_name).to(device)

In [28]:
hypothesis[0]

'gua2 e5 hoo7 tsiau3 ho7 be2 si7 kiu2 tshit4 pat4 ji7 pat4 ji7 khong3 kiu2 khong3'

In [30]:
predictions_tai_wen  = []
batch_size = 64
for i in range(0, len(hypothesis), batch_size):
    input_texts = hypothesis[i: i + batch_size]

    encoded_input = trans_tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True, max_length=48)
    encoded_input = encoded_input.to(device)
    if 'token_type_ids' in encoded_input : del encoded_input['token_type_ids']
    generated_tokens = translate_model.generate(**encoded_input, early_stopping=True, max_length=48)

    # generated_tokens = translate_model.generate(**encoded_input, forced_bos_token_id=21210, early_stopping=True, max_length=48)
    # generated_tokens = translate_model.generate(**encoded_input, forced_bos_token_id=trans_tokenizer.lang_code_to_id[tgt_lang], early_stopping=True, max_length=48)
    generated_tokens = [i[i != trans_tokenizer.cls_token_id ] for i in generated_tokens]
    generated_tokens = [i[i != trans_tokenizer.sep_token_id ] for i in generated_tokens]
    generated_tokens = [i[i != trans_tokenizer.pad_token_id ] for i in generated_tokens]
    generated_tokens = [i[i != trans_tokenizer.unk_token_id ] for i in generated_tokens]

    output_texts = trans_tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
    predictions_tai_wen.extend(output_texts)


    print(f'\r {i}', end='')

print()
print(len(predictions_tai_wen))

 5824
5837


In [31]:
print(len(predictions_tai_wen))
print(len(tai_wen_label))

print(predictions_tai_wen[:10])
print(tai_wen_label[:10])

5837
5837
['我 的 護 照 號 碼 是 九 七 八 二 八 二 空 九 空', '這 半 是 啥 物 菜', '阿 指 仔 你 tsit 个 囡 仔 是 按 怎', '束 子 me 試 牢', '趁 無 食', '敢 有 人 知 影 文 鐘 厝 來 的 情 形', '著 愛 久 久 仔 才 有 一 領 先 三 通 清', '淡 無 三 聲', '恬 恬 咧 共 老 婦 人 人 鬥 整 理 紙 批', '想 欲 留 學 親 情']
['我的護照號碼是九七八二八二空九空', '這盤是啥物菜', '阿鐘仔你這个囡仔是按怎', '淑珠緊猛揤掉', '趁無食', '敢有人知影文鐘厝裡的情形', '著愛久久矣才有一領新衫通穿', '霆無三聲', '恬恬咧共老婦人人鬥整理紙坯', '想欲留下親情']


In [32]:
tai_wen_label_tokens = []
for s in tai_wen_label:
    tokens = trans_tokenizer.tokenize(s)
    tai_wen_label_tokens.append(' '.join(tokens))

In [33]:
print(predictions_tai_wen[:10])
print(tai_wen_label_tokens[:10])
wer = wer_cal(tai_wen_label_tokens, predictions_tai_wen)
print(wer)

['我 的 護 照 號 碼 是 九 七 八 二 八 二 空 九 空', '這 半 是 啥 物 菜', '阿 指 仔 你 tsit 个 囡 仔 是 按 怎', '束 子 me 試 牢', '趁 無 食', '敢 有 人 知 影 文 鐘 厝 來 的 情 形', '著 愛 久 久 仔 才 有 一 領 先 三 通 清', '淡 無 三 聲', '恬 恬 咧 共 老 婦 人 人 鬥 整 理 紙 批', '想 欲 留 學 親 情']
['我 的 護 照 號 碼 是 九 七 八 二 八 二 空 九 空', '這 盤 是 啥 物 菜', '阿 鐘 仔 你 這 个 囡 仔 是 按 怎', '淑 珠 緊 猛 揤 掉', '趁 無 食', '敢 有 人 知 影 文 鐘 厝 裡 的 情 形', '著 愛 久 久 矣 才 有 一 領 新 衫 通 穿', '霆 無 三 聲', '恬 恬 咧 共 老 婦 人 人 鬥 整 理 紙 坯', '想 欲 留 下 親 情']
0.2268914935429534


In [34]:
predictions_tai_wen = [p.replace(' ','') for p in predictions_tai_wen]
print(predictions_tai_wen[:10])
print(tai_wen_label[:10])

['我的護照號碼是九七八二八二空九空', '這半是啥物菜', '阿指仔你tsit个囡仔是按怎', '束子me試牢', '趁無食', '敢有人知影文鐘厝來的情形', '著愛久久仔才有一領先三通清', '淡無三聲', '恬恬咧共老婦人人鬥整理紙批', '想欲留學親情']
['我的護照號碼是九七八二八二空九空', '這盤是啥物菜', '阿鐘仔你這个囡仔是按怎', '淑珠緊猛揤掉', '趁無食', '敢有人知影文鐘厝裡的情形', '著愛久久矣才有一領新衫通穿', '霆無三聲', '恬恬咧共老婦人人鬥整理紙坯', '想欲留下親情']


In [35]:
cer = cer_cal(tai_wen_label, predictions_tai_wen)
print(cer)

0.29097271511810247


In [47]:
predictions_tai_wen_input_label  = []

for i in range(len(groundtruth)):
    input_text = groundtruth[i]

    encoded_input = trans_tokenizer(input_text, return_tensors="pt")
    encoded_input = encoded_input.to(device)
    generated_tokens = translate_model.generate(**encoded_input, forced_bos_token_id=21210, early_stopping=True, max_length=48)
    output_text = trans_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    output_text = ''.join(output_text)
    predictions_tai_wen_input_label.append(output_text)

    print(f'\r {i}', end='')

print()

{'input_ids': tensor([[   48,    30,    32,    70,  1298,   175,   217,    39,    67,    62,
           185,    82,    34,    71,    82,    34,    71,   174,    67,    62,
           174,     2, 21209]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
tensor([[    0, 21210,   413,  2572,   677,  1021,  5600,     2]],
       device='cuda:0')


KeyboardInterrupt: Interrupted by user

In [44]:
cer = cer_cal(tai_wen_label, predictions_tai_wen_input_label)
print(cer)

0.9688691547133684


In [48]:
for i in range(len(hypothesis)):
    print(groundtruth[i])
    print(predictions_tai_wen_input_label[i])
    
    print(tai_wen_label[i])
    input()

guá ê hōo tsiò hō bé sī kiú tshit pat jī pat jī khòng kiú khòng
二試七字偏遠
我的護照號碼是九七八二八二空九空
tsit puânn sī siánn mi̍h tshài
菜in囝
這盤是啥物菜
a tsing á lí tsit ê gín á sī án tsuánn
囡仔七減二囡仔
阿鐘仔你這个囡仔是按怎
siok tsu kín mé tshi̍h tiāu
才有法度上無相借問
淑珠緊猛揤掉
thàn bô tsia̍h
趁
趁無食
kám ū lâng tsai iánn bûn tsing tshù lí ê tsîng hîng
知影絕竅你是毋是做 庄跤人較抾
敢有人知影文鐘厝裡的情形


KeyboardInterrupt: Interrupted by user

In [None]:
predictions_tai_wen_input_label  = []

for i in range(len(hypothesis)):
    tai_lo_label = groundtruth[i]
    input_text = tai_lo_label

    encoded_input = trans_tokenizer(input_text, return_tensors="pt")
    encoded_input = encoded_input.to(device)
    generated_tokens = translate_model.generate(**encoded_input, forced_bos_token_id=21210, early_stopping=True, max_length=48)
    output_text = trans_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    output_text = ''.join(output_text)
    predictions_tai_wen_input_label.append(output_text)

    print(f'\r {i}', end='')

print()

In [33]:
print(len(predictions_tai_wen))
print(len(tai_wen_label))
print(len(hypothesis))
print(len(groundtruth))

5837
5837
5837


In [38]:
for i in range(len(hypothesis)):
    print(hypothesis[i])
    print(groundtruth[i])
    
    print(predictions_tai_wen[i])
    print(predictions_tai_wen_input_label[i])
    print(tai_wen_label[i])
    input()

guá ê hōo tsio̍h hō bé sī kiú tshit pat jī pat jī khòng kiú khòng
guá ê hōo tsiò hō bé sī kiú tshit pat jī pat jī khòng kiú khòng
二逼試七
二試七字偏遠
我的護照號碼是九七八二八二空九空
tsit puânn sī siánn mih tshài
tsit puânn sī siánn mi̍h tshài
菜體的菜
菜in囝
這盤是啥物菜
a tsing á lí tsit ê gín á sī án tsuánn
a tsing á lí tsit ê gín á sī án tsuánn
囡仔七減二囡仔
囡仔七減二囡仔
阿鐘仔你這个囡仔是按怎
sio̍k tsu kim mé tshì tiâu
siok tsu kín mé tshi̍h tiāu
頂下立場
才有法度上無相借問
淑珠緊猛揤掉
thàn bô tsia̍h
thàn bô tsia̍h
趁
趁
趁無食
kám m̄ ū lâng tsai iánn bûn tsing tshù lâi ê tsîng hîng
kám ū lâng tsai iánn bûn tsing tshù lí ê tsîng hîng
幾个延伸khoo查某的
知影絕竅你是毋是做 庄跤人較抾
敢有人知影文鐘厝裡的情形
tio̍h ài kú á tsiah ū tsi̍t iánn sin sann thang tsing
tio̍h ài kú kú ah tsiah ū tsi̍t niá sin sann thang tshīng
tī船熟似逼趕的心肝
贊成三認真教囝講矣
著愛久久矣才有一領新衫通穿
tām bô sann siá
tân bô sann siann
無初見面
無聲
霆無三聲
tiàm leh kā lāu hū lîn lâng tàu tsing lí tsuā phe
tiām tiām leh kā lāu hū jîn lâng tàu tsíng lí tsuá phe
佇咧真正予無leh懸的山頂
國民政府leh舉辦就是台語地質 小姐我beh
恬恬咧共老婦人人鬥整理紙坯
siūnn beh lâu ha̍k tshing tsîng
siūnn beh l

KeyboardInterrupt: Interrupted by user

In [34]:
cer = cer_cal(tai_wen_label, predictions_tai_wen)
print(cer)

0.9704372513741478


In [39]:
cer = cer_cal(tai_wen_label, predictions_tai_wen_input_label)
print(cer)

0.9688691547133684
