In [1]:
import torch
import pickle
import pandas as pd

from ark_nlp.model.tm.bert import Bert
from ark_nlp.model.tm.bert import BertConfig
from ark_nlp.model.tm.bert import Dataset
from ark_nlp.model.tm.bert import Task
from ark_nlp.model.tm.bert import get_default_model_optimizer
from ark_nlp.model.tm.bert import Tokenizer

### 一、数据读入与处理

#### 1. 召回模型

In [3]:
import math
import copy
import logging
import numpy as np

from six import iteritems


logger = logging.getLogger(__name__)


class BM25(object):
    """
    BM25模型

    Args:
        corpus (:obj:`list`):
            检索的语料
        k1 (:obj:`float`, optional, defaults to 1.5):
            取正值的调优参数，用于文档中的词项频率进行缩放控制
        b (:obj:`float`, optional, defaults to 0.75):
            0到1之间的参数，决定文档长度的缩放程度，b=1表示基于文档长度对词项权重进行完全的缩放，b=0表示归一化时不考虑文档长度因素
        epsilon (:obj:`float`, optional, defaults to 0.25):
            idf的下限值
        tokenizer (:obj:`object`, optional, defaults to None):
            分词器，用于对文档进行分词操作，默认为None，按字颗粒对文档进行分词
        is_retain_docs (:obj:`bool`, optional, defaults to False):
            是否保持原始文档

    Reference:
        [1] https://github.com/RaRe-Technologies/gensim/blob/3.8.3/gensim/summarization/bm25.py
    """  # noqa: ignore flake8"

    def __init__(
        self,
        corpus,
        k1=1.5,
        b=0.75,
        epsilon=0.25,
        tokenizer=None,
        is_retain_docs=False
    ):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon

        self.docs = None
        self.corpus_size = 0
        self.avgdl = 0
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []

        if is_retain_docs:
            self.docs = copy.deepcopy(corpus)

        if tokenizer:
            corpus = [self.tokenizer.tokenize(document) for document in corpus]
        else:
            corpus = [list(document) for document in corpus]

        self._initialize(corpus)

    def _initialize(self, corpus):
        """Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies."""
        nd = {}  # word -> number of documents with word
        num_doc = 0
        for document in corpus:                        
            self.corpus_size += 1
            self.doc_len.append(len(document))
            num_doc += len(document)

            frequencies = {}
            for word in document:
                if word not in frequencies:
                    frequencies[word] = 0
                frequencies[word] += 1
            self.doc_freqs.append(frequencies)

            for word, freq in iteritems(frequencies):
                if word not in nd:
                    nd[word] = 0
                nd[word] += 1

        self.avgdl = float(num_doc) / self.corpus_size

        idf_sum = 0
        negative_idfs = []
        for word, freq in iteritems(nd):
            idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
            self.idf[word] = idf
            idf_sum += idf
            if idf < 0:
                negative_idfs.append(word)
        self.average_idf = float(idf_sum) / len(self.idf)

        if self.average_idf < 0:
            logger.warning(
                'Average inverse document frequency is less than zero. Your corpus of {} documents'
                ' is either too small or it does not originate from natural text. BM25 may produce'
                ' unintuitive results.'.format(self.corpus_size)
            )

        eps = self.epsilon * self.average_idf
        for word in negative_idfs:
            self.idf[word] = eps

    def get_score(self, query, index):
        score = 0.0
        doc_freqs = self.doc_freqs[index]
        numerator_constant = self.k1 + 1
        denominator_constant = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
        for word in query:
            if word in doc_freqs:
                df = self.doc_freqs[index][word]
                idf = self.idf[word]
                score += (idf * df * numerator_constant) / (df + denominator_constant)
        return score

    def get_scores(self, query):
        scores = [self.get_score(query, index) for index in range(self.corpus_size)]
        return scores

    def recall(self, query, topk=5):
        scores = self.get_scores(query)
        indexs = np.argsort(scores)[::-1][:topk]

        if self.docs is None:
            return [[i, scores[i]] for i in indexs]
        else:
            return [[self.docs[i], scores[i]] for i in indexs]

bm25_model = pickle.load(open('checkpoint/recall/bm25_model.pkl', 'rb'))
map_dict = pickle.load(open('checkpoint/recall/map_dict.pkl', 'rb'))

