In [13]:
import requests  
from lxml import etree 
import pickle
import os
from IPython.core.display import display, HTML
import timeit
import jieba
import math

class MySearcherC12V0:
    """
    第十一次课升级的搜索类版本：
    改善文档频和文档长度加权的影响
    改善IDF权值
    采用BM25打分函数
    """
    def __init__(self, scale=1):
        self.docs = [] #所有文档原始数据
        self.load_data()
        self.docs *= scale #文档倍增
        self.cache = {} #缓存(索引)
        self.vocab = set() #缓存(索引)词表
        self.lower_preprocess()
        self.df = {}
        self.avgdl = 0
        self.build_cache() #构建缓存
        jieba.load_userdict('dict.txt')
    
    def build_cache(self):
        doc_id = 0
        doc_length_sum = 0
        for doc in self.docs:
            doc_word_set = set()
            doc_length_sum += len(doc[3])
            for word in jieba.cut_for_search(
                doc[3]
            ):
                if word not in doc_word_set:
                    result_item = doc_id
                    if word not in self.cache:
                        self.cache[word] = set([result_item])
                    else:
                        self.cache[word].add(result_item)
                    self.vocab.add(word)
                    doc_word_set.add(word)
                    
                    if word in self.df:
                        self.df[word] += 1
                    else:
                        self.df[word] = 1
            doc_id += 1
        self.avgdl = doc_length_sum / len(self.docs)
    
    def search(self, query):
        result = None
        for keyword in jieba.cut(query.lower()):
            if keyword in self.cache:
                if result is None:
                    result = self.cache[keyword]
                else:
                    result = result & self.cache[keyword]
            else:
                result = set([])
                break
                
        if result is None:
            result = set([])
        
        sorted_result = self.rank(query, result)
        return sorted_result
                    
    def lower_preprocess(self):
        for doc_id in range(len(self.docs)):
            self.docs[doc_id].append(
                (self.docs[doc_id][1] 
                 + ' ' 
                 + self.docs[doc_id][2]).lower()
            )
    
    def simple_test(self):
        assert(len(self.search('tiktok')) > 1)
    
    def load_data(self):
        data_filename = 'news_list.dat'
        if os.path.exists(data_filename):
            with open(data_filename,'rb') as f:
                self.docs += pickle.load(f)
#                 self.docs = self.docs + pickle.load(f)
        else:
            url = 'http://news.163.com/special/0001386F/rank_tech.html'  
            headers = {'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36 Edg/85.0.564.63'}
            r = requests.get(url, headers=headers)  
            sel = etree.HTML(r.text) 
            link_set = set()
            news_list = []
            count = 0
            for item in sel.xpath('//td/a'):  
                title = item.text
                link = item.attrib['href']
                if link not in link_set:
                    r = requests.get(link, headers=headers)  
                    sel = etree.HTML(r.text)  
                    text_block = sel.xpath('//div[@id="endText"]') 
                    if text_block:
                        content = ''.join(text_block[0].xpath('./p/text()'))
                        title = sel.xpath('//h1/text()')[0]
                        self.docs.append([link, title, content])
                    link_set.add(link)
                count += 1
                if count % 15 == 0:
                    print(count, 'processed.')
            with open(data_filename,'wb') as f:
                pickle.dump(self.docs, f)
    
    def highlight(self, item, query, sidelen=12):
        result = ''
        positions = []
        content_lower = item[2].lower()
        word_start_map = []
        word_end_map = []
        last_word_end = -1
        query_words = list(jieba.cut(query))
        for keyword in query_words:
            idx = content_lower.find(keyword.lower())
            positions.append(idx)

        for keyword in jieba.cut(content_lower):
            cur_word_start = last_word_end + 1
            cur_word_end = cur_word_start + len(keyword) - 1
            for i in range(cur_word_start, cur_word_end + 1):
                word_start_map.append(cur_word_start)
                word_end_map.append(cur_word_end)
            last_word_end = cur_word_end

        positions.sort()
        segments = []
        i = 0
        while i < len(positions):
            start_pos = max(positions[i] - sidelen, 0)
            end_pos = min(positions[i] + sidelen, len(content_lower) - 1)
            while (i < len(positions) - 1 
                and positions[i+1] - positions[i] <= 2 * sidelen):
                end_pos = min(positions[i+1] + sidelen, len(content_lower) - 1)
                i += 1  
            start_ddd = '...' if start_pos > 0 else ''
            end_ddd = '...' if end_pos < len(content_lower) else ''
            segments.append(start_ddd 
                            + item[2][word_start_map[start_pos]:word_end_map[end_pos] + 1]
                            + end_ddd)
            i += 1
        
        result = text = item[1] + '<br/>' + ''.join(segments)
        text_lower = text.lower()
        for keyword in query_words:
            idx = text_lower.find(keyword.lower())
            if idx >= 0:
                ori_word = text[idx:idx+(len(keyword))]
                result = result.replace(ori_word, '<span style="color:red";>{}</span>'.format(ori_word))
        return result
    
    def render_search_result(self, query):
        count = 0
        for item in self.search(query)[:10]:
            count += 1
            display(HTML('{} [{}] {}'.format(count, item[1], 
                self.highlight(self.docs[item[0]], query))))
    
    def rank(self, query, result_set):
        result = []
        for doc_id in result_set:
            result.append([doc_id, 
                self.score(self.docs[doc_id],
                          query)])
        result.sort(key=lambda x: x[1], reverse=True)
        return result       
    
    def score(self, item, query, k1=2, b=0.75):
        score = 0
        #todo cut
        for keyword in jieba.cut(query):
            f = item[2].lower().count(keyword.lower())
            dl = len(item[2])
            tf = f * (k1 + 1) / (f + k1 * (1 - b + b * (dl / self.avgdl)))
            idf = math.log10((len(self.docs) - self.df[keyword] + 0.5 / (0.5 + self.df[keyword])))
            score += tf * idf
        return score

