In [1]:
import os
import jieba
import Spider

def score(item, query):
    score = 0
    # TODO 对query查询的分词避免重复
    for keyword in jieba.cut(query.lower()):
        title_score = item[1].lower().count(keyword.lower())
        content_score = item[2].lower().count(keyword.lower())
        score += title_score * 5 + content_score * 3
    return score


class MySearcherC8V0:
    """
    第七次课升级的搜索类版本：
    1、__init__()初始化过程加载自定义分词词典
    2、build_cache()改用jieba.cut_for_search进行分词
    3、search()对查询分词
    4、search()对分词结果取posting
    5、search()对posting lists进行合并(交集)
    6、build_cache()将posting保存格式改成只用doc_id(方便集合运算)
    7、rank()实现对候选文档打分排序
    8、score()实现对查询中包含的多词统计词频计分
    """
    def __init__(self, scale: int=1):
        self.docs = list()
        self.load_data()
        if scale > 1:
            self.docs *= scale  # 文档规模倍增，用于测试搜索速度
        self.cache = dict()
        self.vocab = set()
        self.lower_preprocess()
        jieba.load_userdict('./dict.txt')
        self.build_cache()

    def load_data(self, data_file_name='./news_list.pkl'):
        if os.path.exists(data_file_name):
            self.docs = Spider.pickle_load(data_file_name)
        else:
            Spider.pickle_save(data_file_name)
            self.docs = Spider.pickle_load(data_file_name)

    def search(self, query):
        result = None
        for keyword in jieba.cut(query.lower()):
            if keyword in self.cache:
                if result is None:
                    result = self.cache[keyword]
                else:
                    result = result & self.cache[keyword]
            else:
                result = set()
                break
        if result is None:
            result = set()
        sorted_result = self.rank(query, result)
        return sorted_result

    def rank(self, query, result_set):
        result = list()
        for doc_id in result_set:
            result.append([doc_id, score(self.docs[doc_id], query)])

        result.sort(key=lambda x: x[1], reverse=True)
        return result

    # def render_search_result(self, keyword):
    #     count = 0
    #     for item in self.search(keyword):
    #         count += 1
    #         print(f'{count}[{item[1]}] {highlight(self.docs[item[0]][1], keyword)}')

    def build_cache(self):
        """用分词（用文档过滤词库）的方式初始化缓存（构建索引）"""
        doc_id = 0
        for doc in self.docs:
            doc_word_set = set()
            for word in jieba.cut_for_search(doc[3]):
                if word not in doc_word_set:
                    result_item = doc_id
                    if word not in self.cache:
                        self.cache[word] = {result_item}
                    else:
                        self.cache[word].add(result_item)
                    self.vocab.add(word)
                    doc_word_set.add(word)
            doc_id += 1

    def lower_preprocess(self):
        for doc_id in range(len(self.docs)):
            self.docs[doc_id].append(
                (self.docs[doc_id][1] + ' ' + self.docs[doc_id][2]).lower())

    def simple_test(self):
        assert(len(self.search('tiktok')) > 1)

In [2]:
%time searcher_v0 = MySearcherC8V0()

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\10633\AppData\Local\Temp\jieba.cache
Loading model cost 0.857 seconds.
Prefix dict has been built successfully.


Wall time: 3.56 s


In [3]:
def highlight(item, query: str, side_len: int = 12) -> str:
    positions = list()
    segments = list()
    i = 0
    content_lower = item[2].lower()
    len_content_lower = len(content_lower)
    for keyword in jieba.cut(query):
        idx = content_lower.find(keyword.lower())
        positions.append(idx)
    positions.sort()
    while i < len(positions):
        start_pos = max(positions[i] - side_len, 0)
        end_pos = min(positions[i] + side_len, len_content_lower)
        while (i < len(positions) - 1) and (positions[i+1] - positions[i] < side_len*2):
            end_pos = min(positions[i+1] + side_len, len_content_lower)
            i += 1
        start_ellipsis = '...' if start_pos > 0 else ''
        end_ellipsis = '...' if end_pos < len_content_lower else ''
        segments.append(start_ellipsis + item[2][start_pos: end_pos] + end_ellipsis)
        i += 1
    result = item[1] + '<br/>' + ''.join(segments)
    # if idx >= 0:
    #     ori_word = item[1][idx:idx+(len(query))]
    #     result = item[1].replace(ori_word, f'<span style="color:red";>{ori_word}</span>')
    return result

In [None]:
class MySearcherC8V1(MySearcherC8V0):
    pass