In [3]:
%pip install tensorflow_text tensorflow_hub

UsageError: Line magic function `%pip3` not found.


In [2]:
import tensorflow as tf
import tensorflow_text as text
import tensorflow_hub as hub
import pandas as pd
import numpy as np
from functools import cache
import re

In [3]:
class TypoDetector:

    def __init__(
            self,
            preprocessor_handle="https://tfhub.dev/tensorflow/bert_zh_preprocess/3",
            encoder_handle="https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/4",
            seq_length: int = 128,
            trainable: bool = True) -> None:
        self.preprocessor_handle = preprocessor_handle
        self.encoder_handle = encoder_handle
        self.seq_length = seq_length
        self.trainable = trainable
        self.preprocessor = hub.load(preprocessor_handle)
        self.encoder = hub.load(encoder_handle)

    @cache
    def get_vocab(self) -> list[str]:
        """
        Get the vocab dictionary from encoder
        :returns: the vocab list
        """
        vocab_filepath = self.encoder.vocab_file.asset_path.numpy().decode(
            "utf-8")
        with open(vocab_filepath, 'r') as f:
            return [vocab[:-1] for vocab in f]

    @cache
    def get_vocab_size(self) -> int:
        """
        Get the length of the vocab list
        :returns: vocab size
        """
        return self.preprocessor.tokenize.get_special_tokens_dict(
        )['vocab_size']

    def get_tokenizer(self) -> hub.KerasLayer:
        """
        Get the tokenizer as the preprocessor of the bert model
        :returns: the tokenizer
        """
        return hub.KerasLayer(self.preprocessor.tokenize)

    def get_model(self) -> tf.keras.Model:
        """
        Compose and return our detector model
        :returns: the detector model
        """
        # input layer
        text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
        # tokenizer layer
        tokenizer = self.get_tokenizer()
        # pack-inputs layer
        bert_pack_inputs = hub.KerasLayer(
            self.preprocessor.bert_pack_inputs,
            arguments={"seq_length": tf.constant(self.seq_length)})
        # mlm layer
        mlm_layer = hub.KerasLayer(self.encoder.mlm, trainable=self.trainable)
        # string input to tokens
        tokenized_input = tokenizer(text_input)
        # tokens to bert encoder inputs
        encoder_inputs = bert_pack_inputs([tokenized_input])
        # bert encoder inputs to mlm inputs
        mlm_inputs = self.__to_mlm_inputs(encoder_inputs)
        # consume mlm inputs and produce corresponding outputs by the mlm model
        mlm_outputs = mlm_layer(mlm_inputs)
        # convert mlm outputs to probabilities of each char of the input
        logit_probs = self.__to_logit_probs(mlm_outputs)
        char_prob = self.__to_char_prob(encoder_inputs, logit_probs)
        top1_prob, top1_token = self.__to_topk_probs(logit_probs, 1)
        top1_prob = tf.squeeze(top1_prob, axis=-1)
        top1_token = tf.squeeze(top1_token, axis=-1)
        char_token = encoder_inputs['input_word_ids']
        outputs = {
            'char_token': char_token,
            'char_prob': char_prob,
            'top1_token': top1_token,
            'top1_prob': top1_prob,
        }
        # pack the layers as tf model
        return tf.keras.Model(text_input, outputs)

    def __to_logit_probs(self, mlm_outputs: dict) -> tf.Tensor:
        """
        Convert mlm outputs to the probabilities of each char in vocab
        :returns: shape=(batch_size, seq_length, vocab_size)
        """
        return tf.keras.layers.Softmax()(mlm_outputs["mlm_logits"])

    def __to_topk_probs(self, logit_probs, k: int = 1) -> list[tf.Tensor]:
        """
        Convert mlm outputs to the indices of top K chars
        :param logit_probs: returns of self.__to_logit_probs()
        :param k: number of items at the top most to take, optional
        :returns: list(values=[shape=(batch_size, seq_length, k)], indices=[shape=(batch_size, seq_length, k)])
        """
        return tf.math.top_k(logit_probs, k)

    def __to_char_prob(self, encoder_inputs: dict, logit_probs) -> tf.Tensor:
        """
        Convert the mlm outputs to the probabilities of each char of sentences
        :param encoder_inputs: inputs for the bert model
        :param logit_probs: returns of self.__to_logit_probs()
        :returns: shape=(batch_size, seq_length)
        """
        vocab_size = self.get_vocab_size()
        one_hot_token_ids = tf.one_hot(encoder_inputs["input_word_ids"],
                                       vocab_size)
        return tf.reduce_max(tf.multiply(logit_probs, one_hot_token_ids), -1)

    def __to_mlm_inputs(self, encoder_inputs: dict) -> dict[str, tf.Tensor]:
        """
        Convert bert encoder inputs input mlm inputs
        :returns: the corresponding inputs for bert's mlm model
        """
        masked_lm_positions = tf.multiply(
            tf.ones_like(encoder_inputs["input_word_ids"], dtype=tf.int32),
            tf.transpose(tf.range(0, self.seq_length)))
        return {
            "input_word_ids": encoder_inputs["input_word_ids"],
            "input_mask": encoder_inputs["input_mask"],
            "input_type_ids": encoder_inputs["input_type_ids"],
            "masked_lm_positions": masked_lm_positions,
        }