In [14]:
searcherv0 = MySearcherC12V0()

In [15]:
searcherv0.render_search_result('华为基站')

In [11]:
q = '(华为 or 苹果) and 5g手机'

q_cut_parts = list(jieba.cut(q))

result_l = []
for part in q_cut_parts:
    if part == '(' or part == ')':
        result_l.append(part)
    elif part == 'and':
        result_l.append('&')
    elif part == 'or':
        result_l.append('|')
    elif part == 'not':
        result_l.append('-')
    elif part == ' ':
        pass
    else:
        result_l.append(str(searcherv0.cache[part]))
# print(result_l)
result = ''.join(result_l)
print(result)

({0, 3, 133, 134, 143, 146, 150, 23, 26, 158, 34, 165, 41, 50, 52, 53, 180, 57, 58, 185, 63, 193, 66, 77, 80, 83, 227, 231, 104, 108, 114, 122, 123}|{256, 1, 3, 4, 133, 6, 7, 134, 9, 265, 267, 12, 13, 141, 142, 16, 271, 23, 151, 279, 28, 29, 157, 32, 33, 34, 39, 168, 42, 170, 53, 182, 55, 56, 187, 60, 61, 190, 63, 69, 197, 73, 74, 81, 210, 83, 254, 215, 89, 220, 98, 231, 238, 113, 246, 247, 124, 126, 127})&{1, 3, 5, 9, 13, 15, 16, 22, 23, 28, 29, 31, 32, 33, 39, 42, 44, 47, 50, 53, 56, 57, 58, 61, 63, 65, 68, 70, 72, 73, 77, 78, 79, 80, 81, 82, 83, 88, 89, 94, 101, 104, 108, 109, 112, 114, 117, 120, 123, 124, 126, 133, 134, 141, 146, 151, 156, 168, 170, 190, 191, 197, 201, 207, 210, 212, 220, 228, 231, 232, 236, 238, 239, 247, 249, 250, 254, 263, 268, 271, 278, 279}


In [12]:
eval(result)

{1,
 3,
 9,
 13,
 16,
 23,
 28,
 29,
 32,
 33,
 39,
 42,
 50,
 53,
 56,
 57,
 58,
 61,
 63,
 73,
 77,
 80,
 81,
 83,
 89,
 104,
 108,
 114,
 123,
 124,
 126,
 133,
 134,
 141,
 146,
 151,
 168,
 170,
 190,
 197,
 210,
 220,
 231,
 238,
 247,
 254,
 271,
 279}

In [16]:
q = '(华为 or 苹果) and 5g手机'

In [19]:
!pip install pysnooper -i https://pypi.tuna.tsinghua.edu.cn/simple

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [29]:
import pysnooper

# @pysnooper.snoop()
def conv_query(query):
    """
    将布尔查询转化成集合运算表达式
    """
    query += ' '
    qlen = len(query)
    idx = 0
    result_parts = []
    cache = ''
    while idx < qlen:
        if query[idx] in ('(', ')', ' '):
            if cache != '':
                if cache == 'and' or cache == 'AND':
                    result_parts.append('&')
                elif cache == 'or' or cache == 'OR':
                    result_parts.append('|')
                elif cache == 'not' or cache == 'NOT':
                    result_parts.append('-')
                else:
                    result_parts.append(
                        'get_phrase_match("{}")'.format(cache)
                    )
                cache = ''
            result_parts.append(query[idx])
        else:
            cache += query[idx]
        idx += 1
#     if cache != '':
#         result_parts.append(cache)
    return ''.join(result_parts)

print(conv_query(q))

(get_phrase_match("华为") | get_phrase_match("苹果")) & get_phrase_match("5g手机") 


In [33]:
def get_phrase_match(phrase):
    result = {}
    wid = 0
    for word in jieba.cut(phrase):
        if word not in searcherv0.cache:
            result = {}
            break
        if wid == 0:
            result = searcherv0.cache.get(word, {})
        else:
            if len(result) == 0:
                break
            result = result & searcherv0.cache.get(word, {})
        wid += 1
    return result

print(get_phrase_match('5g手机'))

{3, 134, 141, 15, 16, 271, 23, 28, 29, 33, 44, 50, 53, 56, 58, 190, 191, 70, 72, 73, 77, 78, 79, 80, 81, 82, 83, 210, 88, 89, 94, 104, 108, 109, 114, 123, 254}


In [30]:
searcherv0

<__main__.MySearcherC12V0 at 0x23f61b8ab20>