In [2]:
import utils
import jieba
import numpy as np
from typing import *
from sklearn.feature_extraction.text import CountVectorizer
import pandas as pd

In [10]:
class Classifier:
    def __init__(self, model_dir: str):
        # 注意CountVectorizer的参数应与训练时一致
        vocab, self.p_c, self.p_w_c, self.labels = utils.load_model(model_dir)
        self.cv = CountVectorizer(analyzer="word", token_pattern=r"(?u)\b\w+\b", vocabulary=vocab)

    def predict_text(self, text: str, top_n: int=1) -> List[Tuple[str, float]]:
        """
        单条预测接口
        :param text: 待分类文本
        :param top_n: 返回三个预测结果
        :return: [(category_name, probability)]
        """
        seg = " ".join(utils.jieba_segment(text, mode="search"))
        text_vec = self.cv.transform([seg])
        log_p_d_c = text_vec @ np.log(self.p_w_c)
        log_p_c_d = np.log(self.p_c).reshape(-1, 1) + log_p_d_c.T
        prob = utils.softmax(log_p_c_d)
        top_n_index = prob[:, 0].argsort()[::-1][:top_n]
        # return [(self.labels[index], prob[:, 0][index]) for index in top_n_index]
        return [(self.labels[index]) for index in top_n_index]


In [7]:
def preprocess_prediction_data(data):
    examples = []
    for text_a in data:
        examples.append({"text_a": text_a})
    return examples

In [5]:
# test = pd.read_table('resources/cropus/iphone13_comment.txt', sep='\t',header=None)
# test.columns = ["text_a"]

In [8]:
# 对测试集数据进行格式处理
data1 = list(test.text_a)
examples = preprocess_prediction_data(data1)

In [13]:
if __name__ == "__main__":
    cls = Classifier("resources/model/classification_model")
    res = cls.predict_text('输入一句评论', top_n=3)
    print(res)

ValueError: jieba: the input parameter should be unicode.