In [4]:
detector = TypoDetector()
model = detector.get_model()
vocab = detector.get_vocab()

In [5]:
class SentenceInfoPraser:
    """
    Example:
    sentenceInfoParser = SentenceInfoPraser(vocab)
    sentenceInfoParser.parse_sentences(["「Star War」很好看。"])
    """

    def __init__(self, vocab: list[str]):
        self.vocab = vocab
        self.separentorMatcher = re.compile(r'[ -\/:-@\[-`\{-~]')

    def get_words_info(self, sentence: str, tokens: list[int]) -> list[tuple[int, int]]:
        """
        Return the index and length of each word in the sentence.
        :param sentence: the sentence string
        :param tokens: the token representation of the sentence
        :returns: list of tuple that contains the word's start and end indexes
        """
        sentence_length = len(sentence)
        max_token_idx = len(tokens) - 1
        words_info = []
        char_ptr = 0
        for token_idx, token in enumerate(tokens):
            #find the word length of the token
            if token == 101: # [CLS]
                char_end_ptr = char_ptr
            elif token == 102: # [SEP]
                break
            elif token == 100: # word is [UNK]
                #find next word position
                if max_token_idx > token_idx:
                    next_token = tokens[token_idx+1]
                    if next_token == 100:
                        #find separent symbol
                        match = self.separentorMatcher.search(sentence, char_ptr)
                        if match is not None:
                            char_end_ptr = match.start()
                        else:
                            char_end_ptr = sentence_length
                    else:
                        next_word = self.get_word(next_token)
                        char_end_ptr = sentence.index(next_word, char_ptr)
                else:
                    char_end_ptr = sentence_length
            else:
                word = self.get_word(token)
                char_end_ptr = char_ptr + len(word)
            #store the corresponding word position and length of the token
            words_info.append((char_ptr, char_end_ptr))
            char_ptr = char_end_ptr
        return words_info

    def get_word(self, token: int) -> str:
        """
        Convert given token back to the word piece
        """
        return self.vocab[token].replace('##', '')


In [6]:
sentenceInfoParser = SentenceInfoPraser(vocab)