#### 2. 数据生成

In [4]:
train_data_df = pd.read_json('../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CHIP-CDN/CHIP-CDN_train.json')
dev_data_df = pd.read_json('../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CHIP-CDN/CHIP-CDN_dev.json')

In [5]:
pair_dataset = []
for _raw_word, _normalized_result in zip(train_data_df['text'], train_data_df['normalized_result']):
    normalized_words = set(_normalized_result.split('##'))
    search_result_ = set()
    train_pair_dataset = []
    for _index, _search_word in enumerate(
        [_result for _results in bm25_model.recall(_raw_word, topk=1000) for _result in map_dict[_results[0]]]):

        if _search_word in normalized_words:
            continue
        elif _search_word in search_result_:
            continue
        else:
            train_pair_dataset.append([_raw_word, _search_word, '0'])
            
        search_result_.add(_search_word)
            
        if len(train_pair_dataset) == 20:
            pair_dataset.extend(train_pair_dataset)
            break
                    
    for _st_word in normalized_words:
        for _ in range(10):
            pair_dataset.append([_raw_word, _st_word, '1'])

In [6]:
pair_dev_dataset = []
for _raw_word, _normalized_result in zip(train_data_df['text'], train_data_df['normalized_result']):
    normalized_words = set(_normalized_result.split('##'))
    search_result_ = set()
    dev_pair_dataset = []
    for _index, _search_word in enumerate(
        [_result for _results in bm25_model.recall(_raw_word, topk=1000) for _result in map_dict[_results[0]]]):

        if _search_word in normalized_words:
            continue
        elif _search_word in search_result_:
            continue
        else:
            dev_pair_dataset.append([_raw_word, _search_word, '0'])
            
        search_result_.add(_search_word)
            
        if len(dev_pair_dataset) == 1:
            pair_dev_dataset.extend(dev_pair_dataset)
            break
                    
    for _st_word in normalized_words:
        pair_dev_dataset.append([_raw_word, _st_word, '1'])

In [7]:
train_data_df = pd.DataFrame(pair_dataset, columns=['text_a', 'text_b', 'label'])
dev_data_df = pd.DataFrame(pair_dev_dataset, columns=['text_a', 'text_b', 'label'])

In [8]:
tm_train_dataset = Dataset(train_data_df)
tm_dev_dataset = Dataset(dev_data_df)

#### 2. 词典创建

In [9]:
tokenizer = Tokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=50)

#### 3. 生成分词器

#### 4. ID化

In [10]:
tm_train_dataset.convert_to_ids(tokenizer)
tm_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [11]:
config = BertConfig.from_pretrained('nghuyong/ernie-1.0',
                                    num_labels=len(tm_train_dataset.cat2id))

#### 2. 模型创建

In [12]:
torch.cuda.empty_cache()

In [13]:
dl_module = Bert.from_pretrained('nghuyong/ernie-1.0', 
                                 config=config)

Some weights of the model checkpoint at nghuyong/ernie-1.0 were not used when initializing Bert: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing Bert from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Bert from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Bert were not initialized from the model checkpoint at nghuyong/ernie-1.0 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [14]:
# 设置运行次数
num_epoches = 2
batch_size = 32

In [15]:
param_optimizer = list(dl_module.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]      

#### 2. 任务创建

In [16]:
model = Task(dl_module, 'adamw', 'lsce', cuda_device=3, ema_decay=0.995)

#### 3. 训练

In [None]:
model.fit(tm_train_dataset, 
          tm_dev_dataset,
          lr=3e-5,
          epochs=num_epoches, 
          batch_size=batch_size,
          params=optimizer_grouped_parameters
         )

  1% 100/7161 [00:26<31:18,  3.76it/s]

[99/7161],train loss is:0.536521,train evaluation is:0.770625


  3% 200/7161 [00:53<31:20,  3.70it/s]

[199/7161],train loss is:0.504741,train evaluation is:0.796406


  4% 300/7161 [01:20<31:13,  3.66it/s]

[299/7161],train loss is:0.479518,train evaluation is:0.816771


  6% 400/7161 [01:48<31:09,  3.62it/s]

[399/7161],train loss is:0.464169,train evaluation is:0.829453


  7% 500/7161 [02:15<30:46,  3.61it/s]

