# Create data for Vietnamese text correction task.
This notebook create data for training the Hard-masked XLM-R model for Vietnamese text correction task. The pipeline include the following tasks:
### 1. Build the sentence piece tokenizer 

- The sentencepiece tokenizer is choosed to tokenize text data because it can handle unknown token and has configurable vocabulary size. The tokenizer will be trained on all the available text data. More information about the sentencepiece tokenizer please refer to its [homepage](https://github.com/google/sentencepiece) 

### 2. Create miss-spell word from the original data

- The miss-spell words will be created by a synthesize function, which will base on the commonly mistake in Vietnamese language as homophones words, no diacritics, typos, ... Especially, this function will also cover the not-appropriate word position errors case, so that the Detector network can recognize these mistake. 
- All the data will go through the synthesize function several times, each time with a different objective as followed:
    - 1. Wrong words created by homophones, in attemp to artificially create homophones mistakes.
    - 2. Wrong words created by typos, which will mostly create non-words errors. This type of error will be created by randomly replace character to a randomly chosen word. Typos like random capital letters, ..
    - 3. To recognize the non-appropriate position of word, a new objective is to replace a word with a random word. 

### 3. Transform text data into pytorch tensor 

- Each sample data created by the synthesize function is a tuple (raw_sentence, onehot_label, true_sentence), where raw_sentence is the sentence contain mistake words, onehot_label is the onehot array indicate the position of the wrong words, true_sentence is the original sentence which was used to synthesize data.
- All of these samples will be converted into tensor and cached to disk for model training later.

In [1]:
!pip install sentencepiece
!pip install unidecode
import sentencepiece as spm
import os
import numpy as np
import re
import time
from tqdm.notebook import tqdm
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
import unidecode
import string
from tqdm.notebook import tqdm
from nltk.tokenize.treebank import TreebankWordDetokenizer

nltk.download('punkt')
sentence_tokenizer  =  nltk.data.load('tokenizers/punkt/english.pickle')




Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/14/67/e42bd1181472c95c8cda79305df848264f2a7f62740995a46945d9797b67/sentencepiece-0.1.95-cp36-cp36m-manylinux2014_x86_64.whl (1.2MB)
[K     |████████████████████████████████| 1.2MB 5.8MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.95
Collecting unidecode
[?25l  Downloading https://files.pythonhosted.org/packages/74/65/91eab655041e9e92f948cb7302e54962035762ce7b518272ed9d6b269e93/Unidecode-1.1.2-py2.py3-none-any.whl (239kB)
[K     |████████████████████████████████| 245kB 5.7MB/s 
[?25hInstalling collected packages: unidecode
Successfully installed unidecode-1.1.2
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
import logging
def init_logger():
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)
init_logger()
logger = logging.getLogger(__name__)

## Build the sentencepiece tokenizer

In [None]:
spm.SentencePieceTrainer.train('--input="/content/drive/MyDrive/nlp_projects/Text_correction/all_sentences.txt" --model_prefix=spm_tokenizer --vocab_size=10000')

- remove diacritics
- replace with random letters
- replace with homophones words
- replace with homophones letters
- replace with teencode
- replace with random words



## Create wrong word at word level

In [None]:
import numpy as np
import re
import unidecode
class SynthesizeData(object):
    """
    Uitils class to create artificial miss-spelled words
    Args: 
        vocab_path: path to vocab file. Vocab file is expected to be a set of words, separate by ' ', no newline charactor.
    """
    def __init__(self, vocab_path,  ):

        self.vocab = open(vocab_path, 'r', encoding = 'utf-8').read().split()
        self.tokenizer = word_tokenize
        self.word_couples = [ ['sương', 'xương'], ['sĩ', 'sỹ'], ['sẽ', 'sẻ'], ['sã', 'sả'], ['sả', 'xả'], ['sẽ', 'sẻ'], ['mùi', 'muồi'], 
                        ['chỉnh', 'chỉn'], ['sữa', 'sửa'], ['chuẩn', 'chẩn'], ['lẻ', 'lẽ'], ['chẳng', 'chẵng'], ['cổ', 'cỗ'], 
                        ['sát', 'xát'], ['cập', 'cặp'], ['truyện', 'chuyện'], ['xá', 'sá'], ['giả', 'dả'], ['đỡ', 'đở'], 
                        ['giữ', 'dữ'], ['giã', 'dã'], ['xảo', 'sảo'], ['kiểm', 'kiễm'], ['cuộc', 'cục'], ['dạng', 'dạn'], 
                        ['tản', 'tảng'], ['ngành', 'nghành'], ['nghề', 'ngề'], ['nổ', 'nỗ'], ['rảnh', 'rãnh'], ['sẵn', 'sẳn'], 
                        ['sáng', 'xán'], ['xuất', 'suất'], ['suôn', 'suông'], ['sử', 'xử'], ['sắc', 'xắc'], ['chữa', 'chửa'], 
                        ['thắn', 'thắng'], ['dỡ', 'dở'], ['trải', 'trãi'], ['trao', 'trau'], ['trung', 'chung'], ['thăm', 'tham'], 
                        ['sét', 'xét'], ['dục', 'giục'], ['tả', 'tã'],['sông', 'xông'], ['sáo', 'xáo'], ['sang', 'xang'], 
                        ['ngã', 'ngả'], ['xuống', 'suống'], ['xuồng', 'suồng'] ]


        self.vn_alphabet = ['a',  'ă', 'â', 'b' ,'c', 'd', 'đ', 'e','ê','g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'ô', 'ơ', 'p', 'q', 'r', 's', 't', 'u', 'ư', 'v', 'x', 'y']
        self.alphabet_len = len(self.vn_alphabet)
        self.char_couples = [['i', 'y'], ['s', 'x'], ['gi', 'd'],
                ['ă', 'â'], ['ch', 'tr'], ['ng', 'n'], 
                ['nh', 'n'], ['ngh', 'ng'], ['ục', 'uộc'], ['o', 'u'], 
                ['ă', 'a'], ['o', 'ô'], ['ả', 'ã'], ['ổ', 'ỗ'], ['ủ', 'ũ'], ['ễ', 'ể'], 
                ['e', 'ê'], ['à', 'ờ'], ['ằ', 'à'], ['ẩn', 'uẩn'],  ['ẽ', 'ẻ'], ['ùi', 'uồi'], ['ă', 'â'], ['ở', 'ỡ'], ['ỹ', 'ỷ'], ['ỉ', 'ĩ'], ['ị', 'ỵ'],
                ['ấ', 'á'],['n', 'l'], ['qu', 'w'], ['ph', 'f'], ['d', 'z'], ['c', 'k'], ['qu', 'q'], ['i','j'], ['gi', 'j'], 
                ]

        self.teencode_dict = {'mình': ['mk', 'mik', 'mjk'], 'vô': ['zô', 'zo', 'vo'], 'vậy':['zậy', 'z', 'zay', 'za'] , 'phải': ['fải', 'fai', ], 'biết': ['bit', 'biet'], 
                              'rồi':['rùi', 'ròi', 'r'], 'bây': ['bi', 'bay'], 'giờ': ['h', ], 'không': ['k', 'ko', 'khong', 'hk', 'hong', 'hông', '0', 'kg', 'kh', ], 
                              'đi': ['di', 'dj', ], 'gì': ['j', ], 'em': ['e', ], 'được': ['dc', 'đc', ], 'tao': ['t'], 'tôi': ['t'], 'chồng': ['ck'], 'vợ':['vk']

        }

        self.all_word_candidates = self.get_all_word_candidates(self.word_couples)
        self.string_all_word_candidates = ' '.join(self.all_word_candidates)
        self.all_char_candidates = self.get_all_char_candidates( )

    def replace_teencode(self, word):
        candidates = self.teencode_dict.get(word, None)
        if candidates is not None:
            chosen_one = 0
            if len(candidates) > 1:
                chosen_one = np.random.randint(0, len(candidates))
            return candidates[chosen_one]
         
    def replace_word_candidate(self, word):
        """
        Return a homophone word of the input word.
        """
        capital_flag = word[0].isupper()
        word = word.lower()
        if capital_flag and word in self.teencode_dict:
            return self.replace_teencode(word).capitalize()
        elif word in self.teencode_dict:
            return self.replace_teencode(word)

        for couple in self.word_couples:
            for i in range(2):
                if couple[i] == word:
                    if i == 0:
                        if capital_flag:
                            return couple[1].capitalize()
                        else:
                            return couple[1]
                    else:
                        if capital_flag:
                            return couple[0].capitalize()
                        else:
                            return couple[0]

    def replace_char_candidate(self,char):
        """
        return a homophone char/subword of the input char.
        """
        for couple in self.char_couples:
            for i in range(2):
                if couple[i] == char:
                    if i == 0:
                        return couple[1]
                    else:
                        return couple[0]

    def get_all_char_candidates(self, ):
        
        all_char_candidates = []
        for couple in self.char_couples:
            all_char_candidates.extend(couple)
        return all_char_candidates


    def get_all_word_candidates(self, word_couples):

        all_word_candidates = []
        for couple in self.word_couples:
            all_word_candidates.extend(couple)
        return all_word_candidates

    def remove_diacritics(self, text, onehot_label):
        """
        Replace word which has diacritics with the same word without diacritics
        Args: 
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this 
            function only choose the word that do not has onehot label == 1.
        return: a list of word tokens has one word that its diacritics was removed, 
                a list of onehot label indicate the position of words that has been modified.
        """
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx] == unidecode.unidecode(text[idx]) or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label


        onehot_label[idx] = 1
        text[idx] = unidecode.unidecode(text[idx])
        return True, text, onehot_label

    def replace_with_random_letter(self, text, onehot_label):
        """
        Replace, add (or remove) a random letter in a random chosen word with a random letter
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this 
            function only choose the word that do not has onehot label == 1. 
        return: a list of word tokens has one word that has been modified, 
                a list of onehot label indicate the position of words that has been modified.
        """
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation :
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        # replace, add or remove? 0 is replace, 1 is add, 2 is remove
        coin = np.random.choice([0, 1,2])
        if coin == 0:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            replaced = self.vn_alphabet[np.random.randint(0, self.alphabet_len)]
            try:
                text[idx] = re.sub(chosen_letter,replaced , text[idx])
            except:
                return False, text, onehot_label
        elif coin == 1:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            replaced = chosen_letter + self.vn_alphabet[np.random.randint(0, self.alphabet_len)]
            try:
                text[idx] = re.sub(chosen_letter,replaced , text[idx])
            except:
                return False, text, onehot_label
        else:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            try:
                text[idx] = re.sub(chosen_letter, '', text[idx])  
            except:
                return False, text, onehot_label   

        onehot_label[idx] = 1
        return True, text, onehot_label

    def replace_with_homophone_word(self, text, onehot_label):
        """
        Replace a candidate word (if exist in the word_couple) with its homophone. if successful, return True, else False
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this 
            function only choose the word that do not has onehot label == 1. 
        return: True, text, onehot_label if successful replace, else False, text, onehot_label
        """
        # account for the case that the word in the text is upper case but its lowercase match the candidates list
        candidates = []
        for i in range(len(text)):
            if text[i].lower() in self.all_word_candidates or text[i].lower() in self.teencode_dict.keys():
                candidates.append((i, text[i]))
        
        if len(candidates) == 0:
            return False, text, onehot_label

        idx = np.random.randint(0, len(candidates))
        prevent_loop = 0
        while onehot_label[candidates[idx][0]] == 1:
            idx = np.random.choice(np.arange(0, len(candidates)))
            prevent_loop += 1
            if prevent_loop > 5:
                return False, text, onehot_label

        text[candidates[idx][0]] = self.replace_word_candidate(candidates[idx][1])
        onehot_label[candidates[idx][0]] = 1
        return True, text, onehot_label

    def replace_with_homophone_letter(self, text, onehot_label):
        """
        Replace a subword/letter with its homophones
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this 
            function only choose the word that do not has onehot label == 1. 
        return: True, text, onehot_label if successful replace, else False, None, None
        """
        candidates = []
        for i in range(len(text)):
            for char in self.all_char_candidates:
                if re.search(char, text[i]) is not None:
                    candidates.append((i, char))
                    break

        if len(candidates) == 0:

           return False, text, onehot_label
        else:
            idx = np.random.randint(0, len(candidates))
            prevent_loop = 0
            while onehot_label[candidates[idx][0]] == 1:
                idx = np.random.randint(0, len(candidates))
                prevent_loop += 1
                if prevent_loop  > 5:
                    return False, text, onehot_label

            replaced = self.replace_char_candidate(candidates[idx][1])
            text[candidates[idx][0]] = re.sub(candidates[idx][1], replaced, text[candidates[idx][0]] )

            onehot_label[candidates[idx][0]] = 1
            return True, text, onehot_label


    def replace_with_random_word(self,text, onehot_label):
        """
        Replace a random word in text with a random word in vocab.
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this 
            function only choose the word that do not has onehot label == 1. 
        return: True, text, onehot_label if successful replace, else False, text, onehot_label

        """
        idx = np.random.randint(0, len(text))
        prevent_loop  = 0
        # the idx must not be an already modify token, punctuation or number.
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(text))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        chosen_idx = np.random.randint(0, len(self.vocab))
        text[idx] = self.vocab[chosen_idx]

        onehot_label[idx] = 1
        return True, text, onehot_label


    def create_wrong_word(self, text, mode):
        """
        Function to create miss-spelled words and its label.
        Args: 
            text: One sentence of text.
            mode: which type of error to create, can be: random_word, homophone_char,
            random_letter, remove_diacritics, 
                                                        
        return: raw_text, onehot_label, text_label
        """
        text = self.tokenizer(text)
        text_label = text.copy()
        onehot_label = [0]*len(text)
        num_wrong = int(np.round(0.15*len(text)))
        
        if mode == 'random_word':
            for i in range(0, num_wrong):
                _, text, onehot_label = self.replace_with_random_word(text, onehot_label) 
                if not _:
                    #logger.info('False to create wrong word with random word!')
                    return False, (text, onehot_label, text_label)

                    
            return True, (text, onehot_label, text_label)

        elif mode == 'homophone_char':
            for i in range(0, num_wrong):
                _, text, onehot_label = self.replace_with_homophone_letter(text, onehot_label)
                if not _:
                    #logger.info('False to create wrong word with homophone_char')  
                    return False, (onehot_label, text_label)   
            return True, (text, onehot_label, text_label)    

        elif mode == 'homophone_word':
            for i in range(0, num_wrong):
                _, text, onehot_label = self.replace_with_homophone_word(text, onehot_label)
                if not _:
                    #logger.info('False to create wrong word with homophone_word, remove with random_word')
                    #_, text, onehot_label = self.replace_with_random_word(text, onehot_label) 
                    return False, (text, onehot_label, text_label)
            return True, (text, onehot_label, text_label)

        elif mode == 'random_letter':
            for i in range(0, num_wrong):
                _, text, onehot_label = self.replace_with_random_letter(text, onehot_label)
                if not _:
                    #logger.info('False to create wrong word with random_letter, remove with random_word') 
                    return False, (text, onehot_label, text_label) 
            return True, (text, onehot_label, text_label)         

        elif mode == 'remove_diacritics':
            for i in range(0 ,num_wrong):
                _, text, onehot_label = self.remove_diacritics(text, onehot_label)
                if not _:
                    #logger.info('False to create wrong word with remove_diacritics, remove with random_word')
                    return False, (text, onehot_label, text_label)
            return True, (text, onehot_label, text_label)




In [None]:
s = ['con sáo sang sông mắc phải cành cây ngã xuống sông',
     'Sơn La lần đầu dán tem đào vườn dân trồng, phân biệt với đào rừng', 
     'Văn bản số 3864 do Phó Chủ tịch UBND huyện Vân Hồ Vũ Thanh Hải ký, trình UBND tỉnh,  cho biết, huyện có 500ha trồng cây đào bán dịp Tết.',
     'Tại xã Lóng Luông có 300ha; xã Vân Hồ trồng 200ha, tất cả đều trồng tập trung trên nương, đồi của người dân sở tại.'
    ]

data = []
for line in s:
    result, tup = synthesizer.create_wrong_word(line, mode = 'remove_diacritics')
    if result:
        data.append(tup)

In [None]:
all_sentences = open('/content/drive/MyDrive/nlp_projects/Text_correction/all_sentences.txt', 'r', encoding = 'utf-8').readlines()


In [None]:
len(all_sentences)

3033438

In [None]:

# vocab_path = '/content/drive/MyDrive/nlp_projects/Text_correction/random_vocab.txt'
synthesizer = SynthesizeData(vocab_path = vocab_path)

random_word_data = []
for line in tqdm(all_sentences):
    _, r = synthesizer.create_wrong_word(line, mode = 'random_word')
    if _:
        random_word_data.append(r)


HBox(children=(FloatProgress(value=0.0, max=3033438.0), HTML(value='')))




In [19]:
def write_data(data_path, data):
    tree = TreebankWordDetokenizer()
    with open(os.path.join(data_path, 'rawtext.txt'), 'w', encoding = 'utf-8') as f:
        for line in tqdm(data, desc = 'Writing rawtext file...'):
            f.write(tree.detokenize(line[0]) + '\n')
    with open(os.path.join(data_path, 'onehot_label.txt'), 'w', encoding = 'utf-8') as f:
        for line in tqdm(data, desc = 'Writing onehot_label file...'):
            f.write(' '.join([str(x) for x in line[1]]) + '\n')
    with open(os.path.join(data_path, 'text_label.txt'), 'w', encoding = 'utf-8') as f:
        for line in tqdm(data, desc = 'Writing text_label file...'):
            f.write(tree.detokenize(line[2]) + '\n')


In [None]:
data_path = '/content/drive/MyDrive/nlp_projects/Text_correction/all_data/random_word_data'
write_data(data_path, random_word_data)

HBox(children=(FloatProgress(value=0.0, description='Writing rawtext file...', max=3031667.0, style=ProgressSt…




HBox(children=(FloatProgress(value=0.0, description='Writing onehot_label file...', max=3031667.0, style=Progr…




HBox(children=(FloatProgress(value=0.0, description='Writing text_label file...', max=3031667.0, style=Progres…




## Create Train test data

In [2]:
def get_data(data_path, num_train, num_dev):
    raw_text = []
    onehot_label = []
    text_label = []
    with open(os.path.join(data_path, 'rawtext.txt'), 'r', encoding = 'utf-8') as f:
        for line in f:
            raw_text.append(line.rstrip())
    with open(os.path.join(data_path, 'onehot_label.txt'), 'r') as f:
        for line in f:
            onehot_label.append([int(x) for x in  line.rstrip().split()])

    with open(os.path.join(data_path, 'text_label.txt'), 'r', encoding = 'utf-8') as f:
        for line in f:
            text_label.append(line.rstrip())

    assert len(raw_text) == len(onehot_label) == len(text_label), 'Error: len data does not match'
    total = num_train + num_dev
    indices = np.random.randint(0, len(onehot_label), total)
    train_data  = []
    dev_data = []
    for idx in indices[:num_train]:
        train_data.append((raw_text[idx], onehot_label[idx], text_label[idx]))
    for idx in indices[num_train:]:
        dev_data.append((raw_text[idx], onehot_label[idx], text_label[idx]))
    
    return train_data, dev_data
        



In [3]:
#random_word: 0.3, random_letter: 0.2, homophone_word: 0.1, remove_diacritcs: 0.2, homophone_char: 0.2
data_path = '/content/drive/MyDrive/nlp_projects/Text_correction/all_data/homophone_char_data'
train_data, dev_data = get_data(data_path, 400000, 4000)

In [6]:
data_path = '/content/drive/MyDrive/nlp_projects/Text_correction/all_data/homophone_word_data'
new_train, new_test = get_data(data_path, 200000, 2000)

In [16]:
train_data.extend(new_train)
dev_data.extend(new_test)

In [9]:
data_path = '/content/drive/MyDrive/nlp_projects/Text_correction/all_data/random_letter_data'
new_train, new_test = get_data(data_path, 400000, 4000)

In [13]:
data_path = '/content/drive/MyDrive/nlp_projects/Text_correction/all_data/random_word_data'
new_train, new_test = get_data(data_path, 600000, 6000)

In [15]:
data_path = '/content/drive/MyDrive/nlp_projects/Text_correction/all_data/removed_diacritics_data'
new_train, new_test = get_data(data_path, 400000, 4000)

In [23]:
#Write train and dev data
def write_train_test(data_path, data):
    with open(os.path.join(data_path, 'rawtext.txt'), 'w', encoding = 'utf-8') as f:
        for line in tqdm(data):
            f.write(line[0] + '\n')
    with open(os.path.join(data_path, 'onehot_label.txt'), 'w') as f:
        for line in tqdm(data):
            f.write(' '.join([str(x) for x in line[1]]) + '\n')
    with open(os.path.join(data_path, 'text_label.txt'), 'w', encoding = 'utf-8') as f:
        for line in tqdm(data):
            f.write(line[2] + '\n')

    

In [25]:
data_path = '/content/drive/MyDrive/nlp_projects/Text_correction/all_data/dev_data'
write_train_test(data_path, dev_data)

HBox(children=(FloatProgress(value=0.0, max=20000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20000.0), HTML(value='')))


