In [1]:
!pip install transformers==4.15.0 sentencepiece
!pip install datasets==1.17.0 
!pip install pythainlp
!pip install jiwer
!pip install pytorch-lightning 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.15.0
  Downloading transformers-4.15.0-py3-none-any.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 29.7 MB/s 
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 61.6 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 52.5 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 70.8 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |███████████████████

In [2]:
from transformers import (
    MT5ForConditionalGeneration,
    MT5TokenizerFast,
)
from transformers import AutoModelForMaskedLM, pipeline
from transformers import AutoTokenizer, BertForTokenClassification, AutoModel
import pandas as pd
import torch
import pickle
from tqdm import tqdm
from datasets import load_metric
from pythainlp.benchmarks import word_tokenization
from nltk.translate.gleu_score import sentence_gleu
from jiwer import cer
import os
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from typing import Optional

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [3]:
# tokenizer = AutoTokenizer.from_pretrained('airesearch/wangchanberta-base-att-spm-uncased', model_max_length=512)
tokenizer_tag = pickle.load(open('drive/MyDrive/AIBuilders/json/tokenizer_json_15k.pkl', 'rb'))

FileNotFoundError: ignored

In [None]:
class BertModel(torch.nn.Module):
    def __init__(self):
        super(BertModel, self).__init__()
        self.bert = BertForTokenClassification.from_pretrained('airesearch/wangchanberta-base-att-spm-uncased', num_labels=2)
        self.bert.resize_token_embeddings(len(tokenizer_tag))

    def forward(self, input_id, mask, label):
        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
        return output

In [None]:
tagging_model = BertModel()
# model = BertForTokenClassification.from_pretrained('airesearch/wangchanberta-base-att-spm-uncased', num_labels=2)
# model.resize_token_embeddings(len(tokenizer_th))

# FILE = "drive/MyDrive/AIBuilders/tagging.pth"
FILE = "drive/MyDrive/AIBuilders/json/tagging_json_400.pth"
tagging_model.load_state_dict(torch.load(FILE, map_location=torch.device('cpu')))
tagging_model.eval()
tagging_model.cuda()

In [None]:
ids_to_labels = {0: 'f', 1: 'i'}

def evaluate_one_text(model, sentence, mask):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    input_id = torch.Tensor([sentence]).type(torch.int64)
    label_ids = []
    for i in sentence:
      if i == 1 or i == 5 or i == 6:
        label_ids.append(-100)
      else:
        label_ids.append(2)
    label_ids = torch.Tensor([label_ids]).type(torch.int64)
    mask = torch.Tensor([mask]).type(torch.int64)

    if use_cuda:
        model = model.cuda()
        input_id = input_id.cuda()
        mask = mask.cuda()
        label_ids = label_ids.cuda()

    logits = tagging_model(input_id, mask, None)
    logits_clean = logits[0][label_ids != -100]

    predictions = logits_clean.argmax(dim=1).tolist()
    prediction_label = [ids_to_labels[i] for i in predictions]
    return prediction_label

In [None]:
model_name="mt5-qg-epoch-4-train-loss-0.009-val-loss-0.2703"
mt5_model = MT5ForConditionalGeneration.from_pretrained(
    f"drive/MyDrive/mt5-thai-qg/{model_name}",
    return_dict=True,
)
tokenizer_mt5 = MT5TokenizerFast.from_pretrained(
    f"drive/MyDrive/mt5-thai-qg/{model_name}"
)

mt5_model.cuda()