[499/7161],train loss is:0.453538,train evaluation is:0.838688


  8% 600/7161 [02:43<30:17,  3.61it/s]

[599/7161],train loss is:0.444477,train evaluation is:0.845990


 10% 700/7161 [03:11<29:52,  3.60it/s]

[699/7161],train loss is:0.437269,train evaluation is:0.851295


 11% 800/7161 [03:38<29:33,  3.59it/s]

[799/7161],train loss is:0.430789,train evaluation is:0.855820


 13% 900/7161 [04:06<29:16,  3.57it/s]

[899/7161],train loss is:0.424376,train evaluation is:0.859688


 14% 1000/7161 [04:34<28:50,  3.56it/s]

[999/7161],train loss is:0.420098,train evaluation is:0.863500


 15% 1100/7161 [05:03<28:19,  3.57it/s]

[1099/7161],train loss is:0.414926,train evaluation is:0.867074


 17% 1200/7161 [05:31<28:01,  3.54it/s]

[1199/7161],train loss is:0.411834,train evaluation is:0.869115


 18% 1300/7161 [05:59<27:36,  3.54it/s]

[1299/7161],train loss is:0.408520,train evaluation is:0.871635


 20% 1400/7161 [06:27<27:09,  3.53it/s]

[1399/7161],train loss is:0.405696,train evaluation is:0.873817


 21% 1500/7161 [06:55<26:46,  3.52it/s]

[1499/7161],train loss is:0.401988,train evaluation is:0.876313


 22% 1600/7161 [07:24<26:10,  3.54it/s]

[1599/7161],train loss is:0.399567,train evaluation is:0.877988


 24% 1700/7161 [07:52<25:40,  3.55it/s]

[1699/7161],train loss is:0.396465,train evaluation is:0.880276


 25% 1800/7161 [08:20<25:07,  3.56it/s]

[1799/7161],train loss is:0.393469,train evaluation is:0.882083


 27% 1900/7161 [08:48<24:44,  3.54it/s]

[1899/7161],train loss is:0.391088,train evaluation is:0.883503


 28% 2000/7161 [09:17<24:06,  3.57it/s]

[1999/7161],train loss is:0.388921,train evaluation is:0.885047


 29% 2100/7161 [09:45<23:36,  3.57it/s]

[2099/7161],train loss is:0.386407,train evaluation is:0.886682


 31% 2200/7161 [10:13<23:13,  3.56it/s]

[2199/7161],train loss is:0.384940,train evaluation is:0.887727


 32% 2300/7161 [10:41<22:44,  3.56it/s]

[2299/7161],train loss is:0.383162,train evaluation is:0.888886


 34% 2400/7161 [11:09<22:18,  3.56it/s]

[2399/7161],train loss is:0.381412,train evaluation is:0.890026


 35% 2500/7161 [11:37<21:44,  3.57it/s]

[2499/7161],train loss is:0.379550,train evaluation is:0.891387


 36% 2600/7161 [12:05<21:22,  3.56it/s]

[2599/7161],train loss is:0.377604,train evaluation is:0.892692


 38% 2700/7161 [12:33<20:53,  3.56it/s]

[2699/7161],train loss is:0.375888,train evaluation is:0.893831


 39% 2800/7161 [13:01<20:33,  3.53it/s]

[2799/7161],train loss is:0.374172,train evaluation is:0.895033


 40% 2900/7161 [13:29<20:02,  3.54it/s]

[2899/7161],train loss is:0.372120,train evaluation is:0.896433


 42% 3000/7161 [13:57<19:28,  3.56it/s]

[2999/7161],train loss is:0.370925,train evaluation is:0.897260


 43% 3100/7161 [14:25<18:57,  3.57it/s]

[3099/7161],train loss is:0.369513,train evaluation is:0.898135


 45% 3200/7161 [14:53<18:27,  3.58it/s]

[3199/7161],train loss is:0.367987,train evaluation is:0.899121


 46% 3300/7161 [15:21<18:02,  3.57it/s]

[3299/7161],train loss is:0.366034,train evaluation is:0.900379


 47% 3400/7161 [15:49<17:34,  3.57it/s]