In [11]:
sentences = [
    # "古語有云：等到潮水退了，就知道誰沒穿褲子。",
    # "古語有云：當地天氣嚴寒，室內亦只有兩、三度，不少衣衫單簿的女歌手都冷得發抖。",
    # "古語有云：《三國演義》一書敍述了魏、蜀、吳三國之間復雜的爭鬥。",
    # "古語有云：摧眠治療早於 1958 年己被美國醫學會 (AMA) 宣佈為正式的精神治療方法之一。",
    # "古語有云：香港大球場舉行了一連三天的國際性球賽，吸引了世界各地球迷蜂湧而至。",
    # "古語有云：請問我於遞交申請表和會費後，何時才會收到會員卡和優惠贈卷呢？",
    # "古語有云：不少人認為燈迷很難猜，其實只要我們掌握其法，就很容易猜到答案的。",
    # "古語有云：他的文章內容混亂，令人看得頭昏腦漲。",
    # "古語有云：他的表現輕佻浮燥，惹人討厭。",
    # "古語有云：他對偶像的鐘愛，已到了盲目的地步。",
    # "古語有云：正確和錯誤，或真與假，可以算是邏輯裏面最基本的慨念。",
    # "【本文獲授權轉載。】",
    # "案件編號：CCDI-602/2016",
    # "【林綸詩其他文章：】",
    # "• 宣布天厨違反第一行為守則",
    # "• 向天厨施加罰款",
    # "• 禁止天厨日後從事違反該守則的相同行為",
    # "• 天厨須推行有效的合規計劃",
    # "• 向天厨收取競委會的訟費及調查費用",
    # "↓↓↓單位直撃↓↓↓",
    # "👇🏻即看3款總評分5星產品👇🏻",
    # "葉劉：難斷定11月會否放寬至「0+7」",
    # "👉",
    # "PCAOB必須能夠訪問所有在其註冊的公眾會計師事務所的審計文件並選擇任何審計業務，而非僅僅部分事務所或部分業務才能夠在內地和香港進行全面的檢查和調查。",
    # "PCAOB將通知審計公司其檢查計畫，包括具體的業務。PCAOB檢查員將於9月中旬前在香港開始他們的檢查工作。",
    # "HKZO有個重點一定要注意，可能會分散注意力。",
    # "入伙後一家四口陸續收到由「The Watcher」寄來的恐怖信件，遇上不友善的鄰居擅闖家中，並躲再食物升降機。",
    # "保險公司Allstate對這兩款產品進行了跌落測試，測試設備為Dropbot，保證拍出認為因素干擾，最終得到的結果是：iPhone 14 Plus更耐摔。",
    "影片只有5秒，但已經有15萬次點擊，歌迷相當關注，亦有人讚賞Jennie敬業，馬上就重新站起來表演，有時常看演出的網民表示，這應該跟耳機漏電有關，更表示「已經不是第一次」，覺得大會跟共公司要改善，照片歌手安全。",

    # "她擔心得笑不停。",
    # "她開心得笑不停。",
]

outputs = model(tf.constant(sentences))

In [12]:
for sentence, char_tokens, char_porbs, top1_tokens, top1_probs in zip(
        sentences,
        outputs["char_token"], outputs["char_prob"], outputs["top1_token"],
        outputs["top1_prob"]):
    tokens = list(char_tokens.numpy())
    tokens = tokens[:tokens.index(0)] #crop padding
    words_info = sentenceInfoParser.get_words_info(sentence, tokens)
    for token_idx, char_token in enumerate(tokens):
        if char_token == 101:
            continue
        if char_token == 102:
            break
        else:
            word_start_idx, word_end_idx = words_info[token_idx]
            word = sentence[word_start_idx:word_end_idx]
            char = vocab[char_token]
            char_prob = char_porbs[token_idx]
            top1_token = top1_tokens[token_idx]
            top1_char = vocab[top1_token]
            top1_prob = top1_probs[token_idx]
            print(f"#{token_idx:0>2}({char_token}) {word:>5}({char_prob:>7.2%}): {top1_char:>3}({top1_prob:>7.2%})")

#01(2512)     影(  4.39%):   。( 27.04%)
#02(4275)     片(100.00%):   片(100.00%)
#03(1372)     只( 99.98%):   只( 99.98%)
#04(3300)     有(100.00%):   有(100.00%)
#05(126)     5( 99.96%):   5( 99.96%)
#06(4907)     秒(100.00%):   秒(100.00%)
#07(8024)     ，( 99.99%):   ，( 99.99%)
#08(852)     但( 99.99%):   但( 99.99%)
#09(2347)     已(100.00%):   已(100.00%)
#10(5195)     經(100.00%):   經(100.00%)
#11(3300)     有(100.00%):   有(100.00%)
#12(8115)    15( 99.97%):  15( 99.97%)
#13(5857)     萬(100.00%):   萬(100.00%)
#14(3613)     次(100.00%):   次(100.00%)
#15(7953)     點(100.00%):   點(100.00%)
#16(3080)     擊(100.00%):   擊(100.00%)
#17(8024)     ，( 99.96%):   ，( 99.96%)
#18(3625)     歌(100.00%):   歌(100.00%)
#19(6837)     迷(100.00%):   迷(100.00%)
#20(4685)     相(100.00%):   相(100.00%)
#21(4534)     當(100.00%):   當(100.00%)
#22(7302)     關(100.00%):   關(100.00%)
#23(3800)     注(100.00%):   注(100.00%)
#24(8024)     ，( 99.99%):   ，( 99.99%)
#25(771)     亦( 99.99%):   亦( 99.99%)
#26(3300)     有(100.00%):   