MT5ForConditionalGeneration(
  (shared): Embedding(250112, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(250112, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedGeluDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (w

In [None]:
ds_tag = pickle.load(open('drive/MyDrive/AIBuilders/json/ner_json_15k_2.pkl', 'rb'))
ds_mlm = pickle.load(open('drive/MyDrive/AIBuilders/json/mlm_json_15k_2.pkl', 'rb'))
test_df = pickle.load(open('drive/MyDrive/AIBuilders/json/test_mt5_no_split_json_15k.pkl', 'rb'))

In [None]:
ds_mlm = ds_mlm[10576:]
ds_tag = ds_tag[10576:]
NUM_SAMPLE = ds_mlm.shape[0]
NUM_SAMPLE

5000

In [None]:
for i in range(10):
  a = ds_tag.iloc[i]['text']['input_ids'].squeeze(0).tolist()
  b = tokenizer_tag.convert_ids_to_tokens(a)
  print("".join(b))

<s>▁จําเป็นต้องเปิดบริการกับสาขาที่เปิดบัญชี▁มั๊ยครับ▁หรือสาขาไหนก็ได้</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [None]:
test_df

Unnamed: 0,source_text,target_text
0,แก้คำผิด: จําเป็นต้องเปิดบริการกับสาขาที่เปิด...,จําเป็นต้องเปิดบริการกับสาขาที่เปิดบัญชี ไหมคร...
1,แก้คำผิด: เมื่อวาน<mask>ยังเข้าได้,เมื่อวาน ก็ยังเข้าได้
2,แก้คำผิด: ผม<mask>ต่างประเทศครับ,ผมอยู่ต่างประเทศครับ
3,แก้คำผิด: พอไปสมัครที่<mask><mask>,พอไปสมัครที่ตู้
4,แก้คำผิด: <mask>มีปัญหา<mask><mask>ตอนนี้,แอพมีปัญหาหรือคะตอนนี้
...,...,...
4995,แก้คำผิด: ไม่ทราบว่าทําการสมัครบริการซื้อสินค...,ไม่ทราบว่าทําการสมัครบริการซื้อสินค้าทาง อินเท...
4996,แก้คำผิด: จะ<mask><mask><mask><mask>,จะสมัคร อย่างไรครับ
4997,แก้คำผิด: อยาก<mask>ว่าจะเข้าระบบ <mask><mask>,อยาก รู้ว่าจะเข้าระบบ อย่างไรคะ
4998,แก้คำผิด: เข้าไม่ได้ <mask><mask>ครับ,เข้าไม่ได้ เป็นอะ ไร ครับ


In [None]:
msp_type = pickle.load(open('drive/MyDrive/AIBuilders/json/msp_type.pkl', 'rb'))
msp_type = msp_type[10576:]
msp_word = pickle.load(open('drive/MyDrive/AIBuilders/json/msp_word.pkl', 'rb'))
msp_word = msp_word[10576:]
print(len(msp_type), len(msp_word))

5000 5000


In [None]:
msp_word

[[[('มั๊ย', 'ไหม')]],
 [[('ก้', 'ก็')]],
 [[('อยุ่', 'อยู่')]],
 [[('ตุ้', 'ตู้')]],
 [[('แอฟ', 'แอพ')], [('หรอ', 'หรือ'), ('ค่ะ', 'คะ')]],
 [[('คับ', 'ครับ')], [('รุ้', 'รู้')]],
 [[('คับ', 'ครับ')]],
 [[('เบอร', 'เบอร์')]],
 [[('ใหน', 'ไหน')]],
 [[('อัพเดต', 'อัปเดต')]],
 [[('เช้ค', 'เช็ค')]],
 [[('เบอ', 'เบอร์')], [('คับ', 'ครับ')]],
 [[('แอฟ', 'แอป')], [('อยาก', 'ยาก')]],
 [[('แอพ', 'แอป')]],
 [[('เชค', 'เช็ค')], [('ยังไง', 'อย่างไร')]],
 [[('จ้า', 'จ้ะ')]],
 [[('บัชชี', 'บัญชี')]],
 [[('ยังใง', 'อย่างไร')]],
 [[('คร้าบบ?', 'ครับ')]],
 [[('ตัง', 'สตางค์')], [('เปน', 'เป็น')], [('ค่ะ', 'คะ')]],
 [[('ยังงัย', 'ยังไง')]],
 [[('ยังไง', 'อย่างไร')]],
 [[('อ่อ', 'อ๋อ'), ('คะ', 'ค่ะ')], [('คะ', 'ค่ะ')]],
 [[('คร้า', 'ค่ะ')]],
 [[('ใหม', 'ไหม')]],
 [[('อ่ะ', 'อะ')]],
 [[('ป่าว', 'เปล่า'), ('คัฟ', 'ครับ')]],
 [[('อินเตอร์เน็ต', 'อินเทอร์เน็ต')]],
 [[('อ่ะ', 'อะ')]],
 [[('ก้', 'ก็')]],
 [[('อ่ะ', 'อะ')]],
 [[('อ่อ', 'อ๋อ')], [('มั้ย', 'ไหม')]],
 [[('แอพ', 'แอป'), ('ไหม่', 'ใหม่'), ('หรอ', 'เ

In [None]:
msp_type

[[['morphed', 10, 11]],
 [['misspelled', 1, 2]],
 [['misspelled', 1, 2]],
 [['misspelled', 4, 6]],
 [['misspelled', 1, 2], ['morphed', 3, 5]],
 [['morphed', 1, 2], ['misspelled', 3, 4]],
 [['misspelled', 6, 7]],
 [['misspelled', 2, 4]],
 [['misspelled', 4, 5]],
 [['misspelled', 3, 4]],
 [['misspelled', 1, 4]],
 [['misspelled', 3, 4], ['morphed', 5, 6]],
 [['misspelled', 1, 2], ['misspelled', 4, 5]],
 [['misspelled', 3, 4]],
 [['misspelled', 2, 3], ['morphed', 7, 8]],
 [['morphed', 3, 4]],
 [['misspelled', 3, 6]],
 [['morphed', 4, 7]],
 [['misspelled', 1, 4]],
 [['morphed', 6, 7], ['morphed', 13, 14], ['misspelled', 15, 16]],
 [['misspelled', 2, 3]],
 [['morphed', 3, 4]],
 [['misspelled', 1, 3], ['misspelled', 4, 5]],
 [['morphed', 3, 4]],
 [['misspelled', 6, 7]],
 [['misspelled', 7, 8]],
 [['misspelled', 2, 4]],
 [['misspelled', 4, 6]],
 [['misspelled', 2, 3]],
 [['misspelled', 4, 5]],
 [['misspelled', 3, 4]],
 [['misspelled', 1, 2], ['morphed', 12, 13]],
 [['misspelled', 4, 9]],
 [['m

In [None]:
def predict(text):
    with torch.no_grad():
        input_ids = tokenizer_mt5.encode(text, return_tensors="pt", add_special_tokens=True)

        input_ids = input_ids.cuda()

        # print(input_ids)
        mt5_model.cuda()
        generated_ids = mt5_model.generate(
            input_ids=input_ids,
            num_beams=3,
            max_length=10000,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
            top_p=50,
            top_k=20,
            num_return_sequences=1,
        )

        # print(generated_ids)

        preds = [
            tokenizer_mt5.decode(
                g,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
            for g in generated_ids
        ]
    return preds


In [None]:
def ids_to_tokens(tokenized_text):
  a = tokenizer_mt5.convert_ids_to_tokens(tokenized_text)
  return a

In [None]:
msp_type_dict = {}
msp_type_dict_full = {}
msp_type_dict_fill = {}
sent_id_type = {}
msp_word_dict = {}
msp_word_dict_full = {}
msp_word_dict_fill = {}

NUM_SAMPLE = 5000
len_msp_type = len(msp_type)
# for i in range(len_msp_type):
for i in range(NUM_SAMPLE):
  len_a = len(msp_type[i])
  for j in range(len_a):
    k = msp_type[i][j][0]
    if not k in msp_type_dict:
      msp_type_dict[k] = 1
      msp_type_dict_full[k] = 0
      msp_type_dict_fill[k] = 0
      sent_id_type[k] = []
    else:
      msp_type_dict[k] += 1
    
    len_b = len(msp_word[i][j])
    for k in range(len_b):
      x = msp_word[i][j][k]
      if not x in msp_word_dict:
        msp_word_dict[x] = 1
        msp_word_dict_full[x] = 0
        msp_word_dict_fill[x] = 0
      else:
        msp_word_dict[x] += 1

msp_word_dict_wrong = msp_word_dict_fill.copy()
print(msp_type_dict)
print(msp_type_dict_fill)
print(msp_type_dict_full)
msp_word_dict = dict(sorted(msp_word_dict.items(), key=lambda item: item[1], reverse=True))
print(msp_word_dict)
print(msp_word_dict_fill)
print(msp_word_dict_full)

{'morphed': 1818, 'misspelled': 4509, 'ws': 36, 'other': 38, 'abbreviation': 190, 'new': 1}
{'morphed': 0, 'misspelled': 0, 'ws': 0, 'other': 0, 'abbreviation': 0, 'new': 0}
{'morphed': 0, 'misspelled': 0, 'ws': 0, 'other': 0, 'abbreviation': 0, 'new': 0}
{('อ่ะ', 'อะ'): 626, ('ค่ะ', 'คะ'): 487, ('คับ', 'ครับ'): 466, ('คะ', 'ค่ะ'): 345, ('เบอ', 'เบอร์'): 315, ('มั้ย', 'ไหม'): 205, ('สมัค', 'สมัคร'): 186, ('แอพ', 'แอป'): 170, ('เปน', 'เป็น'): 153, ('ค้ะ', 'คะ'): 148, ('ค้ะ', 'ค่ะ'): 146, ('อ่อ', 'อ๋อ'): 143, ('ยังไง', 'อย่างไร'): 136, ('ก้', 'ก็'): 135, ('ใด้', 'ได้'): 106, ('คัฟ', 'ครับ'): 90, ('มั๊ย', 'ไหม'): 88, ('เรย', 'เลย'): 70, ('โทรศัพ', 'โทรศัพท์'): 70, ('ใหม', 'ไหม'): 68, ('ไง', 'อย่างไร'): 67, ('ไม', 'ไหม'): 67, ('เคดิต', 'เครดิต'): 62, ('บช', 'บัญชี'): 56, ('แอฟ', 'แอป'): 55, ('ตัง', 'สตางค์'): 55, ('ม่', 'ไม่'): 55, ('อยุ่', 'อยู่'): 53, ('บันชี', 'บัญชี'): 53, ('ไหม่', 'ใหม่'): 51, ('อ้ะ', 'อะ'): 49, ('รุ้', 'รู้'): 48, ('ก้อ', 'ก็'): 48, ('แร้ว', 'แล้ว'): 46, ('ใหน', 'ไหน

In [None]:
total_gs_ori = 0
total_gs_pred = 0
total_cer_ori = 0
total_cer_pred = 0

for sent_id in tqdm(range(NUM_SAMPLE)):
  chng = []
  text = ds_tag.iloc[sent_id]['text']['input_ids'].squeeze(0).tolist()
  mask = ds_mlm.iloc[sent_id]['attention_mask']
  labels = test_df.iloc[sent_id]['target_text']
  text = [k for k in text if k != 1]
  mask = [k for k in mask if k != 0]
  labels = tokenizer_mt5(labels)['input_ids']
  # print(text)
  # print(labels)
  # print(mask)
  original = "".join(tokenizer_tag.convert_ids_to_tokens(text))
  original = original.replace("<s>", "")
  original = original.replace("</s>", "")
  original = tokenizer_mt5(original)['input_ids']
  original_id = original.copy()
  original = [ids_to_tokens(original)]
  references = [[ids_to_tokens(labels)]]
  moi = ''.join(original[0])
  moi = moi.replace("▁", " ")
  # print(f"TEXT: {moi}")
  

  i_f = evaluate_one_text(tagging_model, text, mask)
  predicted_tokens = tokenizer_tag.convert_ids_to_tokens(text.copy())

  original_tokens = predicted_tokens.copy()

  i_f_len = len(i_f)
  for j in range(i_f_len):
    if(i_f[j] == 'i'):
      ph = predicted_tokens[j+1]

      predicted_tokens[j+1] = "<mask>"
      # concat_input = "".join(tokenizer_tag.convert_ids_to_tokens(predicted_id))
      concat_input = "".join(predicted_tokens)
      concat_input = concat_input.replace("<s>", "")
      concat_input = concat_input.replace("</s>", "")
      concat_input = concat_input.replace("▁", " ")
      concat_input = concat_input.strip()
      mt5_input = "แก้คำผิด: " + concat_input
      # print(f"INPUT: {concat_input}")
      
      predicted = predict(mt5_input)[0]
      # print(f"PRED: {predicted}")

      # find mask
      a = concat_input
      b = predicted
      while len(a)>0 and len(b)>0 and a[0] == b[0]:
        a = a[1:]
        b = b[1:]
      while len(a)>0 and len(b)>0 and a[-1] == b[-1]:
        a = a[:-1]
        b = b[:-1]
      # print(f"{a} = {b}")
      # a = tokenizer_mt5(predicted)
      # print(f"INPUTID: {input_mt5}")
      # print(tokenizer_mt5.convert_ids_to_tokens(input_mt5))
      # print(f"PRED ID: {a['input_ids'][0]}")
      # b = tokenizer_mt5.convert_ids_to_tokens(a['input_ids'][0])
      # print(b)
      # print(f"LEN ORI: {len(input_mt5)}, LEN PRED: {len(b)}")

      # token_logits = mlm_model(**mlm_input).logits
      # mask_token_index = torch.where(mlm_input["input_ids"] == tokenizer.mask_token_id)[1]
      # mask_token_logits = token_logits[0, mask_token_index, :]
      # top_5_tokens = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()
      # # print(f"{tokenizer.convert_ids_to_tokens(ph)} ({ph}) => {tokenizer.convert_ids_to_tokens(top_5_tokens[0])} ({top_5_tokens[0]})")
      chng.append((j, b))
      predicted_tokens[j+1] = ph
      if (ph, b) in msp_word_dict_fill:
        msp_word_dict_fill[(ph, b)] += 1
        # print(f"({ph},{b})")

  for x,y in chng:
    predicted_tokens[x+1] = y

  labels_len = len(labels)  

  predicted_tokens = [k for k in predicted_tokens if k != "<s>"]
  predicted_tokens = [k for k in predicted_tokens if k != "</s>"]
  ans = tokenizer_mt5("".join(predicted_tokens))['input_ids']
  predicted_id = ans.copy()
  ans = tokenizer_mt5.convert_ids_to_tokens(ans)
  # if "<s>" in ans:
  #   ans.remove('<s>')
  # if "</s>" in ans:
  #   ans.remove('</s>')
  # if ans[0] == '▁':
  #   ans.pop(0)
  predictions = [ans]  
  ans = ''.join(ans)
  ans = ans.replace("▁", " ")

  references[0][0] = [k.replace("▁", " ") for k in references[0][0]]
  original[0] = [k.replace("▁", " ") for k in original[0]]
  predictions[0] = [k.replace("▁", " ") for k in predictions[0]]

  # print(f"REF: {references}")
  # print(f"PRE: {predictions}")
  # print(f"ORI: {original}")
  gleu_score = sentence_gleu(references[0], predictions[0], min_len=1, max_len=4)
  total_gs_pred += gleu_score
  gleu_score_original = sentence_gleu(references[0], original[0], min_len=1, max_len=4)
  total_gs_ori += gleu_score_original

  cer_text_ori = "".join(original[0])
  cer_text_pred = "".join(predictions[0])
  cer_text_ref = "".join(references[0][0])
  cer_text_ori = cer_text_ori.replace("_", "")
  cer_text_ori = cer_text_ori.replace("▁", "")
  cer_text_pred = cer_text_pred.replace("_", "")
  cer_text_pred = cer_text_pred.replace("▁", "")
  cer_text_ref = cer_text_ref.replace("_", "")
  cer_text_ref = cer_text_ref.replace("▁", "")
  # print(f"ORI: {cer_text_ori}")
  # print(f"PRED: {cer_text_pred}")
  # print(f"REF: {cer_text_ref}")
  cer_original = cer(cer_text_ori, cer_text_ref)
  total_cer_ori += cer_original
  cer_pred = cer(cer_text_pred, cer_text_ref)
  total_cer_pred += cer_pred

  # print(f"MSP TYPE: {msp_type[sent_id]}")
  # for i in msp_type[sent_id]:
  #   msp_type_dict_full[i[0]] += 1
  #   flag = 0
  #   a = []
  #   for j in range(i[1]+1, i[2]+1):
  #     a.append((original_id[j], predicted_id[j]))
  #     # print(tokenizer.convert_ids_to_tokens(predicted_id[j]), tokenizer.convert_ids_to_tokens(labels[j]))
  #     if predicted_id[j] != labels[j]:
  #       flag = 1
  #       break
  #   if flag == 0:
  #     msp_type_dict_fill[i[0]] += 1
  #     if sent_id not in sent_id_type[i[0]]: #it may look too little because it only takes full marks
  #       sent_id_type[i[0]].append([sent_id, a])

  # print(f"PREDICTE ID: {predicted_id}")
  # print(f"ORIGINAL ID: {original_id}")
  # print(f"LABELS   ID: {labels}")
  # print(f"PRED: {len(predicted_id)}, ORI: {len(original_id)}, LABEL: {len(labels)}")
  # for i in range(labels_len):
  #   if labels[i] != original_id[i]: #change
  #     word_original = tokenizer_mt5.convert_ids_to_tokens(original_id[i])
  #     word_original = word_original.replace("▁", "")
  #     word_predicted = tokenizer_mt5.convert_ids_to_tokens(predicted_id[i])
  #     word_predicted = word_predicted.replace("▁", "")
  #     word_labels = tokenizer_mt5.convert_ids_to_tokens(labels[i])
  #     word_labels = word_labels.replace("▁", "")
  #     # if (word_original, word_predicted) in msp_word_dict_fill: #have key
  #     if (word_original, word_labels) in msp_word_dict_fill: #have key
  #       # msp_word_dict_full[(word_original, word_predicted)] += 1
  #       msp_word_dict_full[(word_original, word_labels)] += 1
  #       if labels[i] == predicted_id[i]: #correct
  #         # msp_word_dict_fill[(word_original, word_predicted)] += 1
  #         msp_word_dict_fill[(word_original, word_labels)] += 1
  #       elif labels[i] != predicted_id[i]: #incorrect
  #         # msp_word_dict_wrong[(word_original, word_predicted)] += 1
  #         msp_word_dict_wrong[(word_original, word_labels)] += 1

 


  # print(f"GLUE PREDICTED: {gleu_score}")
  # print(f"GLUE ORIGINAL: {gleu_score_original}")
  # print(f"CER PREDICTED: {cer_pred}")
  # print(f"CER ORIGINAL: {cer_original}")
  # print("--------------------------------------------")
# print(f"AVG ACC: {float(totalacc/NUM_SAMPLE)}")
# print(f"AVG f1: {float(total_f1/NUM_SAMPLE)}")
# print(f"# HIGHER BLEU PREDICTION: {bleu_higher}")
# print(f"# LOWER BLEU PREDICTION: {bleu_lower}")
# print(f"# EQUAL BLEU PREDICTION: {bleu_equal}")
print(f"# GLEU PREDICTION: {float(total_gs_pred/NUM_SAMPLE)}")
print(f"# GLEU ORIGINAL: {float(total_gs_ori/NUM_SAMPLE)}")
print(f"# CER PREDICTION: {float(total_cer_pred/NUM_SAMPLE)}")
print(f"# CER ORIGINAL: {float(total_cer_ori/NUM_SAMPLE)}")
# print(msp_type_dict_full)
# print(msp_type_dict_fill)
# msp_word_dict_full = dict(sorted(msp_word_dict_full.items(), key=lambda item: item[1], reverse=True))
msp_word_dict_fill = dict(sorted(msp_word_dict_fill.items(), key=lambda item: item[1], reverse=True))
# msp_word_dict_wrong = dict(sorted(msp_word_dict_wrong.items(), key=lambda item: item[1], reverse=True))
# print(msp_word_dict_full)
print(msp_word_dict)
print(msp_word_dict_fill)
# print(msp_word_dict_wrong)
# print(sent_id_type)

100%|██████████| 5000/5000 [1:47:35<00:00,  1.29s/it]

# GLEU PREDICTION: 0.6518720137207179
# GLEU ORIGINAL: 0.5054915367683877
# CER PREDICTION: 0.16382701521622978
# CER ORIGINAL: 0.09989535645117614
{('อ่ะ', 'อะ'): 626, ('ค่ะ', 'คะ'): 487, ('คับ', 'ครับ'): 466, ('คะ', 'ค่ะ'): 345, ('เบอ', 'เบอร์'): 315, ('มั้ย', 'ไหม'): 205, ('สมัค', 'สมัคร'): 186, ('แอพ', 'แอป'): 170, ('เปน', 'เป็น'): 153, ('ค้ะ', 'คะ'): 148, ('ค้ะ', 'ค่ะ'): 146, ('อ่อ', 'อ๋อ'): 143, ('ยังไง', 'อย่างไร'): 136, ('ก้', 'ก็'): 135, ('ใด้', 'ได้'): 106, ('คัฟ', 'ครับ'): 90, ('มั๊ย', 'ไหม'): 88, ('เรย', 'เลย'): 70, ('โทรศัพ', 'โทรศัพท์'): 70, ('ใหม', 'ไหม'): 68, ('ไง', 'อย่างไร'): 67, ('ไม', 'ไหม'): 67, ('เคดิต', 'เครดิต'): 62, ('บช', 'บัญชี'): 56, ('แอฟ', 'แอป'): 55, ('ตัง', 'สตางค์'): 55, ('ม่', 'ไม่'): 55, ('อยุ่', 'อยู่'): 53, ('บันชี', 'บัญชี'): 53, ('ไหม่', 'ใหม่'): 51, ('อ้ะ', 'อะ'): 49, ('รุ้', 'รู้'): 48, ('ก้อ', 'ก็'): 48, ('แร้ว', 'แล้ว'): 46, ('ใหน', 'ไหน'): 43, ('หรอ', 'หรือ'): 41, ('ตุ้', 'ตู้'): 40, ('ขอบคุน', 'ขอบคุณ'): 40, ('เบอร', 'เบอร์'): 37, ('หรอ', 'เ




In [None]:
tokenizer_mt5.convert_tokens_to_ids(["<mask>", '$'])
# tokenizer_mt5.convert_ids_to_tokens([259, 26461, 25301])
# references = [[1,2,3]]
# predictions = [1,2,3]
# print(sentence_gleu(references, predictions, min_len=1, max_len=4))

# print(cer("0123", "123"))

[2, 1279]

In [None]:
pickle_model_json = {}
pickle_model_json['msp_type_dict_full'] = msp_type_dict_full
pickle_model_json['msp_type_dict_fill'] = msp_type_dict_fill
pickle_model_json['msp_word_dict_full'] = msp_word_dict_full
pickle_model_json['msp_word_dict_fill'] = msp_word_dict_fill
pickle_model_json['msp_word_dict_wrong'] = msp_word_dict_wrong
pickle_model_json['sent_id_type'] = sent_id_type

pickle.dump(pickle_model_json, open('pickle_model_json.pkl', 'wb'))

In [None]:
pickle_model_json['msp_type_dict_full']

{'abbreviation': 190,
 'misspelled': 4509,
 'morphed': 1818,
 'new': 1,
 'other': 38,
 'ws': 36}

In [None]:
for key in sent_id_type:
  print(f"=========[{key}]==========")
  for sent_id, words in sent_id_type[key]:
    text = ds_tag.iloc[sent_id]['text']['input_ids'].squeeze(0).tolist()
    labels = ds_mlm.iloc[sent_id]['labels']
    text = [k for k in text if k != 1]
    labels = [k for k in labels if k != 1]
    original = [ids_to_tokens(text)]
    references = [[ids_to_tokens(labels)]]
    for i in words:
      print(f"MSP: {tokenizer.convert_ids_to_tokens(i[0])} => {tokenizer.convert_ids_to_tokens(i[1])}")
    print(original)
    print(references)
    print("-------------------")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
MSP: เปน => เป็น
[['<s>', '▁ถ้า', 'เปน', 'แบบนี้', 'ทําไงดี', 'คะ', '</s>']]
[[['<s>', '▁ถ้า', 'เป็น', 'แบบนี้', 'ทําไงดี', 'คะ', '</s>']]]
-------------------
MSP: เบอ => เบอร์
[['<s>', '▁', 'เปลี่ยน', 'เบอ', 'ตรงไหน', 'หรอ', 'คะ', '▁', 'หาไม่เจอ', 'ค่ะ', '</s>']]
[[['<s>', '▁', 'เปลี่ยน', 'เบอร์', 'ตรงไหน', 'เหรอ', 'คะ', '▁', 'หาไม่เจอ', 'ค่ะ', '</s>']]]
-------------------
MSP: ค่ะ => คะ
[['<s>', '▁พี่', 'ค่ะ', '.', '</s>']]
[[['<s>', '▁พี่', 'คะ', '.', '</s>']]]
-------------------
MSP: แอพ => แอป
[['<s>', '▁', 'เปลี่ยน', 'มือถือ', 'ละ', 'เข้า', 'แอพ', 'scb', 'ไม่ได้', '▁', 'เครียด', 'เด้อ', '</s>']]
[[['<s>', '▁', 'เปลี่ยน', 'มือถือ', 'ละ', 'เข้า', 'แอป', 'scb', 'ไม่ได้', '▁', 'เครียด', 'เด้อ', '</s>']]]
-------------------
MSP: คะ => ค่ะ
[['<s>', '▁', 'อยากสอบถาม', 'เรื่อง', 'โอน', 'ตัง', 'เข้า', 'พร้อมเพย์', '▁', 'คะ', '</s>']]
[[['<s>', '▁', 'อยากสอบถาม', 'เรื่อง', 'โอน', 'สตางค์', 'เข้า', 'พร้อมเพย์', '▁', 'ค่ะ',

In [None]:
pickle_model_json = pickle.load(open('drive/MyDrive/AIBuilders/json/pickle_model_json.pkl', 'rb'))
msp_type_dict_full = pickle_model_json['msp_type_dict_full']
msp_type_dict_fill = pickle_model_json['msp_type_dict_fill']
msp_word_dict_full = pickle_model_json['msp_word_dict_full']
msp_word_dict_fill = pickle_model_json['msp_word_dict_fill']
msp_word_dict_wrong = pickle_model_json['msp_word_dict_wrong']
sent_id_type = pickle_model_json['sent_id_type']

In [None]:
pickle_model_json

{'msp_type_dict_fill': {'abbreviation': 54,
  'misspelled': 2147,
  'morphed': 868,
  'new': 0,
  'other': 14,
  'ws': 6},
 'msp_type_dict_full': {'abbreviation': 190,
  'misspelled': 4509,
  'morphed': 1818,
  'new': 1,
  'other': 38,
  'ws': 36},
 'msp_word_dict_fill': {(' ', '.'): 0,
  ('.', '.'): 0,
  ('K My play', 'K My pay'): 0,
  ('k', 'k plus'): 0,
  ('k plus', 'k plus'): 0,
  ('plus', 'k plus'): 0,
  ('plusอ่ะคับ', 'อะครับ'): 0,
  ('ก', 'ก'): 0,
  ('ก กรุงไทย', 'กรุงไทย'): 0,
  ('กรุงเทพฯ', 'กรุงเทพ'): 0,
  ('กสิกรช', 'กสิกร'): 0,
  ('กอีก', 'อีก'): 0,
  ('กะ', 'กับ'): 2,
  ('กะ', 'ก็'): 1,
  ('กะตัง', 'สตางค์'): 0,
  ('กับ', 'กลับ'): 0,
  ('กับ', 'กัน'): 0,
  ('กัล', 'กัน'): 0,
  ('กาก', 'การ'): 0,
  ('กุ', 'กู'): 0,
  ('กุ้', 'กู้'): 0,
  ('ก็เรย', 'ก็เลย'): 0,
  ('ก่า', 'กว่า'): 0,
  ('ก้', 'ก็'): 113,
  ('ก้คือ', 'ก็คือ'): 0,
  ('ก้ดี', 'ก็ดี'): 0,
  ('ก้วิธีการ', 'ก็วิธีการ'): 0,
  ('ก้อ', 'ก็'): 32,
  ('ก้เลย', 'ก็เลย'): 0,
  ('ขอ', 'ของ'): 0,
  ('ขอบคุน', 'ขอบคุณ'): 30,

In [None]:
percent_correct = {}
for key in msp_word_dict_full:
  fill = msp_word_dict_fill[key]
  full = msp_word_dict_full[key]
  if full == 0 or full < 30:
    percent_correct[key] = 0
  else:
    percent_correct[key] = fill/full
a = dict(sorted(percent_correct.items(), key=lambda item: item[1], reverse=True))
print(a)

{('มั้ย', 'ไหม'): 0.9560975609756097, ('ไง', 'อย่างไร'): 0.9253731343283582, ('อ่อ', 'อ๋อ'): 0.916083916083916, ('มั๊ย', 'ไหม'): 0.9090909090909091, ('ม่', 'ไม่'): 0.9090909090909091, ('ค้ะ', 'คะ'): 0.8851351351351351, ('ไม', 'ไหม'): 0.8805970149253731, ('เบอ', 'เบอร์'): 0.8721590909090909, ('ใหม', 'ไหม'): 0.8676470588235294, ('ใด้', 'ได้'): 0.8584905660377359, ('เปน', 'เป็น'): 0.8431372549019608, ('ก้', 'ก็'): 0.837037037037037, ('แอฟ', 'แอป'): 0.8363636363636363, ('รุ้', 'รู้'): 0.8333333333333334, ('อ้ะ', 'อะ'): 0.7959183673469388, ('อยุ่', 'อยู่'): 0.7924528301886793, ('ค่ะ', 'คะ'): 0.7893660531697342, ('ยังไง', 'อย่างไร'): 0.7883211678832117, ('อ่ะ', 'อะ'): 0.7875399361022364, ('ใม่', 'ไม่'): 0.78125, ('บช', 'บัญชี'): 0.7678571428571429, ('ใหน', 'ไหน'): 0.7674418604651163, ('ขอบคุน', 'ขอบคุณ'): 0.75, ('ค้ะ', 'ค่ะ'): 0.726027397260274, ('คะ', 'ค่ะ'): 0.7130434782608696, ('เค้า', 'เขา'): 0.6842105263157895, ('ตัง', 'สตางค์'): 0.6727272727272727, ('ก้อ', 'ก็'): 0.6666666666666666, ('

In [None]:
percent_wrong = {}
for key in msp_word_dict_full:
  fill = msp_word_dict_wrong[key]
  full = msp_word_dict_full[key]
  if full == 0 or full < 30:
    percent_wrong[key] = 0
  else:
    percent_wrong[key] = fill/full
a = dict(sorted(percent_wrong.items(), key=lambda item: item[1], reverse=True))
print(a)

{('อะ', 'อ่ะ'): 1.0, ('ป่ะ', 'เปล่า'): 1.0, ('ป่าว', 'เปล่า'): 1.0, ('แอฟ', 'แอพ'): 1.0, ('โทร', 'โทรศัพท์'): 0.9933333333333333, ('หรอ', 'เหรอ'): 0.8918918918918919, ('หรอ', 'หรือ'): 0.8780487804878049, ('ค', 'ค่ะ'): 0.8545454545454545, ('ใช่', 'ใช้'): 0.8125, ('คัฟ', 'ครับ'): 0.7222222222222222, ('ค่า', 'ค่ะ'): 0.6451612903225806, ('คับ', 'ครับ'): 0.6209850107066381, ('เรย', 'เลย'): 0.6142857142857143, ('แร้ว', 'แล้ว'): 0.5652173913043478, ('ค', 'คะ'): 0.4, ('ไช่', 'ใช่'): 0.3548387096774194, ('แอพ', 'แอป'): 0.35294117647058826, ('ก้อ', 'ก็'): 0.3333333333333333, ('ตัง', 'สตางค์'): 0.32727272727272727, ('เค้า', 'เขา'): 0.3157894736842105, ('คะ', 'ค่ะ'): 0.28695652173913044, ('ค้ะ', 'ค่ะ'): 0.273972602739726, ('ขอบคุน', 'ขอบคุณ'): 0.25, ('ใหน', 'ไหน'): 0.23255813953488372, ('บช', 'บัญชี'): 0.23214285714285715, ('ใม่', 'ไม่'): 0.21875, ('อ่ะ', 'อะ'): 0.2124600638977636, ('ยังไง', 'อย่างไร'): 0.2116788321167883, ('ค่ะ', 'คะ'): 0.21063394683026584, ('อยุ่', 'อยู่'): 0.20754716981132076, 

In [None]:
for key in sent_id_type:
  cou = 0
  print(f"=========[{key}]==========")
  for sent_id, words in sent_id_type[key]:
    cou += 1
    if cou == 20: break
    text = ds_tag.iloc[sent_id]['text']['input_ids'].squeeze(0).tolist()
    labels = ds_mlm.iloc[sent_id]['labels']
    text = [k for k in text if k != 1]
    labels = [k for k in labels if k != 1]
    original = [ids_to_tokens(text)]
    references = [[ids_to_tokens(labels)]]
    for i in words:
      print(f"MSP: {tokenizer.convert_ids_to_tokens(i[0])} => {tokenizer.convert_ids_to_tokens(i[1])}")
    print(original)
    print(references)
    print("-------------------")

MSP: มั๊ย => ไหม
[['<s>', '▁', 'จําเป็น', 'ต้อง', 'เปิดบริการ', 'กับ', 'สาขา', 'ที่', 'เ', 'ปิดบัญชี', '▁', 'มั๊ย', 'ครับ', '▁', 'หรือ', 'สาขา', 'ไหนก็ได้', '</s>']]
[[['<s>', '▁', 'จําเป็น', 'ต้อง', 'เปิดบริการ', 'กับ', 'สาขา', 'ที่', 'เ', 'ปิดบัญชี', '▁', 'ไหม', 'ครับ', '▁', 'หรือ', 'สาขา', 'ไหนก็ได้', '</s>']]]
-------------------
MSP: คับ => ครับ
[['<s>', '▁พี่', 'คับ', 'ผมอยาก', 'รุ้', 'ว่า', '</s>']]
[[['<s>', '▁พี่', 'ครับ', 'ผมอยาก', '▁รู้', 'ว่า', '</s>']]]
-------------------
MSP: ยังไง => อย่างไร
[['<s>', '▁แล้วถ้า', 'อยาก', 'เชค', 'ยอด', 'การใช้งาน', 'บัตรเครดิต', 'ต้องทํา', 'ยังไง', 'อะคะ', '</s>']]
[[['<s>', '▁แล้วถ้า', 'อยาก', 'เช็ค', 'ยอด', 'การใช้งาน', 'บัตรเครดิต', 'ต้องทํา', 'อย่างไร', 'อะคะ', '</s>']]]
-------------------
MSP: ตัง => สตางค์
[['<s>', '▁', 'สอบถาม', 'หน่อยค่ะ', '▁', 'บางทีก็', 'โอน', 'ตัง', 'ได้', '▁', 'บางทีก็', 'โอน', 'ไม่ได้', '▁', 'เปน', 'เพราะอะไร', 'ค่ะ', '</s>']]
[[['<s>', '▁', 'สอบถาม', 'หน่อยค่ะ', '▁', 'บางทีก็', 'โอน', 'สตางค์', 'ได้', '▁', 