[3399/7161],train loss is:0.365053,train evaluation is:0.900993


 49% 3500/7161 [16:17<17:07,  3.56it/s]

[3499/7161],train loss is:0.363869,train evaluation is:0.901795


 50% 3600/7161 [16:45<16:39,  3.56it/s]

[3599/7161],train loss is:0.362318,train evaluation is:0.902778


 52% 3700/7161 [17:14<16:11,  3.56it/s]

[3699/7161],train loss is:0.360965,train evaluation is:0.903606


 53% 3800/7161 [17:42<15:37,  3.58it/s]

[3799/7161],train loss is:0.359614,train evaluation is:0.904457


 54% 3900/7161 [18:10<15:11,  3.58it/s]

[3899/7161],train loss is:0.358287,train evaluation is:0.905353


 56% 4000/7161 [18:38<14:45,  3.57it/s]

[3999/7161],train loss is:0.357015,train evaluation is:0.906203


 57% 4100/7161 [19:06<14:16,  3.57it/s]

[4099/7161],train loss is:0.355961,train evaluation is:0.906799


 59% 4200/7161 [19:34<13:48,  3.57it/s]

[4199/7161],train loss is:0.354777,train evaluation is:0.907612


 60% 4300/7161 [20:02<13:30,  3.53it/s]

[4299/7161],train loss is:0.353956,train evaluation is:0.908154


 61% 4400/7161 [20:30<12:58,  3.55it/s]

[4399/7161],train loss is:0.352900,train evaluation is:0.908849


 63% 4500/7161 [20:58<12:28,  3.56it/s]

[4499/7161],train loss is:0.351869,train evaluation is:0.909542


 64% 4600/7161 [21:26<12:00,  3.56it/s]

[4599/7161],train loss is:0.351026,train evaluation is:0.910217


 66% 4700/7161 [21:54<11:31,  3.56it/s]

[4699/7161],train loss is:0.350330,train evaluation is:0.910645


 67% 4800/7161 [22:22<10:59,  3.58it/s]

[4799/7161],train loss is:0.349470,train evaluation is:0.911243


 68% 4900/7161 [22:50<10:33,  3.57it/s]

[4899/7161],train loss is:0.348573,train evaluation is:0.911792


 70% 5000/7161 [23:18<10:05,  3.57it/s]

[4999/7161],train loss is:0.347706,train evaluation is:0.912369


 71% 5100/7161 [23:46<09:38,  3.56it/s]

[5099/7161],train loss is:0.346848,train evaluation is:0.912935


 73% 5200/7161 [24:14<09:09,  3.57it/s]

[5199/7161],train loss is:0.346106,train evaluation is:0.913456


 74% 5300/7161 [24:42<08:42,  3.56it/s]

[5299/7161],train loss is:0.345357,train evaluation is:0.913945


 75% 5400/7161 [25:10<08:14,  3.56it/s]

[5399/7161],train loss is:0.344600,train evaluation is:0.914421


 77% 5500/7161 [25:38<07:49,  3.54it/s]

[5499/7161],train loss is:0.343725,train evaluation is:0.915006


 78% 5600/7161 [26:07<07:17,  3.57it/s]

[5599/7161],train loss is:0.342855,train evaluation is:0.915614


 80% 5700/7161 [26:35<06:49,  3.57it/s]

[5699/7161],train loss is:0.342163,train evaluation is:0.916069


 80% 5746/7161 [26:47<06:36,  3.57it/s]

In [None]:
model.ema.store(model.module.parameters())
model.ema.copy_to(model.module.parameters())  

<br>

### 四、模型验证与保存

#### 1. 模型验证

In [None]:
from ark_nlp.factory.predictor import TMPredictor

In [None]:
tm_predictor_instance = TMPredictor(model.module, tokenizer, tm_train_dataset.cat2id)

In [None]:
tm_predictor_instance.predict_one_sample(['胸部皮肤破裂伤', '胸部开放性损伤'], return_proba=True)

#### 2. 模型保存

In [None]:
!mkdir -p checkpoint/textsim

In [None]:
import pickle

In [None]:
torch.save(model.module.state_dict(), 'checkpoint/textsim/module.pth')

In [None]:
with open('checkpoint/textsim/cat2id.pkl', "wb") as f:
    pickle.dump(tm_train_dataset.cat2id, f)