In [1]:
import pandas as pd

from tqdm import tqdm
from collections import defaultdict

### 一、数据读入

In [2]:
train_df = pd.read_json('../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CHIP-CDN/CHIP-CDN_train.json')

In [14]:
train_df.head()

Unnamed: 0,text,normalized_result
0,左膝退变伴游离体,膝骨关节病##膝关节游离体
1,糖尿病反复低血糖;骨质疏松;高血压冠心病不稳定心绞痛,糖尿病性低血糖症##骨质疏松##高血压##冠状动脉粥样硬化性心脏病##不稳定性心绞痛
2,右乳腺癌IV期,乳腺恶性肿瘤##癌
3,头痛.头晕.高血压,头痛##头晕##高血压
4,骶裂半大便失控,骶椎椎板裂##大便失禁


In [8]:
icd_df = pd.read_excel('../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CHIP-CDN/国际疾病分类 ICD-10北京临床版v601.xlsx',
                        header=None
                      )

In [10]:
icd_df.columns = 'code name'.split()

In [11]:
icd_df.head(2)

Unnamed: 0,code,name
0,A00,霍乱
1,A00.0,"霍乱,由于01群霍乱弧菌,霍乱生物型所致"


In [12]:
map_dict = defaultdict(set)

In [13]:
for _text in icd_df['name']:
    map_dict[_text].add(_text)

In [15]:
for _text in train_df['normalized_result']:
    for _label in _text.split('##'):
        map_dict[_label].add(_label)

In [16]:
for _text, _labels in zip(train_df['text'], train_df['normalized_result']):
    for _label in _labels.split('##'):
        map_dict[_text].add(_label)

### 二、召回模型构建

In [17]:
import math
import copy
import logging
import numpy as np

from six import iteritems


logger = logging.getLogger(__name__)


class BM25(object):
    """
    BM25模型

    Args:
        corpus (:obj:`list`):
            检索的语料
        k1 (:obj:`float`, optional, defaults to 1.5):
            取正值的调优参数，用于文档中的词项频率进行缩放控制
        b (:obj:`float`, optional, defaults to 0.75):
            0到1之间的参数，决定文档长度的缩放程度，b=1表示基于文档长度对词项权重进行完全的缩放，b=0表示归一化时不考虑文档长度因素
        epsilon (:obj:`float`, optional, defaults to 0.25):
            idf的下限值
        tokenizer (:obj:`object`, optional, defaults to None):
            分词器，用于对文档进行分词操作，默认为None，按字颗粒对文档进行分词
        is_retain_docs (:obj:`bool`, optional, defaults to False):
            是否保持原始文档

    Reference:
        [1] https://github.com/RaRe-Technologies/gensim/blob/3.8.3/gensim/summarization/bm25.py
    """  # noqa: ignore flake8"

    def __init__(
        self,
        corpus,
        k1=1.5,
        b=0.75,
        epsilon=0.25,
        tokenizer=None,
        is_retain_docs=False
    ):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon

        self.docs = None
        self.corpus_size = 0
        self.avgdl = 0
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []

        if is_retain_docs:
            self.docs = copy.deepcopy(corpus)

        if tokenizer:
            corpus = [self.tokenizer.tokenize(document) for document in corpus]
        else:
            corpus = [list(document) for document in corpus]

        self._initialize(corpus)

    def _initialize(self, corpus):
        """Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies."""
        nd = {}  # word -> number of documents with word
        num_doc = 0
        for document in corpus:                        
            self.corpus_size += 1
            self.doc_len.append(len(document))
            num_doc += len(document)

            frequencies = {}
            for word in document:
                if word not in frequencies:
                    frequencies[word] = 0
                frequencies[word] += 1
            self.doc_freqs.append(frequencies)

            for word, freq in iteritems(frequencies):
                if word not in nd:
                    nd[word] = 0
                nd[word] += 1

        self.avgdl = float(num_doc) / self.corpus_size

        idf_sum = 0
        negative_idfs = []
        for word, freq in iteritems(nd):
            idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
            self.idf[word] = idf
            idf_sum += idf
            if idf < 0:
                negative_idfs.append(word)
        self.average_idf = float(idf_sum) / len(self.idf)

        if self.average_idf < 0:
            logger.warning(
                'Average inverse document frequency is less than zero. Your corpus of {} documents'
                ' is either too small or it does not originate from natural text. BM25 may produce'
                ' unintuitive results.'.format(self.corpus_size)
            )

        eps = self.epsilon * self.average_idf
        for word in negative_idfs:
            self.idf[word] = eps

    def get_score(self, query, index):
        score = 0.0
        doc_freqs = self.doc_freqs[index]
        numerator_constant = self.k1 + 1
        denominator_constant = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
        for word in query:
            if word in doc_freqs:
                df = self.doc_freqs[index][word]
                idf = self.idf[word]
                score += (idf * df * numerator_constant) / (df + denominator_constant)
        return score

    def get_scores(self, query):
        scores = [self.get_score(query, index) for index in range(self.corpus_size)]
        return scores

    def recall(self, query, topk=5):
        scores = self.get_scores(query)
        indexs = np.argsort(scores)[::-1][:topk]

        if self.docs is None:
            return [[i, scores[i]] for i in indexs]
        else:
            return [[self.docs[i], scores[i]] for i in indexs]


In [18]:
bm25_model = BM25([_text for _text, _ in map_dict.items()], is_retain_docs=True)

### 三、召回率评估

In [19]:
dev_data_df = pd.read_json('../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CHIP-CDN/CHIP-CDN_dev.json')

a_label = []
new_train_data = []
recall_ = 0 
query_counter = 0
miss_list = []

for text_, normalized_result_ in tqdm(zip(dev_data_df['text'], dev_data_df['normalized_result'])):
    query_counter += 1
    
    result = set([_result for _results in bm25_model.recall(text_, topk=200) for _result in map_dict[_results[0]]])
            
    if len(set(normalized_result_.split('##')) & result) != len(set(normalized_result_.split('##'))):
        miss_list.append([text_, normalized_result_])
        continue
        
    recall_ += 1
    
print('召回率为： ', recall_/query_counter)

# 召回率为：  0.9135

2000it [02:28, 13.46it/s]

召回率为：  0.9135





### 四、模型保存

In [21]:
# !mkdir -p checkpoint/recall

In [22]:
import pickle

with open('checkpoint/recall/bm25_model.pkl', "wb") as f:
    pickle.dump(bm25_model, f)
    
with open('checkpoint/recall/map_dict.pkl', "wb") as f:
    pickle.dump(map_dict, f)