In [1]:
!pip install transformers==4.15.0 sentencepiece
!pip install datasets==1.17.0 
!pip install pythainlp
!pip install jiwer
!pip install editdistance

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.15.0
  Downloading transformers-4.15.0-py3-none-any.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 16.2 MB/s 
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 59.2 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 58.3 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 58.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.9.1-py3-none-any.whl (120 kB)
[K     |████████████████████████████████| 120 kB 46.0 MB/s 
Building wheels for 

In [2]:
from transformers import AutoModelForMaskedLM, pipeline
from transformers import AutoTokenizer, BertForTokenClassification, AutoModel
import pandas as pd
import torch
import pickle
from tqdm import tqdm
from datasets import load_metric
from pythainlp.benchmarks import word_tokenization
from nltk.translate.gleu_score import sentence_gleu
from jiwer import cer
import numpy as np
import editdistance

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [3]:
tokenizer = pickle.load(open('drive/MyDrive/AIBuilders/tpth/tokenizer_40k_nova.pkl', 'rb'))

In [4]:
class BertModel(torch.nn.Module):
    def __init__(self):
        super(BertModel, self).__init__()
        self.bert = BertForTokenClassification.from_pretrained('airesearch/wangchanberta-base-att-spm-uncased', num_labels=2)
        self.bert.resize_token_embeddings(len(tokenizer))

    def forward(self, input_id, mask, label):
        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
        return output

In [5]:
tagging_model = BertModel()

FILE = "drive/MyDrive/AIBuilders/tpth/tagging_tpth_200.pth"
tagging_model.load_state_dict(torch.load(FILE, map_location=torch.device('cpu')))
tagging_model.eval()

Downloading:   0%|          | 0.00/546 [00:00<?, ?B/s]

You are using a model of type camembert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.


Downloading:   0%|          | 0.00/404M [00:00<?, ?B/s]

Some weights of the model checkpoint at airesearch/wangchanberta-base-att-spm-uncased were not used when initializing BertForTokenClassification: ['roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'lm_head.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'lm_head.layer_norm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.9.atten

BertModel(
  (bert): BertForTokenClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(33660, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (Lay

In [6]:
ids_to_labels = {0: 'f', 1: 'i'}

def evaluate_one_text(model, sentence, mask, labels):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
        model = model.cuda()

    input_id = torch.Tensor([sentence]).type(torch.int64)
    label_ids = []
    for i in sentence:
      if i == 1 or i == 5 or i == 6:
        label_ids.append(-100)
      else:
        label_ids.append(2)
    label_ids = torch.Tensor([label_ids]).type(torch.int64)
    mask = torch.Tensor([mask]).type(torch.int64)

    logits = tagging_model(input_id, mask, None)
    logits_clean = logits[0][label_ids != -100]

    predictions = logits_clean.argmax(dim=1).tolist()
    prediction_label = [ids_to_labels[i] for i in predictions]
    return prediction_label

In [7]:
mlm_model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
mlm_model.resize_token_embeddings(len(tokenizer))

FILE = "drive/MyDrive/AIBuilders/tpth/mlm_tpth_6.pth"
mlm_model.load_state_dict(torch.load(FILE, map_location=torch.device('cpu')))
mlm_model.eval()

CamembertForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(33660, 768)
      (position_embeddings): Embedding(512, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps

In [8]:
ds_tag = pickle.load(open('drive/MyDrive/AIBuilders/tpth/ner_ds_40k_nova.pkl', 'rb'))
ds_mlm = pickle.load(open('drive/MyDrive/AIBuilders/tpth/mlm_ds_40k_nova.pkl', 'rb'))

In [9]:
ds_mlm_test = ds_mlm[37890:]
ds_mlm = ds_mlm[:37890]
ds_mlm_train, ds_mlm_val = np.split(ds_mlm.sample(frac=1, random_state=42),
                            [int(.9 * len(ds_mlm))])

ds_tag_test = ds_tag[37890:]
ds_tag = ds_tag[:37890]
ds_tag_train, ds_tag_val = np.split(ds_tag.sample(frac=1, random_state=42),
                            [int(.9 * len(ds_tag))])

NUM_SAMPLE = ds_mlm_val.shape[0]

In [10]:
ds_tag_val

Unnamed: 0,text,labels
37250,"{'input_ids': [[tensor(5), tensor(2700), tenso...","[tensor(0), tensor(0), tensor(0), tensor(0), t..."
27081,"{'input_ids': [[tensor(5), tensor(10), tensor(...","[tensor(0), tensor(0), tensor(0), tensor(0), t..."
14908,"{'input_ids': [[tensor(5), tensor(10), tensor(...","[tensor(0), tensor(0), tensor(0), tensor(0), t..."
9306,"{'input_ids': [[tensor(5), tensor(10), tensor(...","[tensor(0), tensor(0), tensor(0), tensor(1), t..."
33724,"{'input_ids': [[tensor(5), tensor(8690), tenso...","[tensor(0), tensor(0), tensor(0), tensor(0), t..."
...,...,...
16850,"{'input_ids': [[tensor(5), tensor(10), tensor(...","[tensor(0), tensor(0), tensor(0), tensor(0), t..."
6265,"{'input_ids': [[tensor(5), tensor(6936), tenso...","[tensor(0), tensor(0), tensor(0), tensor(0), t..."
11284,"{'input_ids': [[tensor(5), tensor(206), tensor...","[tensor(0), tensor(0), tensor(0), tensor(0), t..."
860,"{'input_ids': [[tensor(5), tensor(10167), tens...","[tensor(0), tensor(0), tensor(0), tensor(0), t..."


In [11]:
ds_mlm_val

Unnamed: 0,input_ids,attention_mask,labels
37250,"[5, 2700, 2432, 9892, 10, 89, 13348, 10, 1149,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 2700, 2432, 9892, 10, 89, 13348, 10, 1149,..."
27081,"[5, 10, 13054, 48, 4174, 52, 381, 17, 26489, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 10, 13054, 48, 4174, 52, 381, 17, 26489, 1..."
14908,"[5, 10, 2004, 12609, 3879, 1204, 3590, 4300, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 10, 2004, 12609, 3879, 1204, 3590, 4300, 1..."
9306,"[5, 10, 1417, 25004, 25004, 145, 478, 1809, 25...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 10, 1417, 4618, 265, 145, 478, 1809, 627, ..."
33724,"[5, 8690, 26983, 2168, 12504, 3162, 18958, 20,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 8690, 26983, 2168, 12504, 3162, 18958, 20,..."
...,...,...,...
16850,"[5, 10, 5168, 29429, 10, 182, 181, 6148, 5089,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 10, 5168, 29429, 10, 182, 181, 6148, 5089,..."
6265,"[5, 6936, 10, 74, 10, 551, 5492, 7630, 15207, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 6936, 10, 74, 10, 551, 5492, 7630, 15207, ..."
11284,"[5, 206, 303, 2570, 690, 43, 12639, 25004, 66,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 206, 303, 2570, 690, 43, 12639, 28469, 66,..."
860,"[5, 10167, 1105, 11, 85, 723, 894, 25004, 2500...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 10167, 1105, 11, 85, 723, 894, 881, 265, 5..."


In [12]:
msp_word = pickle.load(open('drive/MyDrive/AIBuilders/tpth/tag_val.pkl', 'rb'))

In [13]:
msp_word

[[('พี่ ๆ', 'พี่ๆ'), ('ดี ๆ', 'ดีดี')],
 [('กลาง ๆ', 'กลางๆ'), ('ท้าย ๆ', 'ท้ายๆ'), ('ต้อง', 'ตัอง')],
 [('นะ', 'น้า')],
 [('ศาสตร์', 'สาด'),
  ('หนัก ๆ', 'หนักๆ'),
  ('โปรโมชัน', 'โปร'),
  ('โปรโมชัน', 'โปร')],
 [('หลาย ๆ', 'หลายๆ'), ('แก้ไข', 'แก้ไร'), ('มาก ๆ', 'มากๆ'), ('อะไร', 'ไร')],
 [('จริง ๆ', 'จริงๆ'), ('อลังการ', 'อลัง')],
 [('น้อง', 'ณ๊อง'),
  ('พิกกี', 'พิคกี้'),
  ('พิกกี', 'พิคคี้'),
  ('ข้าง ๆ', 'ข้างๆ'),
  ('เก๋ ๆ', 'เก๋ๆ'),
  ('จ๊ะ', 'จ้า'),
  ('นะ', 'น๊า')],
 [('ใคร ๆ', 'ใครๆ')],
 [('จริง ๆ', 'จริงๆ'), ('สาว ๆ', 'สาวๆ'), ('นะ', 'น้า')],
 [('แป๊บ', 'แปป'), ('มาก ๆ', 'มากๆ'), ('อย่างไร', 'ยังไง')],
 [('มิ.ย.', 'มิย'),
  ('บวก ๆ', 'บวกๆ'),
  ('เวอร์ชัน', 'เวอร์ชั่น'),
  ('มาก ๆ', 'มากๆ')],
 [('คอนเซ็ปต์', 'คอนเซป'),
  ('มีนาคม', 'มีนา'),
  ('แบตเตอรี', 'แบต'),
  ('เต็ม ๆ', 'เต็มๆ')],
 [('มาก ๆ', 'มากๆ'), ('ปจด.', 'ปจด')],
 [('คอนเสิร์ต', 'คอน'), ('คอนเสิร์ต', 'คอน')],
 [('โปรโมชัน', 'โปรโมชั่น'),
  ('ว่าง ๆ', 'ว่างว่าง'),
  ('อุดหนุน', 'อุหนุน'),
  ('นะคะ', 'นะค่ะ')],
 

In [14]:
def ids_to_tokens(tokenized_text):
  a = tokenizer.convert_ids_to_tokens(tokenized_text)
  # a.remove("<s>")
  # a.remove("</s>")
  # if a[0] == '▁':
  #   a.pop(0)
  return a

In [None]:
for CED in range(1,11):
  print(f"################ CED {CED} ###################")

  # NUM_SAMPLE = 3
  bleu = load_metric("bleu")
  totalacc = 0
  bleu_higher = 0
  bleu_lower = 0
  bleu_equal = 0
  total_f1 = 0
  total_gs_ori = 0
  total_gs_pred = 0
  total_cer_ori = 0
  total_cer_pred = 0

  for sent_id in tqdm(range(1000)):
    chng = []
    text = ds_tag.iloc[sent_id]['text']['input_ids'].squeeze(0).tolist()
    mask = ds_mlm.iloc[sent_id]['attention_mask']
    labels = ds_mlm.iloc[sent_id]['labels']
    text = [k for k in text if k != 1]
    labels = [k for k in labels if k != 1]
    mask = [k for k in mask if k != 0]
    original = [ids_to_tokens(text)]
    references = [[ids_to_tokens(labels)]]
    moi = ''.join(original[0])
    moi = moi.replace("▁", " ")
    # print(f"TEXT: {moi}")
    

    i_f = evaluate_one_text(tagging_model, text, mask, labels)
    predicted_id = text.copy()
    original_id = text.copy()
    i_f_len = len(i_f)
    for j in range(i_f_len):
      if(i_f[j] == 'i'):
        ph = predicted_id[j+1]

        predicted_id[j+1] = 25004
        mlm_input = {'input_ids': torch.Tensor([predicted_id]).type(torch.int64).to(device), 'attention_mask': torch.Tensor([mask]).type(torch.int64).to(device)}
        token_logits = mlm_model(**mlm_input).logits
        mask_token_index = torch.where(mlm_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_200_tokens = torch.topk(mask_token_logits, 200, dim=1).indices[0].tolist()
        
        CED_flag = 0
        original_token = tokenizer.convert_ids_to_tokens(ph)
        for k in range(200):
          predicted_token = tokenizer.convert_ids_to_tokens(top_200_tokens[k])
          # print(f"Compare: {predicted_token}, {original_token}")
          CED_score = editdistance.eval(predicted_token, original_token)
          if CED_score <= CED:
            chng.append((j, top_200_tokens[k])) 
            CED_flag = 1   
            break
        if CED_flag == 0:
          chng.append((j, ph)) 

        predicted_id[j+1] = ph

    for x,y in chng:
      predicted_id[x+1] = y

    numer = 0
    denom = 0
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    labels_len = len(labels)
    for i in range(labels_len):
      if not predicted_id[i] == original_id[i]: #change
        denom += 1
        if predicted_id[i] == labels[i]:
          numer += 1
          TP += 1
        elif not predicted_id[i] == labels[i]:
          FP += 1
      elif predicted_id[i] == original_id[i]: #no change
        if predicted_id[i] == labels[i]:
          TN += 1
        elif not predicted_id[i] == labels[i]:
          FN += 1
    if denom == 0:
      acc = 0
    else:
      acc = float(numer)/float(denom)
    totalacc += acc

    # print(f"TP:{TP}   TN:{TN}   FP:{FP}   FN:{FN}")
    precision = float(TP) / float(TP+FP) if TP+FP > 0 else 0
    recall = float(TP) / float(TP+FN) if TP+FN > 0 else 0
    f1 = float(2*precision*recall) / float(precision + recall) if precision+recall > 0 else 0
    total_f1 += f1
    

    ans = tokenizer.convert_ids_to_tokens(predicted_id)
    # ans.remove('<s>')
    # ans.remove('</s>')
    if ans[0] == '▁':
      ans.pop(0)
    predictions = [ans]  
    ans = ''.join(ans)
    ans = ans.replace("▁", " ")

    references[0][0] = [k.replace("▁", " ") for k in references[0][0]]
    original[0] = [k.replace("▁", " ") for k in original[0]]
    predictions[0] = [k.replace("▁", " ") for k in predictions[0]]

    bleu_original = bleu.compute(predictions=original, references=references)
    bleu_prediction = bleu.compute(predictions=predictions, references=references)
    
    if bleu_prediction['bleu'] > bleu_original['bleu']:
      bleu_higher += 1
    elif bleu_prediction['bleu'] < bleu_original['bleu']:
      bleu_lower += 1
    elif bleu_prediction['bleu'] == bleu_original['bleu']:
      bleu_equal += 1

    # print(references)
    # print(predictions)
    # print(original)
    gleu_score = sentence_gleu(references[0], predictions[0], min_len=1, max_len=4)
    total_gs_pred += gleu_score
    gleu_score_original = sentence_gleu(references[0], original[0], min_len=1, max_len=4)
    total_gs_ori += gleu_score_original

    cer_text_ori = "".join(original[0])
    cer_text_pred = "".join(predictions[0])
    cer_text_ref = "".join(references[0][0])
    cer_text_ori = cer_text_ori.replace("_", "")
    cer_text_ori = cer_text_ori.replace("▁", "")
    cer_text_pred = cer_text_pred.replace("_", "")
    cer_text_pred = cer_text_pred.replace("▁", "")
    cer_text_ref = cer_text_ref.replace("_", "")
    cer_text_ref = cer_text_ref.replace("▁", "")
    # print(f"ORI: {cer_text_ori}")
    # print(f"PRED: {cer_text_pred}")
    # print(f"REF: {cer_text_ref}")
    cer_original = cer(cer_text_ori, cer_text_ref)
    total_cer_ori += cer_original
    cer_pred = cer(cer_text_pred, cer_text_ref)
    total_cer_pred += cer_pred

    # print(f"GLUE PREDICTED: {gleu_score}")
    # print(f"GLUE ORIGINAL: {gleu_score_original}")
    # print(f"CER PREDICTED: {cer_pred}")
    # print(f"CER ORIGINAL: {cer_original}")

    # print(ans)
    # print(f"ACC: {acc}")
    # print(f"f1: {f1}")
    # print(f"ORIGINAL : {bleu_original}")
    # print(f"PREDICTED: {bleu_prediction}")
    # print("--------------------------------------------")
  print(f"AVG ACC: {float(totalacc/NUM_SAMPLE)}")
  print(f"AVG f1: {float(total_f1/NUM_SAMPLE)}")
  print(f"# HIGHER BLEU PREDICTION: {bleu_higher}")
  print(f"# LOWER BLEU PREDICTION: {bleu_lower}")
  print(f"# EQUAL BLEU PREDICTION: {bleu_equal}")
  print(f"# GLEU PREDICTION: {float(total_gs_pred/NUM_SAMPLE)}")
  print(f"# GLEU ORIGINAL: {float(total_gs_ori/NUM_SAMPLE)}")
  print(f"# CER PREDICTION: {float(total_cer_pred/NUM_SAMPLE)}")
  print(f"# CER ORIGINAL: {float(total_cer_ori/NUM_SAMPLE)}")

################ CED 1 ###################


100%|██████████| 1000/1000 [38:35<00:00,  2.32s/it]


AVG ACC: 0.09588453759079249
AVG f1: 0.06644790852892535
# HIGHER BLEU PREDICTION: 449
# LOWER BLEU PREDICTION: 203
# EQUAL BLEU PREDICTION: 348
# GLEU PREDICTION: 0.23293897774309705
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.0071440804250438115
# CER ORIGINAL: 0.007548613945541157
################ CED 2 ###################


100%|██████████| 1000/1000 [38:38<00:00,  2.32s/it]


AVG ACC: 0.1040929518698186
AVG f1: 0.0863331564964697
# HIGHER BLEU PREDICTION: 486
# LOWER BLEU PREDICTION: 253
# EQUAL BLEU PREDICTION: 261
# GLEU PREDICTION: 0.23417377188984317
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.007147322810350754
# CER ORIGINAL: 0.007548613945541157
################ CED 3 ###################


100%|██████████| 1000/1000 [38:43<00:00,  2.32s/it]


AVG ACC: 0.08229047995136424
AVG f1: 0.07367294274462917
# HIGHER BLEU PREDICTION: 430
# LOWER BLEU PREDICTION: 284
# EQUAL BLEU PREDICTION: 286
# GLEU PREDICTION: 0.23263887307373635
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.008612794658438368
# CER ORIGINAL: 0.007548613945541157
################ CED 4 ###################


100%|██████████| 1000/1000 [38:36<00:00,  2.32s/it]


AVG ACC: 0.07020939618422596
AVG f1: 0.06602067589252178
# HIGHER BLEU PREDICTION: 386
# LOWER BLEU PREDICTION: 329
# EQUAL BLEU PREDICTION: 285
# GLEU PREDICTION: 0.2312079411279323
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.010260014917576754
# CER ORIGINAL: 0.007548613945541157
################ CED 5 ###################


100%|██████████| 1000/1000 [38:41<00:00,  2.32s/it]


AVG ACC: 0.06326691649509039
AVG f1: 0.06006497669422556
# HIGHER BLEU PREDICTION: 354
# LOWER BLEU PREDICTION: 341
# EQUAL BLEU PREDICTION: 305
# GLEU PREDICTION: 0.23047528703159822
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.011190859469563542
# CER ORIGINAL: 0.007548613945541157
################ CED 6 ###################


100%|██████████| 1000/1000 [38:29<00:00,  2.31s/it]


AVG ACC: 0.06082063766322169
AVG f1: 0.05838060982674757
# HIGHER BLEU PREDICTION: 347
# LOWER BLEU PREDICTION: 349
# EQUAL BLEU PREDICTION: 304
# GLEU PREDICTION: 0.2301860394696734
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.01170315355063429
# CER ORIGINAL: 0.007548613945541157
################ CED 7 ###################


100%|██████████| 1000/1000 [38:38<00:00,  2.32s/it]


AVG ACC: 0.05923067386027292
AVG f1: 0.05683291759406664
# HIGHER BLEU PREDICTION: 342
# LOWER BLEU PREDICTION: 355
# EQUAL BLEU PREDICTION: 303
# GLEU PREDICTION: 0.22994937061909926
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.01236770526986985
# CER ORIGINAL: 0.007548613945541157
################ CED 8 ###################


100%|██████████| 1000/1000 [39:00<00:00,  2.34s/it]


AVG ACC: 0.05865884312217141
AVG f1: 0.05645965665073445
# HIGHER BLEU PREDICTION: 341
# LOWER BLEU PREDICTION: 357
# EQUAL BLEU PREDICTION: 302
# GLEU PREDICTION: 0.2299240655681028
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.012477612330357537
# CER ORIGINAL: 0.007548613945541157
################ CED 9 ###################


100%|██████████| 1000/1000 [39:01<00:00,  2.34s/it]


AVG ACC: 0.058586578798125614
AVG f1: 0.05642041703413387
# HIGHER BLEU PREDICTION: 341
# LOWER BLEU PREDICTION: 357
# EQUAL BLEU PREDICTION: 302
# GLEU PREDICTION: 0.22990711944535483
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.012551043912502408
# CER ORIGINAL: 0.007548613945541157
################ CED 10 ###################


100%|██████████| 1000/1000 [38:50<00:00,  2.33s/it]

AVG ACC: 0.058586578798125614
AVG f1: 0.05642041703413387
# HIGHER BLEU PREDICTION: 341
# LOWER BLEU PREDICTION: 357
# EQUAL BLEU PREDICTION: 302
# GLEU PREDICTION: 0.22990711944535483
# GLEU ORIGINAL: 0.22999332709811493
# CER PREDICTION: 0.012558615630287439
# CER ORIGINAL: 0.007548613945541157





In [18]:
msp_word_dict = {}
msp_word_dict_full = {}
msp_word_dict_fill = {}

# NUM_SAMPLE = 5000
NUM_SAMPLE = ds_mlm_test.shape[0]
# for i in range(len_msp_type):
for i in range(NUM_SAMPLE):  
    len_b = len(msp_word[i])
    for j in range(len_b):
      x = msp_word[i][j]
      if not x in msp_word_dict:
        msp_word_dict[x] = 1
        msp_word_dict_full[x] = 0
        msp_word_dict_fill[x] = 0
      else:
        msp_word_dict[x] += 1

msp_word_dict_wrong = msp_word_dict_fill.copy()
msp_word_dict = dict(sorted(msp_word_dict.items(), key=lambda item: item[1], reverse=True))
print(msp_word_dict)
print(msp_word_dict_fill)
print(msp_word_dict_full)

# NUM_SAMPLE = 3
bleu = load_metric("bleu")
totalacc = 0
bleu_higher = 0
bleu_lower = 0
bleu_equal = 0
total_f1 = 0
total_gs_ori = 0
total_gs_pred = 0
total_cer_ori = 0
total_cer_pred = 0

for sent_id in tqdm(range(NUM_SAMPLE)):
  chng = []
  text = ds_tag_test.iloc[sent_id]['text']['input_ids'].squeeze(0).tolist()
  mask = ds_mlm_test.iloc[sent_id]['attention_mask']
  labels = ds_mlm_test.iloc[sent_id]['labels']
  text = [k for k in text if k != 1]
  labels = [k for k in labels if k != 1]
  mask = [k for k in mask if k != 0]
  original = [ids_to_tokens(text)]
  references = [[ids_to_tokens(labels)]]
  moi = ''.join(original[0])
  moi = moi.replace("▁", " ")
  # print(f"TEXT: {moi}")
  

  i_f = evaluate_one_text(tagging_model, text, mask, labels)
  predicted_id = text.copy()
  original_id = text.copy()
  i_f_len = len(i_f)
  for j in range(i_f_len):
    if(i_f[j] == 'i'):
      ph = predicted_id[j+1]

      predicted_id[j+1] = 25004
      mlm_input = {'input_ids': torch.Tensor([predicted_id]).type(torch.int64).to(device), 'attention_mask': torch.Tensor([mask]).type(torch.int64).to(device)}
      token_logits = mlm_model(**mlm_input).logits
      mask_token_index = torch.where(mlm_input["input_ids"] == tokenizer.mask_token_id)[1]
      mask_token_logits = token_logits[0, mask_token_index, :]
      top_200_tokens = torch.topk(mask_token_logits, 200, dim=1).indices[0].tolist()
        
      CED_flag = 0
      original_token = tokenizer.convert_ids_to_tokens(ph)
      for k in range(200):
        predicted_token = tokenizer.convert_ids_to_tokens(top_200_tokens[k])
        # print(f"Compare: {predicted_token}, {original_token}")
        CED_score = editdistance.eval(predicted_token, original_token)
        if CED_score <= 2:
          chng.append((j, top_200_tokens[k])) 
          CED_flag = 1   
          break
      if CED_flag == 0:
        chng.append((j, ph)) 

      
      # print(f"{tokenizer.convert_ids_to_tokens(ph)} => {tokenizer.convert_ids_to_tokens(top_5_tokens[0])}")
      # chng.append((j, top_5_tokens[0]))

      predicted_id[j+1] = ph

  for x,y in chng:
    predicted_id[x+1] = y

  numer = 0
  denom = 0
  TP = 0
  FP = 0
  TN = 0
  FN = 0
  labels_len = len(labels)
  for i in range(labels_len):
    if not predicted_id[i] == original_id[i]: #change
      denom += 1
      if predicted_id[i] == labels[i]:
        numer += 1
        TP += 1
      elif not predicted_id[i] == labels[i]:
        FP += 1
    elif predicted_id[i] == original_id[i]: #no change
      if predicted_id[i] == labels[i]:
        TN += 1
      elif not predicted_id[i] == labels[i]:
        FN += 1
  if denom == 0:
    acc = 0
  else:
    acc = float(numer)/float(denom)
  totalacc += acc

  # print(f"TP:{TP}   TN:{TN}   FP:{FP}   FN:{FN}")
  precision = float(TP) / float(TP+FP) if TP+FP > 0 else 0
  recall = float(TP) / float(TP+FN) if TP+FN > 0 else 0
  f1 = float(2*precision*recall) / float(precision + recall) if precision+recall > 0 else 0
  total_f1 += f1
  

  ans = tokenizer.convert_ids_to_tokens(predicted_id)
  # ans.remove('<s>')
  # ans.remove('</s>')
  if ans[0] == '▁':
    ans.pop(0)
  predictions = [ans]  
  ans = ''.join(ans)
  ans = ans.replace("▁", " ")

  references[0][0] = [k.replace("▁", " ") for k in references[0][0]]
  original[0] = [k.replace("▁", " ") for k in original[0]]
  predictions[0] = [k.replace("▁", " ") for k in predictions[0]]

  bleu_original = bleu.compute(predictions=original, references=references)
  bleu_prediction = bleu.compute(predictions=predictions, references=references)
  
  if bleu_prediction['bleu'] > bleu_original['bleu']:
    bleu_higher += 1
  elif bleu_prediction['bleu'] < bleu_original['bleu']:
    bleu_lower += 1
  elif bleu_prediction['bleu'] == bleu_original['bleu']:
    bleu_equal += 1

  # print(references)
  # print(predictions)
  # print(original)
  gleu_score = sentence_gleu(references[0], predictions[0], min_len=1, max_len=4)
  total_gs_pred += gleu_score
  gleu_score_original = sentence_gleu(references[0], original[0], min_len=1, max_len=4)
  total_gs_ori += gleu_score_original

  cer_text_ori = "".join(original[0])
  cer_text_pred = "".join(predictions[0])
  cer_text_ref = "".join(references[0][0])
  cer_text_ori = cer_text_ori.replace("_", "")
  cer_text_ori = cer_text_ori.replace("▁", "")
  cer_text_pred = cer_text_pred.replace("_", "")
  cer_text_pred = cer_text_pred.replace("▁", "")
  cer_text_ref = cer_text_ref.replace("_", "")
  cer_text_ref = cer_text_ref.replace("▁", "")
  # print(f"ORI: {cer_text_ori}")
  # print(f"PRED: {cer_text_pred}")
  # print(f"REF: {cer_text_ref}")
  cer_original = cer(cer_text_ori, cer_text_ref)
  total_cer_ori += cer_original
  cer_pred = cer(cer_text_pred, cer_text_ref)
  total_cer_pred += cer_pred

  for i in range(labels_len):
    if labels[i] != original_id[i]: #change
      word_original = tokenizer.convert_ids_to_tokens(original_id[i])
      word_original = word_original.replace("▁", "")
      word_predicted = tokenizer.convert_ids_to_tokens(predicted_id[i])
      word_predicted = word_predicted.replace("▁", "")
      word_labels = tokenizer.convert_ids_to_tokens(labels[i])
      word_labels = word_labels.replace("▁", "")
      # if (word_original, word_predicted) in msp_word_dict_fill: #have key
      if (word_original, word_labels) in msp_word_dict_fill: #have key
        # msp_word_dict_full[(word_original, word_predicted)] += 1
        msp_word_dict_full[(word_original, word_labels)] += 1
        if labels[i] == predicted_id[i]: #correct
          # msp_word_dict_fill[(word_original, word_predicted)] += 1
          msp_word_dict_fill[(word_original, word_labels)] += 1
        elif labels[i] != predicted_id[i]: #incorrect
          # msp_word_dict_wrong[(word_original, word_predicted)] += 1
          msp_word_dict_wrong[(word_original, word_labels)] += 1

  # print(f"GLUE PREDICTED: {gleu_score}")
  # print(f"GLUE ORIGINAL: {gleu_score_original}")
  # print(f"CER PREDICTED: {cer_pred}")
  # print(f"CER ORIGINAL: {cer_original}")

  # print(ans)
  # print(f"ACC: {acc}")
  # print(f"f1: {f1}")
  # print(f"ORIGINAL : {bleu_original}")
  # print(f"PREDICTED: {bleu_prediction}")
  # print("--------------------------------------------")
print(f"AVG ACC: {float(totalacc/NUM_SAMPLE)}")
print(f"AVG f1: {float(total_f1/NUM_SAMPLE)}")
print(f"# HIGHER BLEU PREDICTION: {bleu_higher}")
print(f"# LOWER BLEU PREDICTION: {bleu_lower}")
print(f"# EQUAL BLEU PREDICTION: {bleu_equal}")
print(f"# GLEU PREDICTION: {float(total_gs_pred/NUM_SAMPLE)}")
print(f"# GLEU ORIGINAL: {float(total_gs_ori/NUM_SAMPLE)}")
print(f"# CER PREDICTION: {float(total_cer_pred/NUM_SAMPLE)}")
print(f"# CER ORIGINAL: {float(total_cer_ori/NUM_SAMPLE)}")
msp_word_dict_full = dict(sorted(msp_word_dict_full.items(), key=lambda item: item[1], reverse=True))
msp_word_dict_fill = dict(sorted(msp_word_dict_fill.items(), key=lambda item: item[1], reverse=True))
msp_word_dict_wrong = dict(sorted(msp_word_dict_wrong.items(), key=lambda item: item[1], reverse=True))
print(msp_word_dict_full)
print(msp_word_dict_fill)
print(msp_word_dict_wrong)

{('จริง ๆ', 'จริงๆ'): 852, ('เขา', 'เค้า'): 767, ('มาก ๆ', 'มากๆ'): 566, ('หนึ่ง', 'นึง'): 437, ('อะ', 'อ่ะ'): 373, ('ฉัน', 'ชั้น'): 235, ('ไหม', 'มั้ย'): 202, ('แล้ว', 'ละ'): 198, ('ดี ๆ', 'ดีๆ'): 188, ('กู', 'กุ'): 175, ('อะไร', 'ไร'): 169, ('นะ', 'อะ'): 154, ('ก็', 'ก้'): 150, ('น้อง ๆ', 'น้องๆ'): 149, ('หลาย ๆ', 'หลายๆ'): 139, ('อย่างไร', 'ยังไง'): 137, ('ต่าง ๆ', 'ต่างๆ'): 134, ('เหรอ', 'หรอ'): 131, ('เรื่อย ๆ', 'เรื่อยๆ'): 122, ('อี', 'อิ'): 121, ('คอนเสิร์ต', 'คอน'): 115, ('สุด ๆ', 'สุดๆ'): 111, ('จ้ะ', 'จ้า'): 102, ('นะ', 'อ่ะ'): 92, ('แน่ ๆ', 'แน่ๆ'): 91, ('กับ', 'กะ'): 87, ('มาก', 'มากกก'): 87, ('เฉย ๆ', 'เฉยๆ'): 85, ('เด็ก ๆ', 'เด็กๆ'): 82, ('มาก', 'มากก'): 81, ('อื่น ๆ', 'อื่นๆ'): 73, ('เยอะ ๆ', 'เยอะๆ'): 73, ('นะ', 'น้า'): 72, ('ทุก ๆ', 'ทุกๆ'): 70, ('ข้าง ๆ', 'ข้างๆ'): 66, ('เล็ก ๆ', 'เล็กๆ'): 63, ('เป็น', 'เปน'): 63, ('แรก ๆ', 'แรกๆ'): 60, ('อย่างไร', 'ไง'): 58, ('ทั้ง ๆ', 'ทั้งๆ'): 58, ('ค่ะ', 'ค่า'): 57, ('หรือ', 'รึ'): 55, ('เธอ', 'เทอ'): 55, ('อย่างนี้', 'งี้'): 54, 

100%|██████████| 5000/5000 [1:53:41<00:00,  1.36s/it]

AVG ACC: 0.33455844662318823
AVG f1: 0.27775351210990973
# HIGHER BLEU PREDICTION: 2257
# LOWER BLEU PREDICTION: 1457
# EQUAL BLEU PREDICTION: 1286
# GLEU PREDICTION: 0.8731252321814156
# GLEU ORIGINAL: 0.8660844114008937
# CER PREDICTION: 0.03555648121930075
# CER ORIGINAL: 0.03541724276517062
{('ทวิต', 'ทวีต'): 65, ('นี้', 'นี่'): 49, ('เดม', 'เด็ม'): 30, ('ว่ะ', 'วะ'): 27, ('โปรโมท', 'โปรโมต'): 27, ('ละ', 'ล่ะ'): 26, ('เอ้ย', 'เอ๊ย'): 26, ('คะ', 'ค่ะ'): 25, ('ค่ะ', 'คะ'): 24, ('เซต', 'เซ็ต'): 24, ('นั้น', 'นั่น'): 18, ('เว่อร์', 'เวอร์'): 18, ('ที', 'ที่'): 17, ('นั่น', 'นั้น'): 13, ('ฟิล', 'ฟีล'): 9, ('อ่อ', 'อ๋อ'): 9, ('จ้ะ', 'จ๊ะ'): 9, ('ที่', 'ที'): 8, ('หละ', 'ล่ะ'): 8, ('จ๊ะ', 'จ้ะ'): 8, ('ม๊า', 'ม้า'): 8, ('ชุ่มชื้น', 'ชุ่มชื่น'): 8, ('สมมุติ', 'สมมติ'): 7, ('วะ', 'ว่ะ'): 7, ('ดีดี', 'ดีๆ'): 6, ('นี่', 'นี้'): 6, ('ชอบ', 'ชอบ'): 5, ('เข้า', 'เขา'): 4, ('บาง', 'บ้าง'): 4, ('ค่อย', 'คอย'): 4, ('มา', 'มาก'): 3, ('ล่ะ', 'ละ'): 3, ('ใช้', 'ใช่'): 3, ('หน้า', 'น่า'): 3, ('ใช่', 'ใช




In [None]:
# NUM_SAMPLE = 3
bleu = load_metric("bleu")
totalacc = 0
bleu_higher = 0
bleu_lower = 0
bleu_equal = 0
total_f1 = 0
total_gs_ori = 0
total_gs_pred = 0
total_cer_ori = 0
total_cer_pred = 0

for sent_id in tqdm(range(NUM_SAMPLE)):
  chng = []
  text = ds_tag.iloc[sent_id]['text']['input_ids'].squeeze(0).tolist()
  mask = ds_mlm.iloc[sent_id]['attention_mask']
  labels = ds_mlm.iloc[sent_id]['labels']
  text = [k for k in text if k != 1]
  labels = [k for k in labels if k != 1]
  mask = [k for k in mask if k != 0]
  original = [ids_to_tokens(text)]
  references = [[ids_to_tokens(labels)]]
  moi = ''.join(original[0])
  moi = moi.replace("▁", " ")
  # print(f"TEXT: {moi}")
  

  i_f = evaluate_one_text(tagging_model, text, mask, labels)
  predicted_id = text.copy()
  original_id = text.copy()
  i_f_len = len(i_f)
  for j in range(i_f_len):
    if(i_f[j] == 'i'):
      ph = predicted_id[j+1]

      predicted_id[j+1] = 25004
      mlm_input = {'input_ids': torch.Tensor([predicted_id]).type(torch.int64).to(device), 'attention_mask': torch.Tensor([mask]).type(torch.int64).to(device)}
      token_logits = mlm_model(**mlm_input).logits
      mask_token_index = torch.where(mlm_input["input_ids"] == tokenizer.mask_token_id)[1]
      mask_token_logits = token_logits[0, mask_token_index, :]
      top_5_tokens = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()
      # print(f"{tokenizer.convert_ids_to_tokens(ph)} => {tokenizer.convert_ids_to_tokens(top_5_tokens[0])}")
      chng.append((j, top_5_tokens[0]))

      predicted_id[j+1] = ph

  for x,y in chng:
    predicted_id[x+1] = y

  numer = 0
  denom = 0
  TP = 0
  FP = 0
  TN = 0
  FN = 0
  labels_len = len(labels)
  for i in range(labels_len):
    if not predicted_id[i] == original_id[i]: #change
      denom += 1
      if predicted_id[i] == labels[i]:
        numer += 1
        TP += 1
      elif not predicted_id[i] == labels[i]:
        FP += 1
    elif predicted_id[i] == original_id[i]: #no change
      if predicted_id[i] == labels[i]:
        TN += 1
      elif not predicted_id[i] == labels[i]:
        FN += 1
  if denom == 0:
    acc = 0
  else:
    acc = float(numer)/float(denom)
  totalacc += acc

  # print(f"TP:{TP}   TN:{TN}   FP:{FP}   FN:{FN}")
  precision = float(TP) / float(TP+FP) if TP+FP > 0 else 0
  recall = float(TP) / float(TP+FN) if TP+FN > 0 else 0
  f1 = float(2*precision*recall) / float(precision + recall) if precision+recall > 0 else 0
  total_f1 += f1
  

  ans = tokenizer.convert_ids_to_tokens(predicted_id)
  # ans.remove('<s>')
  # ans.remove('</s>')
  if ans[0] == '▁':
    ans.pop(0)
  predictions = [ans]  
  ans = ''.join(ans)
  ans = ans.replace("▁", " ")

  references[0][0] = [k.replace("▁", " ") for k in references[0][0]]
  original[0] = [k.replace("▁", " ") for k in original[0]]
  predictions[0] = [k.replace("▁", " ") for k in predictions[0]]

  bleu_original = bleu.compute(predictions=original, references=references)
  bleu_prediction = bleu.compute(predictions=predictions, references=references)
  
  if bleu_prediction['bleu'] > bleu_original['bleu']:
    bleu_higher += 1
  elif bleu_prediction['bleu'] < bleu_original['bleu']:
    bleu_lower += 1
  elif bleu_prediction['bleu'] == bleu_original['bleu']:
    bleu_equal += 1

  # print(references)
  # print(predictions)
  # print(original)
  gleu_score = sentence_gleu(references[0], predictions[0], min_len=1, max_len=4)
  total_gs_pred += gleu_score
  gleu_score_original = sentence_gleu(references[0], original[0], min_len=1, max_len=4)
  total_gs_ori += gleu_score_original

  cer_text_ori = "".join(original[0])
  cer_text_pred = "".join(predictions[0])
  cer_text_ref = "".join(references[0][0])
  cer_text_ori = cer_text_ori.replace("_", "")
  cer_text_ori = cer_text_ori.replace("▁", "")
  cer_text_pred = cer_text_pred.replace("_", "")
  cer_text_pred = cer_text_pred.replace("▁", "")
  cer_text_ref = cer_text_ref.replace("_", "")
  cer_text_ref = cer_text_ref.replace("▁", "")
  # print(f"ORI: {cer_text_ori}")
  # print(f"PRED: {cer_text_pred}")
  # print(f"REF: {cer_text_ref}")
  cer_original = cer(cer_text_ori, cer_text_ref)
  total_cer_ori += cer_original
  cer_pred = cer(cer_text_pred, cer_text_ref)
  total_cer_pred += cer_pred

  # print(f"GLUE PREDICTED: {gleu_score}")
  # print(f"GLUE ORIGINAL: {gleu_score_original}")
  # print(f"CER PREDICTED: {cer_pred}")
  # print(f"CER ORIGINAL: {cer_original}")

  # print(ans)
  # print(f"ACC: {acc}")
  # print(f"f1: {f1}")
  # print(f"ORIGINAL : {bleu_original}")
  # print(f"PREDICTED: {bleu_prediction}")
  # print("--------------------------------------------")
print(f"AVG ACC: {float(totalacc/NUM_SAMPLE)}")
print(f"AVG f1: {float(total_f1/NUM_SAMPLE)}")
print(f"# HIGHER BLEU PREDICTION: {bleu_higher}")
print(f"# LOWER BLEU PREDICTION: {bleu_lower}")
print(f"# EQUAL BLEU PREDICTION: {bleu_equal}")
print(f"# GLEU PREDICTION: {float(total_gs_pred/NUM_SAMPLE)}")
print(f"# GLEU ORIGINAL: {float(total_gs_ori/NUM_SAMPLE)}")
print(f"# CER PREDICTION: {float(total_cer_pred/NUM_SAMPLE)}")
print(f"# CER ORIGINAL: {float(total_cer_ori/NUM_SAMPLE)}")

100%|██████████| 5000/5000 [3:46:40<00:00,  2.72s/it]

AVG ACC: 0.2575590514065182
AVG f1: 0.24176052649338325
# HIGHER BLEU PREDICTION: 2020
# LOWER BLEU PREDICTION: 1722
# EQUAL BLEU PREDICTION: 1258
# GLEU PREDICTION: 0.8685422572985301
# GLEU ORIGINAL: 0.8660844114008937
# CER PREDICTION: 0.056664659746703316
# CER ORIGINAL: 0.03541724276517062



