In [42]:
import sentencepiece as spm
from coherence_model import AlbertCoherenceRank
from transformers import AlbertConfig
import torch

In [2]:
s = spm.SentencePieceProcessor(model_file='tokenizer_65536.model')

In [12]:
albert_tiny_config = {
    "attention_probs_dropout_prob": 0.0,
    "directionality": "bidi",
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.0,
    "hidden_size": 312,
    "embedding_size": 128,
    "initializer_range": 0.02,
    "intermediate_size": 1248 ,
    "max_position_embeddings": 512,
    "num_attention_heads": 12,
    "num_hidden_layers": 4,
    "pooler_fc_size": 768,
    "pooler_num_attention_heads": 12,
    "pooler_num_fc_layers": 3,
    "pooler_size_per_head": 128,
    "pooler_type": "first_token_transform",
    "type_vocab_size": 2,
    "vocab_size": 65536,
    "ln_type":"postln"
}

config = AlbertConfig(**albert_tiny_config)

In [14]:
model = AlbertCoherenceRank.from_pretrained('./coherence_model', config=config, sequence_length=128)

In [83]:
class NeuralWordSegmentation:

    def __init__(self, tokenizer_path, model):
        self.spm = spm.SentencePieceProcessor(model_file=tokenizer_path)
        self.model = model

    def get(self, sentence, nbest=10):
        candidates = self.spm.NBestEncodeAsIds(sentence, nbest)
        candidates_words = self.spm.NBestEncodeAsPieces(sentence, nbest)
        inputs = torch.zeros((len(candidates), 128), dtype=int)
        for i, candidate in enumerate(candidates):
            inputs[i,:len(candidate)] = torch.tensor(candidate)
        self.model.eval()
        albert_out, _ = self.model.albert(inputs)
        scores = self.model.mlp(torch.flatten(albert_out, start_dim=1)).to('cpu').detach().numpy()
        results = [("/".join(candidate), scores[i,0], i+1) for i, candidate in enumerate(candidates_words)]
        return sorted(results, key=lambda t: t[1], reverse=True)

In [84]:
nws = NeuralWordSegmentation('tokenizer_65536.model', model)

In [102]:
nws.get('兒子生性病母倍感安慰', 10)

[('▁/兒子/生/性病/母/倍/感/安慰', -0.52289635, 1),
 ('▁/兒子/生性/病/母/倍/感/安慰', -0.62185246, 2),
 ('▁/兒/子/生/性/病/母/倍/感/安慰', -0.67673385, 8),
 ('▁/兒/子/生/性病/母/倍/感/安慰', -0.6780272, 4),
 ('▁/兒/子/生性/病/母/倍/感/安慰', -0.75988257, 5)]

In [123]:
nws.get('獅子山下體現香港精神', 10)


[('▁/獅子山/下體/現/香港/精神', 0.110433176, 5),
 ('▁/獅子山/下/體/現/香港/精神', 0.016519196, 8),
 ('▁/獅子/山/下/體現/香港/精神', -0.036805652, 7),
 ('▁/獅子山下/體/現/香港/精神', -0.14911321, 3),
 ('▁/獅子山下/體現/香港/精神', -0.15852737, 1),
 ('▁/獅子山/下/體現/香港/精神', -0.20499955, 2),
 ('▁/獅子山下/體/現/香港/精/神', -0.20978086, 10),
 ('▁/獅子山/下/體現/香港/精/神', -0.22237615, 9),
 ('▁/獅子山下/體現/香港/精/神', -0.3028409, 4),
 ('▁/獅子山下/體現/香/港/精神', -0.37603474, 6)]

In [109]:
nws.get('花生長在屋後的田裡', 10)

[('▁/花/生長/在/屋/後的/田/裡', 0.08546053, 3),
 ('▁/花/生長/在/屋/後/的/田/裡', -0.15445843, 4),
 ('▁/花/生/長/在/屋/後/的/田/裡', -0.22414203, 6),
 ('▁/花/生/長/在/屋/後的/田/裡', -0.47418493, 5),
 ('▁/花生/長/在/屋/後的/田/裡', -0.81693584, 1),
 ('▁/花生/長/在/屋/後/的/田/裡', -0.92735624, 2)]

In [124]:
nws.get('照顧客嘅要求設計產品', 10)

[('▁/照/顧客/嘅/要求/設計/產品', 0.0071468055, 3),
 ('▁/照/顧/客/嘅/要求/設計/產品', -0.06025382, 9),
 ('▁/照/顧客/嘅/要/求/設計/產品', -0.220116, 6),
 ('▁照/顧/客/嘅/要求/設計/產品', -0.22720684, 7),
 ('▁照/顧客/嘅/要/求/設計/產品', -0.8018307, 5),
 ('▁照/顧客/嘅/要求/設計/產品', -0.835267, 2),
 ('▁/照顧/客/嘅/要求/設計/產品', -0.8555057, 1),
 ('▁/照顧/客/嘅/要/求/設計/產品', -0.95422095, 4),
 ('▁/照顧/客/嘅/要求/設計/產/品', -0.9880646, 10),
 ('▁/照顧/客/嘅/要求/設/計/產品', -1.0709113, 8)]

In [125]:
nws.get('要學生活得更有意義', 10)


[('▁/要/學/生活/得更/有意義', 0.36772978, 5),
 ('▁/要/學生/活得/更/有意義', 0.30958253, 6),
 ('▁/要/學生/活/得更/有意義', 0.28081858, 9),
 ('▁/要學/生活/得更/有/意義', -0.056980528, 10),
 ('▁/要學/生活/得/更/有意義', -0.32899457, 4),
 ('▁要/學生/活得/更/有意義', -0.33722532, 3),
 ('▁要/學/生活/得/更/有意義', -0.37182134, 8),
 ('▁/要學/生活/得更/有意義', -0.48700887, 1),
 ('▁要/學/生活/得更/有意義', -0.5989682, 2),
 ('▁要/學生/活/得更/有意義', -0.67934483, 7)]