In [131]:
from IPython.core.display import display, HTML
import bisect
from collections import defaultdict
import jieba
import pickle
from math import sqrt, log
from collections import defaultdict    
    
class BM25():
    def __init__(self, doc_list):
        self.doc_count = len(doc_list)
        self.avgdl = 0
        self.df = defaultdict(int)
        for doc in doc_list:
            for word in set(jieba.cut(doc)):
                self.df[word] += 1
            self.avgdl += len(doc)
        self.avgdl /= self.doc_count
    
    def score(self, q, doc):
        k1 = 1.5
        b = 0.75
        result = 0
        query_new = set(jieba.cut(q.lower())) - set(['(', ')', 'and', 'AND', '+', 'or', 'OR', 'NOT', 'not', '+', '-', ' ', ''])
        word_list_doc = list(jieba.cut(doc.lower()))
        for keyword in query_new:
            f = word_list_doc.count(keyword)
            dl = len(doc)
            idf = log((self.doc_count - self.df[keyword] + 0.5) / (self.df[keyword] + 0.5) + 1)
            result += idf * ((f * (k1 + 1)) / (f + k1 * (1 - b + b * dl / self.avgdl)))
        return result
    

class MySearchC6V0():
    """
    C3V0: Base class for Search Engine.
    C3V1: Data multiplication added.
    C3V2: Sorting optimization.
    C3V3: Add lowered version of docs.
    C3V4: For long doc.
    C3V5: Caching search results.
    C3V6: Pre-caching all words in docs.
    C3V7: Add Serialize/UnSerialize.
    C4V1: Add basic Bool query support
    C4V2: Add wordseg to get_word_match()
    ----------------C5V0-----------------
    C5V1: Use VSMTFIDF.score() as score
    C5V2: Use BM25.score() as score
    C5V3: Use MiddleRank -> BM25
    ----------------C6V0-----------------

    Attributes
    ----------
    filename : str
        file name of doc data
    multi_factor : int
        data multiplication factor(default 1)

    Methods
    -------
    load_data(filename):
        load data from file.
    save_data(filename):
        save data to file
    pre_cache_all():
        Pre-caching all words in docs.
    highlight(text, keyword):
        highlight text with keyword.
    score(text, keyword):
        get score of text for a query.
    get_word_match(self, keyword):
        get doc set containing keyword.
    search(keyword, num=15):
        get top num search results of a query.
    render(result_list, keyword):
        output search results with highlight.
    query_to_set_expression(query):
        convert bool query to set expression(for eval process).
    get_word_match(word):
        get match set of the word.
    def mid_score(query, tid):
        get middle-rank score of doc(tid) according to query
    def cosine(vec1, vec2):
        get cosine similarity between vec1 and vec2
    def dot(vec1, vec2):
        get dot product of vec1 and vec2
    """
    
    def __init__(self, filename, multi_factor=1):
        self.docs = []
        self.docs_lower = []
        self.doc_word_dict = [] #记录文档-词关系
        self.search_cache = defaultdict(set)
        self.multi_factor = multi_factor
        self.load_data(filename)
    
    def highlight(self, text, keyword, ori_text):
        idx = text.find(keyword)
        result = text
        if idx >= 0:
            ori_keyword = ori_text[idx:idx+len(keyword)]
            result = ori_text.replace(ori_keyword, f'<span style="color:red">{ori_keyword}</span>')
        return result
    
    def score(self, text, keyword):
        result = text.count(keyword)
        return result
    
    def query_to_set_expression(self, query):
        query_new_parts = []
        all_parts = list(query.replace('(', ' ( ').replace(')', ' ) ').split())
        idx = 0
        cache = ''
        count_parts = len(all_parts)
        while idx < count_parts:
            if all_parts[idx] == '(' or all_parts[idx] == ')':
                query_new_parts.append(all_parts[idx])
            elif all_parts[idx] == ' ' or all_parts[idx] == '':
                query_new_parts.append(' ')
            elif all_parts[idx] in ('and', 'AND', '+'):
                query_new_parts.append('&')
            elif all_parts[idx] in ('or', 'OR'):
                query_new_parts.append('|')
            elif all_parts[idx] in ('not', 'NOT', '-'):
                query_new_parts.append('-')
            else:
                if cache:
                    cache += ' ' + all_parts[idx]
                else:
                    cache = all_parts[idx]

                if (idx + 1 == count_parts
                  or all_parts[idx + 1] in ('(', ')', 'and', 'AND', '+', 'or', 'OR', 'NOT', 'not', '+', '-', ' ', '')):
                    query_new_parts.append(f"self.get_word_match('{cache}')")
                    cache = ''
            idx += 1
        query_new = ''.join(query_new_parts)
        return query_new
    
    def get_word_match(self, word):
        if_first_subword = True
        result = None
        for term in list(jieba.cut(word)):
            if if_first_subword:
                result = self.search_cache[term]
                if_first_subword = False
            else:
                result = result & self.search_cache[term]
            if not result:
                break
        return result
    
    def search(self, query, num=15):
        query_lower = query.lower()    
        result_list = []
        min_score = 0
        #粗筛(候选文档)
        query_set_expression = self.query_to_set_expression(query_lower)
        match_tid_list = list(eval(query_set_expression))

        query_new = ' '.join(set(jieba.cut(query_lower)) - set(['(', ')', 'and', 'AND', '+', 'or', 'OR', 'NOT', 'not', '+', '-', ' ', '']))
        
        #粗排(快速排序)
        mid_tid_list = [(tid, self.mid_score(query_new, tid)) for tid in match_tid_list]
        mid_tid_list.sort(key = lambda x: x[1], reverse=True)
        
        #细排序
        bm25_model = BM25([self.docs_lower[tid] for tid,_ in mid_tid_list[:num + 5]])
        result_list = [(tid, bm25_model.score(query_new, self.docs_lower[tid])) for tid,_ in mid_tid_list]
        result_list.sort(key = lambda x: x[1], reverse=True)
                           
        return [doc_id for doc_id, _ in result_list[:num]]
            
    def pre_cache_all(self):
        for tid, doc in enumerate(self.docs_lower):
            doc_tf_dict = defaultdict(int)
            doc_word_count = 0
            for word in jieba.cut_for_search(doc):
                self.search_cache[word].add(tid)
                doc_tf_dict[word] += 1
                doc_word_count += 1
            for word in doc_tf_dict:
                doc_tf_dict[word] /= doc_word_count
            self.doc_word_dict.append(doc_tf_dict)
    
    def render(self, result_list, keyword):
        count = 1
        for item in result_list:
            result = self.highlight(
                self.docs_lower[item], 
                keyword.lower(), 
                self.docs[item]
            ).replace('$$$', '<br/>') #
            display(HTML(f"{count}、{result[:150]}......")) #
            count += 1
            
    def mid_score(self, query, tid):
        vacabulary = list(set(query.split()))
        q_vec = [1] * len(vacabulary)
        d_vec = [self.doc_word_dict[tid][word] for word in vacabulary]
        score = self.cosine(q_vec, d_vec)
        return score
    
    def dot(self, vec1, vec2):
        return [vec1[i] * vec2[i] for i in range(len(vec1))]
    
    def cosine(self, vec1, vec2):
        return sum(self.dot(vec1, vec2)) / (sqrt(sum(self.dot(vec1, vec1))) * sqrt(sum(self.dot(vec2, vec2))))
    
    def load_data(self, filename):
        if filename[-3:] == 'txt':
            with open(filename, 'r') as f:
                self.docs = f.read().split('\n')
            self.docs_lower = [doc.lower() for doc in self.docs]
            self.docs = self.docs * self.multi_factor 
            self.docs_lower = self.docs_lower * self.multi_factor
            self.pre_cache_all()
        elif filename[-3:] == 'dat':
            with open(filename, 'rb') as f:
                self.docs, self.docs_lower, self.search_cache, self.doc_word_dict = pickle.load(f)
                
    def save_data(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump((self.docs, self.docs_lower, self.search_cache, self.doc_word_dict), f)
            

In [132]:
searcher = MySearchC6V0('titles_l.txt', 1)

In [134]:
searcher.search_cache['0-3']

set()

In [133]:
query = '0-3'
search_result = searcher.search(query, num=10)
searcher.render(search_result, query)

In [138]:
' | '.join(jieba.cut_for_search('11-29'))

'11 | - | 29'

### 对中英/数混排文档分词粒度的思考   
### 解决方案：  
#### 对中英文/数字片段分别处理  

In [139]:
import string

def parse_doc(doc):
    result = []
    state_last = ''
    cache = ''
    for c in doc:
        state_c = c in string.ascii_letters \
            or c.isdigit() \
            or c in ('-', ':', '.', '：', '/')
        if c == ' ':
            if state_last:
                result.append(cache)
            else:
                result.extend(list(jieba.cut_for_search(cache)))
            result.append(' ')
            cache = ''
            state_last = '' 
        else:
            if state_c == state_last:
                cache += c
            else:
                if state_last != '':
                    if state_last:
                        result.append(cache)
                    else:
                        result.extend(list(jieba.cut_for_search(cache)))
                cache = c
            state_last = state_c
    if cache:
        if state_last:
            result.append(cache)
        else:
            result.extend(list(jieba.cut_for_search(cache)))
    return result

In [140]:
doc = '''德国杯-爆冷!拜仁0-5惨遭门兴血洗 后防灾难级表现$$$北京时间10月28日凌晨2：45，2021-2022赛季德国杯第二轮继续进行，拜仁慕尼黑客场对阵门兴格拉德巴赫，上半场，科内闪电破门，本塞拜尼连入两球，拜仁0-3落后。下半场，恩博洛梅开二度。最终，拜仁客场0-5惨败于门兴，德国杯惨遭淘汰！$$$https://www.163.com/sports/article/GNCKODUB00058781.html'''
' | '.join(parse_doc(doc))

'德国 | 德国杯 | - | 爆冷 | ! | 拜仁 | 0-5 | 惨遭 | 门兴 | 血洗 |   | 后防 | 灾难 | 级 | 表现 | $ | $ | $ | 北京 | 时间 | 10 | 月 | 28 | 日 | 凌晨 | 2：45 | ， | 2021-2022 | 赛季 | 德国 | 德国杯 | 第二 | 二轮 | 第二轮 | 继续 | 进行 | ， | 拜仁 | 慕尼黑 | 客场 | 对阵 | 门兴格 | 拉 | 德巴 | 巴赫 | 德巴赫 | ， | 上半 | 半场 | 上半场 | ， | 科内 | 闪电 | 破门 | ， | 本塞拜尼 | 连入 | 两球 | ， | 拜仁 | 0-3 | 落后 | 。 | 下半 | 半场 | 下半场 | ， | 恩博洛 | 二度 | 梅开二度 | 。 | 最终 | ， | 拜仁 | 客场 | 0-5 | 惨败 | 于门兴 | ， | 德国 | 德国杯 | 惨遭 | 淘汰 | ！ | $ | $ | $ | https://www.163.com/sports/article/GNCKODUB00058781.html'

In [141]:
def get_char_type(c):
    """返回当前字符的类型(e,c,s,f,b)

    Args:
        c:要进行判断的单个字符

    Results:
        返回判断结果(前缀)：e为英文，c为中文，s为空格，f为引号，b为括号
    """
    result = 'c'
    if c in string.ascii_letters \
            or c.isdigit() \
            or c in ('-', ':', '.', '：', '/'):
        result = 'e'
    elif c == '"':
        result = 'f'
    elif c == ' ':
        result = 's'
    elif c in ('(', ')'):
        result = 'b'
    return result

def parse_query(doc):
    """对查询进行自定义解析，保留英文串，对中文串原型插入

    Args:
        doc:待解析的原始文档

    Returns:
        解析结果列表，元素是带有串类型标记(首字符，e为英文，c为中文，s为空格，f为引号，b为括号)
        的切分term结果
    """
    doc = doc.lower() + ' ' #解决末位字符状态切换问题的小技巧
    result = []
    doclen = len(doc)
    i = 0
    while True:
        cur_char_type = get_char_type(doc[i])
        for j in range(i+1, doclen):
            if cur_char_type == 'f': #当前符号为引号，找下一个引号
                if get_char_type(doc[j]) == 'f':
                    break
            elif get_char_type(doc[j]) != cur_char_type: #当前符号非引号，找下一个状态变化
                break
        if cur_char_type == 's': #对多个空格连续出现的情况进行合并
            result.append('s ')
        elif cur_char_type == 'f': #对引号只提取引号内字符串
            result.append(cur_char_type + doc[i+1:j])
            j += 1
        else:
            result.append(cur_char_type + doc[i:j])
        i = j
        if i >= doclen - 1:
            break
    return result

parse_query("拜仁慕尼黑0-5 AND 德国杯")

['c拜仁慕尼黑', 'e0-5', 's ', 'eand', 's ', 'c德国杯']

In [142]:
def conv_part(part):
    """将带有类别标记的解析结果段 转化为 eval能进行计算的代码段

    Args:
        part:带有类别标记的解析结果段

    Results:
        eval能进行计算的代码段字符串(调用 term_match() 进行计算)
    """
    flag = part[0]
    if flag == 'e':
        return "self.get_word_match('{}')".format(part[1:])
    elif flag == 'c':
        return "(self.get_word_match('{}'))".format(
            "') & self.get_word_match('".join(jieba.cut(part[1:])))

def query_to_set_expression(query):
    query_new_parts = []
    all_parts = parse_query(query)
    idx = 0
    cache = ''
    count_parts = len(all_parts)
    while idx < count_parts:
        if all_parts[idx][1:] == '(' or all_parts[idx][1:] == ')':
            query_new_parts.append(all_parts[idx])
        elif all_parts[idx][1:] == ' ' or all_parts[idx][1:] == '':
            query_new_parts.append(' ')
        elif all_parts[idx][1:] in ('and', 'AND', '+'):
            query_new_parts.append('&')
        elif all_parts[idx][1:] in ('or', 'OR'):
            query_new_parts.append('|')
        elif all_parts[idx][1:] in ('not', 'NOT', '-'):
            query_new_parts.append('-')
        elif (idx + 1 < count_parts #对连续的内容分段结果集合中间加”&“运算符
              and all_parts[idx+1][1:] not in (' ', ')')): 
            query_new_parts.append("{} & ".format(conv_part(all_parts[idx])))
        elif (idx + 2 < count_parts #处理词间、词与符号间空格的情况
              and all_parts[idx+1][1:] == " " 
              and all_parts[idx+2][1:] not in (')', 'and', 'AND', '+', 'or', 'OR', 'NOT', 'not', '+', '-', ' ')): 
            query_new_parts.append("{} & ".format(conv_part(all_parts[idx])))
            idx += 2
            continue
        else:
            query_new_parts.append(conv_part(cache + all_parts[idx]))
            cache = '' #合并完成清空缓存
        idx += 1
    query_new = ''.join(query_new_parts)
    return query_new

query_to_set_expression("拜仁慕尼黑0-5 AND 德国杯")

"(self.get_word_match('拜仁') & self.get_word_match('慕尼黑')) & self.get_word_match('0-5') & (self.get_word_match('德国杯'))"

In [143]:
del conv_part
del query_to_set_expression
del parse_doc
del get_char_type
del parse_query

In [144]:
class MySearchC6V1(MySearchC6V0):
    """
    C3V0: Base class for Search Engine.
    C3V1: Data multiplication added.
    C3V2: Sorting optimization.
    C3V3: Add lowered version of docs.
    C3V4: For long doc.
    C3V5: Caching search results.
    C3V6: Pre-caching all words in docs.
    C3V7: Add Serialize/UnSerialize.
    C4V1: Add basic Bool query support
    C4V2: Add wordseg to get_word_match()
    ----------------C5V0-----------------
    C5V1: Use VSMTFIDF.score() as score
    C5V2: Use BM25.score() as score
    C5V3: Use MiddleRank -> BM25
    ----------------C6V0-----------------
    C6V1: Add parse_doc to include whole Eng/Num string

    Attributes
    ----------
    filename : str
        file name of doc data
    multi_factor : int
        data multiplication factor(default 1)

    Methods
    -------
    load_data(filename):
        load data from file.
    save_data(filename):
        save data to file
    pre_cache_all():
        Pre-caching all words in docs.
    highlight(text, keyword):
        highlight text with keyword.
    score(text, keyword):
        get score of text for a query.
    get_word_match(self, keyword):
        get doc set containing keyword.
    search(keyword, num=15):
        get top num search results of a query.
    render(result_list, keyword):
        output search results with highlight.
    query_to_set_expression(query):
        convert bool query to set expression(for eval process).
    get_word_match(word):
        get match set of the word.
    def mid_score(query, tid):
        get middle-rank score of doc(tid) according to query
    def cosine(vec1, vec2):
        get cosine similarity between vec1 and vec2
    def dot(vec1, vec2):
        get dot product of vec1 and vec2
    def parse_doc(doc):
        parse doc into terms, including whole Eng/Num string
    def get_char_type(c):
        get type of character c
    def parse_query(q):
        parse query q to Cn/En/Num parts with prefix
    def conv_part(part):
        convert part to set call
    """
    
    def parse_doc(self, doc):
        result = []
        state_last = ''
        cache = ''
        for c in doc:
            state_c = c in string.ascii_letters \
                or c.isdigit() \
                or c in ('-', ':', '.', '：', '/')
            if c == ' ':
                if state_last:
                    result.append(cache)
                else:
                    result.extend(list(jieba.cut_for_search(cache)))
                result.append(' ')
                cache = ''
                state_last = '' 
            else:
                if state_c == state_last:
                    cache += c
                else:
                    if state_last != '':
                        if state_last:
                            result.append(cache)
                        else:
                            result.extend(list(jieba.cut_for_search(cache)))
                    cache = c
                state_last = state_c
        if cache:
            if state_last:
                result.append(cache)
            else:
                result.extend(list(jieba.cut_for_search(cache)))
        return result
    
    def pre_cache_all(self):
        for tid, doc in enumerate(self.docs_lower):
            doc_tf_dict = defaultdict(int)
            doc_word_count = 0
            for word in self.parse_doc(doc):
                self.search_cache[word].add(tid)
                doc_tf_dict[word] += 1
                doc_word_count += 1
            for word in doc_tf_dict:
                doc_tf_dict[word] /= doc_word_count
            self.doc_word_dict.append(doc_tf_dict)
            
    def get_char_type(self, c):
        """返回当前字符的类型(e,c,s,f,b)

        Args:
            c:要进行判断的单个字符

        Results:
            返回判断结果(前缀)：e为英文，c为中文，s为空格，f为引号，b为括号
        """
        result = 'c'
        if c in string.ascii_letters \
                or c.isdigit() \
                or c in ('-', ':', '.', '：', '/'):
            result = 'e'
        elif c == '"':
            result = 'f'
        elif c == ' ':
            result = 's'
        elif c in ('(', ')'):
            result = 'b'
        return result

    def parse_query(self, doc):
        """对查询进行自定义解析，保留英文串，对中文串原型插入

        Args:
            doc:待解析的原始文档

        Returns:
            解析结果列表，元素是带有串类型标记(首字符，e为英文，c为中文，s为空格，f为引号，b为括号)
            的切分term结果
        """
        doc = doc.lower() + ' ' #解决末位字符状态切换问题的小技巧
        result = []
        doclen = len(doc)
        i = 0
        while True:
            cur_char_type = self.get_char_type(doc[i])
            for j in range(i+1, doclen):
                if cur_char_type == 'f': #当前符号为引号，找下一个引号
                    if self.get_char_type(doc[j]) == 'f':
                        break
                elif self.get_char_type(doc[j]) != cur_char_type: #当前符号非引号，找下一个状态变化
                    break
            if cur_char_type == 's': #对多个空格连续出现的情况进行合并
                result.append('s ')
            elif cur_char_type == 'f': #对引号只提取引号内字符串
                result.append(cur_char_type + doc[i+1:j])
                j += 1
            else:
                result.append(cur_char_type + doc[i:j])
            i = j
            if i >= doclen - 1:
                break
        return result
    
    def conv_part(self, part):
        """将带有类别标记的解析结果段 转化为 eval能进行计算的代码段

        Args:
            part:带有类别标记的解析结果段

        Results:
            eval能进行计算的代码段字符串(调用 term_match() 进行计算)
        """
        flag = part[0]
        if flag == 'e':
            return "self.get_term_match('{}')".format(part[1:])
        elif flag == 'c':
            return "(self.get_term_match('{}'))".format(
                "') & self.get_term_match('".join(jieba.cut(part[1:])))

    def query_to_set_expression(self, query):
        query_new_parts = []
        all_parts = self.parse_query(query)
        idx = 0
        cache = ''
        count_parts = len(all_parts)
        while idx < count_parts:
            if all_parts[idx][1:] == '(' or all_parts[idx][1:] == ')':
                query_new_parts.append(all_parts[idx])
            elif all_parts[idx][1:] == ' ' or all_parts[idx][1:] == '':
                query_new_parts.append(' ')
            elif all_parts[idx][1:] in ('and', 'AND', '+'):
                query_new_parts.append('&')
            elif all_parts[idx][1:] in ('or', 'OR'):
                query_new_parts.append('|')
            elif all_parts[idx][1:] in ('not', 'NOT', '-'):
                query_new_parts.append('-')
            elif (idx + 1 < count_parts #对连续的内容分段结果集合中间加”&“运算符
                  and all_parts[idx+1][1:] not in (' ', ')')): 
                query_new_parts.append("{} & ".format(self.conv_part(all_parts[idx])))
            elif (idx + 2 < count_parts #处理词间、词与符号间空格的情况
                  and all_parts[idx+1][1:] == " " 
                  and all_parts[idx+2][1:] not in (')', 'and', 'AND', '+', 'or', 'OR', 'NOT', 'not', '+', '-', ' ')): 
                query_new_parts.append("{} & ".format(self.conv_part(all_parts[idx])))
                idx += 2
                continue
            else:
                query_new_parts.append(self.conv_part(cache + all_parts[idx]))
                cache = '' #合并完成清空缓存
            idx += 1
        query_new = ''.join(query_new_parts)
        return query_new
    
    def get_term_match(self, term):
        return self.search_cache.get(term, set()) 

In [145]:
searcher = MySearchC6V1('titles_l.txt', 1)

In [146]:
query = '拜仁0-5 AND 德国杯'
search_result = searcher.search(query, num=10)
searcher.render(search_result, query)

In [147]:
searcher.query_to_set_expression(query)

"(self.get_term_match('拜仁')) & self.get_term_match('0-5') & (self.get_term_match('德国杯'))"

In [None]:
query = 'to be or not to be'
search_result = searcher.search(query, num=10)
searcher.render(search_result, query)

### 对特殊短语的查询需求，比如“0-3”，“to be or not to be” —— 临近查询 
#### - 采用特殊索引结构

In [148]:
class MySearchC6V2(MySearchC6V1):
    """
    C3V0: Base class for Search Engine.
    C3V1: Data multiplication added.
    C3V2: Sorting optimization.
    C3V3: Add lowered version of docs.
    C3V4: For long doc.
    C3V5: Caching search results.
    C3V6: Pre-caching all words in docs.
    C3V7: Add Serialize/UnSerialize.
    C4V1: Add basic Bool query support
    C4V2: Add wordseg to get_word_match()
    ----------------C5V0-----------------
    C5V1: Use VSMTFIDF.score() as score
    C5V2: Use BM25.score() as score
    C5V3: Use MiddleRank -> BM25
    ----------------C6V0-----------------
    C6V1: Add parse_doc to include whole Eng/Num string
    C6V2: Add 2gram-Inverted-Index and get_frag_match()

    Attributes
    ----------
    filename : str
        file name of doc data
    multi_factor : int
        data multiplication factor(default 1)

    Methods
    -------
    load_data(filename):
        load data from file.
    save_data(filename):
        save data to file
    pre_cache_all():
        Pre-caching all words in docs.
    highlight(text, keyword):
        highlight text with keyword.
    score(text, keyword):
        get score of text for a query.
    get_word_match(self, keyword):
        get doc set containing keyword.
    search(keyword, num=15):
        get top num search results of a query.
    render(result_list, keyword):
        output search results with highlight.
    query_to_set_expression(query):
        convert bool query to set expression(for eval process).
    get_word_match(word):
        get match set of the word.
    def mid_score(query, tid):
        get middle-rank score of doc(tid) according to query
    def cosine(vec1, vec2):
        get cosine similarity between vec1 and vec2
    def dot(vec1, vec2):
        get dot product of vec1 and vec2
    def parse_doc(doc):
        parse doc into terms, including whole Eng/Num string
    def get_char_type(c):
        get type of character c
    def parse_query(q):
        parse query q to Cn/En/Num parts with prefix
    def conv_part(part):
        convert part to set call
    def get_frag_match(frag):
        get docs matching frag
    def get_term_match(term):
        get docs macthing term
    """
    
    def __init__(self, filename, multi_factor=1):
        self.docs = []
        self.docs_lower = []
        self.doc_word_dict = [] #记录文档-词关系
        self.search_cache = defaultdict(set)
        self.search_cache_b = defaultdict(set) #2gram索引
        self.multi_factor = multi_factor
        self.load_data(filename)
    
    def pre_cache_all(self):
        for tid, doc in enumerate(self.docs_lower):
            doc_tf_dict = defaultdict(int)
            doc_word_count = 0
            for word in self.parse_doc(doc):
                self.search_cache[word].add(tid)
                doc_tf_dict[word] += 1
                doc_word_count += 1
            for word in doc_tf_dict:
                doc_tf_dict[word] /= doc_word_count
            self.doc_word_dict.append(doc_tf_dict)
            
            doclen = len(doc)
            for i in range(doclen-1):
                term = doc[i:i+2]
                self.search_cache_b[term].add(tid)

    def get_frag_match(self, frag):
        """对片段frag用ngram索引实现原样搜索
        
        Args:
            frag:要原样搜索的字符串
            
        Results:
            片段原样搜索的结果(文档ID)集合
        """
        frag = frag.lower() #大小写归一化
        result = None
        doclen = len(frag)
        for i in range(doclen - 1):
            term = frag[i:i+2]
            if result is None:
                result = self.search_cache_b.get(term, set())
            else:
                result = result & self.search_cache_b.get(term, set())
        return result

In [149]:
searcher = MySearchC6V2('titles_l.txt', 1)

In [150]:
searcher.get_frag_match('to be or not to be')

{1301}

In [None]:
import nltk
print(nltk.corpus.words.words()[:10])

In [None]:
from collections import Counter
import nltk

class Corrector():
    """用二元索引实现拼写校正
    
    Attributes:
        index_b: 检索使用的二元索引
        max_id: 当前索引的单词最大ID
        doc_list: 索引单词原文
    """
    def __init__(self): 
        """初始化，用NLTK的words词典构建倒排索引
        """
        self.index_b = dict() #ngram索引
        self.max_id = 0
        self.doc_list = [] 
        
        for doc in nltk.corpus.words.words():
            self.add_doc(doc)
            
    def add_doc(self, doc):
        """向索引中添加新词(单词的二元索引)
        
        Args:
            doc:待检索的单词
        
        Returns:
            新增单词ID
        """
        self.doc_list.append(doc)
        doc = doc.lower()
        
        #构建二元索引
        doclen = len(doc)
        for i in range(doclen-1):
            term = doc[i:i+2]
            if term in self.index_b: 
                self.index_b[term].append(self.max_id)
            else:
                self.index_b[term] = [self.max_id]
                
        self.max_id += 1
        return self.max_id - 1
    
    def correct(self, word, limit=5):
        """拼写校正函数
        
        Args:
            word:待校正的词
            limit:返回结果的最大条数，默认值为5
            
        Returns:
            最可能的校正单词列表
        """
        word = word.lower() #大小写归一化
        result = []
        docid_list = []
        doclen = len(word)
        for i in range(doclen - 1):
            term = word[i:i+2]
            docid_list += self.index_b.get(term, [])
        docid_counter = Counter(docid_list)
        count = 0
        for elem in docid_counter.most_common(300):
            cor_word = self.doc_list[elem[0]]
            if len(cor_word) >= doclen - 1 and len(cor_word) <= doclen + 1:
                result.append(cor_word)
                count += 1
                if count > limit:
                    break
        return result

In [None]:
cor = Corrector()

In [None]:
print(cor.correct('retrival'))

In [None]:
doc = '华为Mate30采用安卓系统'

n = 6
word_set = set(
    ['华为' ,'安卓', '安卓系统'])

#正向最大分词
i = 0
result_f = []
while True:
    end_idx = i + n
    if end_idx > len(doc):
        end_idx = len(doc)
    for j in range(end_idx, i, -1):
        if doc[i:j] in word_set:
            break
    result_f.append(doc[i:j])
    i = j
    if i == len(doc):
        break
print('|'.join(result_f))

#逆向最大分词
i = len(doc)
result_b = []
while True:
    end_idx = i - n
    if end_idx < 0:
        end_idx = 0
    for j in range(end_idx, i):
#         print(j,i,doc[j:i])
        if doc[j:i] in word_set:
            break
    result_b.insert(0, doc[j:i])
    i = j
    if i == 0:
        break
print('|'.join(result_b))