In [71]:
from IPython.core.display import display, HTML
import bisect
from collections import defaultdict
import jieba
import jieba_fast
import pickle
import math
import os

In [72]:
class MySearch():
    """
    Attributes
    ----------
    filename : str
        file name of doc data

    Methods
    -------
    load_data(filename):
        load data from file.
    save_data(filename):
        save data to file
    pre_cache_all():
        Pre-caching all words in docs.
    BM25score(text, keyword):
        get BM25score of text for a query.
    BM25init(doc_list):
        initializing BM25 model
    search(keyword, num=15):
        get top num search results of a query.
    render(result_list, keyword):
        output search results with highlight.
    query_to_set_expression(query):
        convert bool query to set expression(for eval process).
    get_word_match(word):
        get match set of the word.
    """
    
    def __init__(self, filename): 
        self.docs = [] 
        self.docs_lower = []
        self.search_cache = defaultdict(set)
        self.doc_count=0
        self.avgdl=0
        self.df=defaultdict(int)
        self.stop_word_set=set([' ','$','——','-','.',',','/', '，', '；', \
                                '？','！','。','、','：','“','”',';',\
                                '@','是','让','了','的','啊','吧']) 
        jieba.load_userdict("dict.txt")
        self.load_data(filename)#读取文件
        
        
    def load_data(self, filename): #读取数据
        if filename[-3:] == 'txt': 
            with open(filename, 'r', encoding='utf-8') as f:
                self.docs = f.read().split('\n') 
            self.docs_lower = [doc.lower() for doc in self.docs] 
            self.pre_cache_all() 
        elif filename[-3:] == 'dat':
            with open(filename, 'rb') as f:
                self.docs, self.docs_lower, self.search_cache = pickle.load(f)
    
    def pre_cache_all(self): #提前加载cache
        doc_length_sum = 0
        for tid, doc in enumerate(self.docs_lower): 
            for word in jieba.cut_for_search(doc): 
                self.search_cache[word].add(tid) #出现过的词进入倒排索引                
    
    def query_to_set_expression(self, query): #转换
        query_new_parts = []
        all_parts = list(query.replace('(', ' ( ').replace(')', ' ) ').split())
        idx = 0
        cache = ''
        count_parts = len(all_parts)
        while idx < count_parts: 
            if all_parts[idx] == '(' or all_parts[idx] == ')':
                query_new_parts.append(all_parts[idx]) 
            elif all_parts[idx] == ' ' or all_parts[idx] == '': 
                query_new_parts.append(' ') 
            elif all_parts[idx] in ('and', 'AND', '+'):
                query_new_parts.append('&')
            elif all_parts[idx] in ('or', 'OR'):
                query_new_parts.append('|')
            elif all_parts[idx] in ('not', 'NOT', '-'): 
                query_new_parts.append('-')
            else:
                if cache:
                    cache += ' ' + all_parts[idx] 
                else:
                    cache = all_parts[idx] 

                if (idx + 1 == count_parts
                  or all_parts[idx + 1] in ('(', ')', 'and', 'AND', '+', 'or', 'OR', 'NOT', 'not', '+', '-', ' ', '')):
                    query_new_parts.append("self.get_word_match('{}')".format(cache)) 
                    cache = ''
            idx += 1
        query_new = ''.join(query_new_parts)
        return query_new #id
    
    def get_word_match(self, word):#转换tid
        if_first_subword = True
        result = None
        for term in list(jieba.cut(word)):
            if if_first_subword:
                result = self.search_cache[term] 
                if_first_subword = False
            else:
                result = result & self.search_cache[term]
            if not result:
                break
        return result
    
    def BM25init(self,doc_list): #初始化BM25模型
        doc_length_sum=0
        self.doc_count=len(self.docs_lower)   #N
        for doc in doc_list:
            doc_length_sum += len(doc)  
            for word in set(jieba.cut(doc.lower())): 
                self.df[word]+=1  #n
        self.avgdl = doc_length_sum/self.doc_count #avgdl
        
        
    def BM25score(self,doc,query,k1 = 1.5,b=0.75):#BM25score    
        result=0
        word_list=list(jieba.cut(doc))    
        dl=len(doc) # dl
        for word in query:
            f = doc.count(word) # f
            idf=math.log10((self.doc_count-self.df[word]+0.5)/(self.df[word]+0.5)+1)
            r=f*(k1+1)/(f+k1*(1-b+b*(dl/self.avgdl)))
            result+=r*idf       
        return result
    
    #最新的
    def search(self, query): 
        query_lower = query.lower()   
        result_list = []
        min_score = 0 
        query_new = self.query_to_set_expression(query_lower) 
        match_tid_list=list(eval(query_new))#查询结果tid
        query_fresh=set(jieba.cut(query.lower()))-set(['(',')','and','or','not','-',' ','','AND','+','OR','NOT'])
        self.BM25init([self.docs_lower[tid] for tid in match_tid_list])#初始化BM25模型
        result_list=[(tid,self.BM25score(self.docs_lower[tid],query_fresh)) for tid in match_tid_list] #查询结果BM25结果打分
        result_list.sort(key=lambda x:x[1],reverse=True) #排序
        return [doc_id for doc_id, _ in result_list]

    
    def render(self, result_list, keyword): #展示函数
        count = 1
        print("共检索出",len(result_list),"条数据","\n","--------------分割线-------------")
        for item in result_list:
            result = self.docs[item].replace('$$$', '<br/>') 
            display(HTML("{}、{}......".format(count,result[:100]))) #输出内容
            count += 1

## 运行内容

In [73]:
%%time
searcher=MySearch('titles.txt')

Wall time: 779 ms


In [104]:
%%time
query='新能源 AND (特斯拉 or 小鹏) '
search_result=searcher.search(query)

Wall time: 70.5 ms


In [105]:
%%time
searcher.render(search_result,query)

共检索出 4 条数据 
 --------------分割线-------------


Wall time: 19.1 ms
