In [1]:
import torch
import pickle
import pandas as pd

from ark_nlp.model.tm.bert import Bert
from ark_nlp.model.tm.bert import BertConfig
from ark_nlp.model.tm.bert import Dataset
from ark_nlp.model.tm.bert import Task
from ark_nlp.model.tm.bert import get_default_model_optimizer
from ark_nlp.model.tm.bert import Tokenizer

### 一、数据读入与处理

#### 1. 召回模型

In [3]:
import math
import copy
import logging
import numpy as np

from six import iteritems


logger = logging.getLogger(__name__)


class BM25(object):
    """
    BM25模型

    Args:
        corpus (:obj:`list`):
            检索的语料
        k1 (:obj:`float`, optional, defaults to 1.5):
            取正值的调优参数，用于文档中的词项频率进行缩放控制
        b (:obj:`float`, optional, defaults to 0.75):
            0到1之间的参数，决定文档长度的缩放程度，b=1表示基于文档长度对词项权重进行完全的缩放，b=0表示归一化时不考虑文档长度因素
        epsilon (:obj:`float`, optional, defaults to 0.25):
            idf的下限值
        tokenizer (:obj:`object`, optional, defaults to None):
            分词器，用于对文档进行分词操作，默认为None，按字颗粒对文档进行分词
        is_retain_docs (:obj:`bool`, optional, defaults to False):
            是否保持原始文档

    Reference:
        [1] https://github.com/RaRe-Technologies/gensim/blob/3.8.3/gensim/summarization/bm25.py
    """  # noqa: ignore flake8"

    def __init__(
        self,
        corpus,
        k1=1.5,
        b=0.75,
        epsilon=0.25,
        tokenizer=None,
        is_retain_docs=False
    ):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon

        self.docs = None
        self.corpus_size = 0
        self.avgdl = 0
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []

        if is_retain_docs:
            self.docs = copy.deepcopy(corpus)

        if tokenizer:
            corpus = [self.tokenizer.tokenize(document) for document in corpus]
        else:
            corpus = [list(document) for document in corpus]

        self._initialize(corpus)

    def _initialize(self, corpus):
        """Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies."""
        nd = {}  # word -> number of documents with word
        num_doc = 0
        for document in corpus:                        
            self.corpus_size += 1
            self.doc_len.append(len(document))
            num_doc += len(document)

            frequencies = {}
            for word in document:
                if word not in frequencies:
                    frequencies[word] = 0
                frequencies[word] += 1
            self.doc_freqs.append(frequencies)

            for word, freq in iteritems(frequencies):
                if word not in nd:
                    nd[word] = 0
                nd[word] += 1

        self.avgdl = float(num_doc) / self.corpus_size

        idf_sum = 0
        negative_idfs = []
        for word, freq in iteritems(nd):
            idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
            self.idf[word] = idf
            idf_sum += idf
            if idf < 0:
                negative_idfs.append(word)
        self.average_idf = float(idf_sum) / len(self.idf)

        if self.average_idf < 0:
            logger.warning(
                'Average inverse document frequency is less than zero. Your corpus of {} documents'
                ' is either too small or it does not originate from natural text. BM25 may produce'
                ' unintuitive results.'.format(self.corpus_size)
            )

        eps = self.epsilon * self.average_idf
        for word in negative_idfs:
            self.idf[word] = eps

    def get_score(self, query, index):
        score = 0.0
        doc_freqs = self.doc_freqs[index]
        numerator_constant = self.k1 + 1
        denominator_constant = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
        for word in query:
            if word in doc_freqs:
                df = self.doc_freqs[index][word]
                idf = self.idf[word]
                score += (idf * df * numerator_constant) / (df + denominator_constant)
        return score

    def get_scores(self, query):
        scores = [self.get_score(query, index) for index in range(self.corpus_size)]
        return scores

    def recall(self, query, topk=5):
        scores = self.get_scores(query)
        indexs = np.argsort(scores)[::-1][:topk]

        if self.docs is None:
            return [[i, scores[i]] for i in indexs]
        else:
            return [[self.docs[i], scores[i]] for i in indexs]

bm25_model = pickle.load(open('checkpoint/recall/bm25_model.pkl', 'rb'))
map_dict = pickle.load(open('checkpoint/recall/map_dict.pkl', 'rb'))

#### 2. 数据生成

In [4]:
train_data_df = pd.read_json('../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CHIP-CDN/CHIP-CDN_train.json')
dev_data_df = pd.read_json('../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CHIP-CDN/CHIP-CDN_dev.json')

In [5]:
pair_dataset = []
for _raw_word, _normalized_result in zip(train_data_df['text'], train_data_df['normalized_result']):
    normalized_words = set(_normalized_result.split('##'))
    search_result_ = set()
    train_pair_dataset = []
    for _index, _search_word in enumerate(
        [_result for _results in bm25_model.recall(_raw_word, topk=1000) for _result in map_dict[_results[0]]]):

        if _search_word in normalized_words:
            continue
        elif _search_word in search_result_:
            continue
        else:
            train_pair_dataset.append([_raw_word, _search_word, '0'])
            
        search_result_.add(_search_word)
            
        if len(train_pair_dataset) == 20:
            pair_dataset.extend(train_pair_dataset)
            break
                    
    for _st_word in normalized_words:
        for _ in range(10):
            pair_dataset.append([_raw_word, _st_word, '1'])

In [6]:
pair_dev_dataset = []
for _raw_word, _normalized_result in zip(train_data_df['text'], train_data_df['normalized_result']):
    normalized_words = set(_normalized_result.split('##'))
    search_result_ = set()
    dev_pair_dataset = []
    for _index, _search_word in enumerate(
        [_result for _results in bm25_model.recall(_raw_word, topk=1000) for _result in map_dict[_results[0]]]):

        if _search_word in normalized_words:
            continue
        elif _search_word in search_result_:
            continue
        else:
            dev_pair_dataset.append([_raw_word, _search_word, '0'])
            
        search_result_.add(_search_word)
            
        if len(dev_pair_dataset) == 1:
            pair_dev_dataset.extend(dev_pair_dataset)
            break
                    
    for _st_word in normalized_words:
        pair_dev_dataset.append([_raw_word, _st_word, '1'])

In [7]:
train_data_df = pd.DataFrame(pair_dataset, columns=['text_a', 'text_b', 'label'])
dev_data_df = pd.DataFrame(pair_dev_dataset, columns=['text_a', 'text_b', 'label'])

In [8]:
tm_train_dataset = Dataset(train_data_df)
tm_dev_dataset = Dataset(dev_data_df)

#### 2. 词典创建

In [9]:
tokenizer = Tokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=50)

#### 3. 生成分词器

#### 4. ID化

In [10]:
tm_train_dataset.convert_to_ids(tokenizer)
tm_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [11]:
config = BertConfig.from_pretrained('nghuyong/ernie-1.0',
                                    num_labels=len(tm_train_dataset.cat2id))

#### 2. 模型创建

In [12]:
torch.cuda.empty_cache()

In [13]:
dl_module = Bert.from_pretrained('nghuyong/ernie-1.0', 
                                 config=config)

Some weights of the model checkpoint at nghuyong/ernie-1.0 were not used when initializing Bert: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing Bert from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Bert from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Bert were not initialized from the model checkpoint at nghuyong/ernie-1.0 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [14]:
# 设置运行次数
num_epoches = 2
batch_size = 32

In [15]:
param_optimizer = list(dl_module.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]      

#### 2. 任务创建

In [16]:
model = Task(dl_module, 'adamw', 'lsce', cuda_device=3, ema_decay=0.995)

#### 3. 训练

In [17]:
model.fit(tm_train_dataset, 
          tm_dev_dataset,
          lr=3e-5,
          epochs=num_epoches, 
          batch_size=batch_size,
          params=optimizer_grouped_parameters
         )

  1% 100/7161 [00:26<31:18,  3.76it/s]

[99/7161],train loss is:0.536521,train evaluation is:0.770625


  3% 200/7161 [00:53<31:20,  3.70it/s]

[199/7161],train loss is:0.504741,train evaluation is:0.796406


  4% 300/7161 [01:20<31:13,  3.66it/s]

[299/7161],train loss is:0.479518,train evaluation is:0.816771


  6% 400/7161 [01:48<31:09,  3.62it/s]

[399/7161],train loss is:0.464169,train evaluation is:0.829453


  7% 500/7161 [02:15<30:46,  3.61it/s]

[499/7161],train loss is:0.453538,train evaluation is:0.838688


  8% 600/7161 [02:43<30:17,  3.61it/s]

[599/7161],train loss is:0.444477,train evaluation is:0.845990


 10% 700/7161 [03:11<29:52,  3.60it/s]

[699/7161],train loss is:0.437269,train evaluation is:0.851295


 11% 800/7161 [03:38<29:33,  3.59it/s]

[799/7161],train loss is:0.430789,train evaluation is:0.855820


 13% 900/7161 [04:06<29:16,  3.57it/s]

[899/7161],train loss is:0.424376,train evaluation is:0.859688


 14% 1000/7161 [04:34<28:50,  3.56it/s]

[999/7161],train loss is:0.420098,train evaluation is:0.863500


 15% 1100/7161 [05:03<28:19,  3.57it/s]

[1099/7161],train loss is:0.414926,train evaluation is:0.867074


 17% 1200/7161 [05:31<28:01,  3.54it/s]

[1199/7161],train loss is:0.411834,train evaluation is:0.869115


 18% 1300/7161 [05:59<27:36,  3.54it/s]

[1299/7161],train loss is:0.408520,train evaluation is:0.871635


 20% 1400/7161 [06:27<27:09,  3.53it/s]

[1399/7161],train loss is:0.405696,train evaluation is:0.873817


 21% 1500/7161 [06:55<26:46,  3.52it/s]

[1499/7161],train loss is:0.401988,train evaluation is:0.876313


 22% 1600/7161 [07:24<26:10,  3.54it/s]

[1599/7161],train loss is:0.399567,train evaluation is:0.877988


 24% 1700/7161 [07:52<25:40,  3.55it/s]

[1699/7161],train loss is:0.396465,train evaluation is:0.880276


 25% 1800/7161 [08:20<25:07,  3.56it/s]

[1799/7161],train loss is:0.393469,train evaluation is:0.882083


 27% 1900/7161 [08:48<24:44,  3.54it/s]

[1899/7161],train loss is:0.391088,train evaluation is:0.883503


 28% 2000/7161 [09:17<24:06,  3.57it/s]

[1999/7161],train loss is:0.388921,train evaluation is:0.885047


 29% 2100/7161 [09:45<23:36,  3.57it/s]

[2099/7161],train loss is:0.386407,train evaluation is:0.886682


 31% 2200/7161 [10:13<23:13,  3.56it/s]

[2199/7161],train loss is:0.384940,train evaluation is:0.887727


 32% 2300/7161 [10:41<22:44,  3.56it/s]

[2299/7161],train loss is:0.383162,train evaluation is:0.888886


 34% 2400/7161 [11:09<22:18,  3.56it/s]

[2399/7161],train loss is:0.381412,train evaluation is:0.890026


 35% 2500/7161 [11:37<21:44,  3.57it/s]

[2499/7161],train loss is:0.379550,train evaluation is:0.891387


 36% 2600/7161 [12:05<21:22,  3.56it/s]

[2599/7161],train loss is:0.377604,train evaluation is:0.892692


 38% 2700/7161 [12:33<20:53,  3.56it/s]

[2699/7161],train loss is:0.375888,train evaluation is:0.893831


 39% 2800/7161 [13:01<20:33,  3.53it/s]

[2799/7161],train loss is:0.374172,train evaluation is:0.895033


 40% 2900/7161 [13:29<20:02,  3.54it/s]

[2899/7161],train loss is:0.372120,train evaluation is:0.896433


 42% 3000/7161 [13:57<19:28,  3.56it/s]

[2999/7161],train loss is:0.370925,train evaluation is:0.897260


 43% 3100/7161 [14:25<18:57,  3.57it/s]

[3099/7161],train loss is:0.369513,train evaluation is:0.898135


 45% 3200/7161 [14:53<18:27,  3.58it/s]

[3199/7161],train loss is:0.367987,train evaluation is:0.899121


 46% 3300/7161 [15:21<18:02,  3.57it/s]

[3299/7161],train loss is:0.366034,train evaluation is:0.900379


 47% 3400/7161 [15:49<17:34,  3.57it/s]

[3399/7161],train loss is:0.365053,train evaluation is:0.900993


 49% 3500/7161 [16:17<17:07,  3.56it/s]

[3499/7161],train loss is:0.363869,train evaluation is:0.901795


 50% 3600/7161 [16:45<16:39,  3.56it/s]

[3599/7161],train loss is:0.362318,train evaluation is:0.902778


 52% 3700/7161 [17:14<16:11,  3.56it/s]

[3699/7161],train loss is:0.360965,train evaluation is:0.903606


 53% 3800/7161 [17:42<15:37,  3.58it/s]

[3799/7161],train loss is:0.359614,train evaluation is:0.904457


 54% 3900/7161 [18:10<15:11,  3.58it/s]

[3899/7161],train loss is:0.358287,train evaluation is:0.905353


 56% 4000/7161 [18:38<14:45,  3.57it/s]

[3999/7161],train loss is:0.357015,train evaluation is:0.906203


 57% 4100/7161 [19:06<14:16,  3.57it/s]

[4099/7161],train loss is:0.355961,train evaluation is:0.906799


 59% 4200/7161 [19:34<13:48,  3.57it/s]

[4199/7161],train loss is:0.354777,train evaluation is:0.907612


 60% 4300/7161 [20:02<13:30,  3.53it/s]

[4299/7161],train loss is:0.353956,train evaluation is:0.908154


 61% 4400/7161 [20:30<12:58,  3.55it/s]

[4399/7161],train loss is:0.352900,train evaluation is:0.908849


 63% 4500/7161 [20:58<12:28,  3.56it/s]

[4499/7161],train loss is:0.351869,train evaluation is:0.909542


 64% 4600/7161 [21:26<12:00,  3.56it/s]

[4599/7161],train loss is:0.351026,train evaluation is:0.910217


 66% 4700/7161 [21:54<11:31,  3.56it/s]

[4699/7161],train loss is:0.350330,train evaluation is:0.910645


 67% 4800/7161 [22:22<10:59,  3.58it/s]

[4799/7161],train loss is:0.349470,train evaluation is:0.911243


 68% 4900/7161 [22:50<10:33,  3.57it/s]

[4899/7161],train loss is:0.348573,train evaluation is:0.911792


 70% 5000/7161 [23:18<10:05,  3.57it/s]

[4999/7161],train loss is:0.347706,train evaluation is:0.912369


 71% 5100/7161 [23:46<09:38,  3.56it/s]

[5099/7161],train loss is:0.346848,train evaluation is:0.912935


 73% 5200/7161 [24:14<09:09,  3.57it/s]

[5199/7161],train loss is:0.346106,train evaluation is:0.913456


 74% 5300/7161 [24:42<08:42,  3.56it/s]

[5299/7161],train loss is:0.345357,train evaluation is:0.913945


 75% 5400/7161 [25:10<08:14,  3.56it/s]

[5399/7161],train loss is:0.344600,train evaluation is:0.914421


 77% 5500/7161 [25:38<07:49,  3.54it/s]

[5499/7161],train loss is:0.343725,train evaluation is:0.915006


 78% 5600/7161 [26:07<07:17,  3.57it/s]

[5599/7161],train loss is:0.342855,train evaluation is:0.915614


 80% 5700/7161 [26:35<06:49,  3.57it/s]

[5699/7161],train loss is:0.342163,train evaluation is:0.916069


 81% 5800/7161 [27:02<06:20,  3.57it/s]

[5799/7161],train loss is:0.341411,train evaluation is:0.916557


 82% 5900/7161 [27:30<05:54,  3.56it/s]

[5899/7161],train loss is:0.340536,train evaluation is:0.917119


 84% 6000/7161 [27:58<05:25,  3.57it/s]

[5999/7161],train loss is:0.339724,train evaluation is:0.917677


 85% 6100/7161 [28:26<04:56,  3.58it/s]

[6099/7161],train loss is:0.338938,train evaluation is:0.918176


 87% 6200/7161 [28:54<04:29,  3.56it/s]

[6199/7161],train loss is:0.338160,train evaluation is:0.918710


 88% 6300/7161 [29:22<04:00,  3.57it/s]

[6299/7161],train loss is:0.337357,train evaluation is:0.919221


 89% 6400/7161 [29:51<03:33,  3.57it/s]

[6399/7161],train loss is:0.336425,train evaluation is:0.919795


 91% 6500/7161 [30:19<03:05,  3.56it/s]

[6499/7161],train loss is:0.335872,train evaluation is:0.920149


 92% 6600/7161 [30:47<02:38,  3.54it/s]

[6599/7161],train loss is:0.335305,train evaluation is:0.920521


 94% 6700/7161 [31:15<02:09,  3.56it/s]

[6699/7161],train loss is:0.334723,train evaluation is:0.920910


 95% 6800/7161 [31:43<01:41,  3.56it/s]

[6799/7161],train loss is:0.334186,train evaluation is:0.921278


 96% 6900/7161 [32:11<01:13,  3.57it/s]

[6899/7161],train loss is:0.333549,train evaluation is:0.921658


 98% 7000/7161 [32:39<00:44,  3.59it/s]

[6999/7161],train loss is:0.332928,train evaluation is:0.922054


 99% 7100/7161 [33:07<00:17,  3.58it/s]

[7099/7161],train loss is:0.332300,train evaluation is:0.922421


100% 7161/7161 [33:24<00:00,  3.57it/s]


epoch:[0],train loss is:0.331901,train evaluation is:0.922657 

classification_report: 
               precision    recall  f1-score   support

           0       0.98      0.79      0.88      6000
           1       0.90      0.99      0.94     10915

    accuracy                           0.92     16915
   macro avg       0.94      0.89      0.91     16915
weighted avg       0.93      0.92      0.92     16915

confusion_matrix_: 
 [[ 4751  1249]
 [   85 10830]]
test loss is:0.351285,test acc is:0.921135,f1_score is:0.909438


  1% 100/7161 [00:28<33:02,  3.56it/s]

[99/7161],train loss is:0.280566,train evaluation is:0.957812


  3% 200/7161 [00:56<32:31,  3.57it/s]

[199/7161],train loss is:0.280078,train evaluation is:0.956250


  4% 300/7161 [01:24<32:20,  3.53it/s]

[299/7161],train loss is:0.278170,train evaluation is:0.957604


  6% 400/7161 [01:52<31:35,  3.57it/s]

[399/7161],train loss is:0.277422,train evaluation is:0.957187


  7% 500/7161 [02:20<30:59,  3.58it/s]

[499/7161],train loss is:0.277559,train evaluation is:0.957375


  8% 600/7161 [02:48<30:34,  3.58it/s]

[599/7161],train loss is:0.276903,train evaluation is:0.957865


 10% 700/7161 [03:16<30:15,  3.56it/s]

[699/7161],train loss is:0.276575,train evaluation is:0.958170


 11% 800/7161 [03:44<29:36,  3.58it/s]

[799/7161],train loss is:0.275850,train evaluation is:0.958516


 13% 900/7161 [04:12<29:14,  3.57it/s]

[899/7161],train loss is:0.275652,train evaluation is:0.958889


 14% 1000/7161 [04:40<28:48,  3.56it/s]

[999/7161],train loss is:0.275911,train evaluation is:0.958594


 15% 1100/7161 [05:08<28:15,  3.57it/s]

[1099/7161],train loss is:0.275840,train evaluation is:0.958324


 17% 1200/7161 [05:36<27:50,  3.57it/s]

[1199/7161],train loss is:0.275255,train evaluation is:0.958776


 18% 1300/7161 [06:04<27:31,  3.55it/s]

[1299/7161],train loss is:0.275467,train evaluation is:0.958606


 20% 1400/7161 [06:32<27:09,  3.54it/s]

[1399/7161],train loss is:0.274943,train evaluation is:0.958795


 21% 1500/7161 [07:00<26:43,  3.53it/s]

[1499/7161],train loss is:0.274338,train evaluation is:0.959187


 22% 1600/7161 [07:28<25:59,  3.57it/s]

[1599/7161],train loss is:0.274561,train evaluation is:0.959160


 24% 1700/7161 [07:57<25:30,  3.57it/s]

[1699/7161],train loss is:0.274479,train evaluation is:0.959265


 25% 1800/7161 [08:25<25:07,  3.56it/s]

[1799/7161],train loss is:0.274269,train evaluation is:0.959462


 27% 1900/7161 [08:53<24:35,  3.57it/s]

[1899/7161],train loss is:0.274929,train evaluation is:0.959112


 28% 2000/7161 [09:21<24:00,  3.58it/s]

[1999/7161],train loss is:0.274879,train evaluation is:0.959219


 29% 2100/7161 [09:49<23:34,  3.58it/s]

[2099/7161],train loss is:0.274582,train evaluation is:0.959375


 31% 2200/7161 [10:17<23:06,  3.58it/s]

[2199/7161],train loss is:0.274208,train evaluation is:0.959631


 32% 2300/7161 [10:45<22:46,  3.56it/s]

[2299/7161],train loss is:0.274397,train evaluation is:0.959579


 34% 2400/7161 [11:13<22:17,  3.56it/s]

[2399/7161],train loss is:0.274091,train evaluation is:0.959714


 35% 2500/7161 [11:41<21:51,  3.55it/s]

[2499/7161],train loss is:0.274419,train evaluation is:0.959500


 36% 2600/7161 [12:09<21:22,  3.56it/s]

[2599/7161],train loss is:0.274369,train evaluation is:0.959567


 38% 2700/7161 [12:37<21:01,  3.54it/s]

[2699/7161],train loss is:0.274062,train evaluation is:0.959757


 39% 2800/7161 [13:06<20:27,  3.55it/s]

[2799/7161],train loss is:0.273929,train evaluation is:0.959833


 40% 2900/7161 [13:34<19:56,  3.56it/s]

[2899/7161],train loss is:0.273923,train evaluation is:0.959946


 42% 3000/7161 [14:02<19:26,  3.57it/s]

[2999/7161],train loss is:0.273823,train evaluation is:0.960031


 43% 3100/7161 [14:30<19:01,  3.56it/s]

[3099/7161],train loss is:0.273658,train evaluation is:0.960081


 45% 3200/7161 [14:58<18:30,  3.57it/s]

[3199/7161],train loss is:0.273731,train evaluation is:0.959951


 46% 3300/7161 [15:26<18:05,  3.56it/s]

[3299/7161],train loss is:0.273679,train evaluation is:0.959953


 47% 3400/7161 [15:54<17:29,  3.58it/s]

[3399/7161],train loss is:0.273554,train evaluation is:0.959982


 49% 3500/7161 [16:22<17:05,  3.57it/s]

[3499/7161],train loss is:0.273376,train evaluation is:0.960027


 50% 3600/7161 [16:50<16:42,  3.55it/s]

[3599/7161],train loss is:0.273033,train evaluation is:0.960226


 52% 3700/7161 [17:18<16:17,  3.54it/s]

[3699/7161],train loss is:0.272908,train evaluation is:0.960372


 53% 3800/7161 [17:46<15:48,  3.54it/s]

[3799/7161],train loss is:0.272743,train evaluation is:0.960485


 54% 3900/7161 [18:14<15:07,  3.59it/s]

[3899/7161],train loss is:0.272737,train evaluation is:0.960473


 56% 4000/7161 [18:42<14:46,  3.57it/s]

[3999/7161],train loss is:0.272584,train evaluation is:0.960547


 57% 4100/7161 [19:10<14:18,  3.57it/s]

[4099/7161],train loss is:0.272444,train evaluation is:0.960640


 59% 4200/7161 [19:38<13:48,  3.57it/s]

[4199/7161],train loss is:0.272138,train evaluation is:0.960818


 60% 4300/7161 [20:06<13:21,  3.57it/s]

[4299/7161],train loss is:0.272043,train evaluation is:0.960792


 61% 4400/7161 [20:34<12:51,  3.58it/s]

[4399/7161],train loss is:0.272100,train evaluation is:0.960732


 63% 4500/7161 [21:02<12:23,  3.58it/s]

[4499/7161],train loss is:0.271868,train evaluation is:0.960833


 64% 4600/7161 [21:30<11:55,  3.58it/s]

[4599/7161],train loss is:0.271926,train evaluation is:0.960808


 66% 4700/7161 [21:58<11:33,  3.55it/s]

[4699/7161],train loss is:0.271634,train evaluation is:0.961004


 67% 4800/7161 [22:27<11:04,  3.55it/s]

[4799/7161],train loss is:0.271291,train evaluation is:0.961224


 68% 4900/7161 [22:55<10:33,  3.57it/s]

[4899/7161],train loss is:0.271109,train evaluation is:0.961269


 70% 5000/7161 [23:23<10:07,  3.56it/s]

[4999/7161],train loss is:0.271070,train evaluation is:0.961350


 71% 5100/7161 [23:51<09:38,  3.57it/s]

[5099/7161],train loss is:0.271035,train evaluation is:0.961317


 73% 5200/7161 [24:19<09:08,  3.58it/s]

[5199/7161],train loss is:0.271098,train evaluation is:0.961310


 74% 5300/7161 [24:47<08:38,  3.59it/s]

[5299/7161],train loss is:0.270939,train evaluation is:0.961338


 75% 5400/7161 [25:15<08:15,  3.55it/s]

[5399/7161],train loss is:0.270757,train evaluation is:0.961424


 77% 5500/7161 [25:43<07:44,  3.57it/s]

[5499/7161],train loss is:0.270608,train evaluation is:0.961523


 78% 5600/7161 [26:11<07:16,  3.58it/s]

[5599/7161],train loss is:0.270540,train evaluation is:0.961579


 80% 5700/7161 [26:39<06:49,  3.56it/s]

[5699/7161],train loss is:0.270245,train evaluation is:0.961782


 81% 5800/7161 [27:07<06:24,  3.54it/s]

[5799/7161],train loss is:0.270160,train evaluation is:0.961794


 82% 5900/7161 [27:35<05:56,  3.53it/s]

[5899/7161],train loss is:0.269955,train evaluation is:0.961864


 84% 6000/7161 [28:03<05:26,  3.56it/s]

[5999/7161],train loss is:0.269649,train evaluation is:0.962036


 85% 6100/7161 [28:31<04:57,  3.57it/s]

[6099/7161],train loss is:0.269434,train evaluation is:0.962187


 87% 6200/7161 [28:59<04:28,  3.58it/s]

[6199/7161],train loss is:0.269195,train evaluation is:0.962319


 88% 6300/7161 [29:27<04:01,  3.57it/s]

[6299/7161],train loss is:0.269088,train evaluation is:0.962366


 89% 6400/7161 [29:55<03:32,  3.58it/s]

[6399/7161],train loss is:0.269074,train evaluation is:0.962363


 91% 6500/7161 [30:23<03:04,  3.58it/s]

[6499/7161],train loss is:0.268945,train evaluation is:0.962409


 92% 6600/7161 [30:51<02:36,  3.58it/s]

[6599/7161],train loss is:0.268866,train evaluation is:0.962457


 94% 6700/7161 [31:19<02:08,  3.58it/s]

[6699/7161],train loss is:0.268837,train evaluation is:0.962435


 95% 6800/7161 [31:47<01:41,  3.56it/s]

[6799/7161],train loss is:0.268736,train evaluation is:0.962477


 96% 6900/7161 [32:15<01:13,  3.54it/s]

[6899/7161],train loss is:0.268547,train evaluation is:0.962591


 98% 7000/7161 [32:43<00:45,  3.55it/s]

[6999/7161],train loss is:0.268319,train evaluation is:0.962754


 99% 7100/7161 [33:12<00:17,  3.58it/s]

[7099/7161],train loss is:0.268161,train evaluation is:0.962852


100% 7161/7161 [33:29<00:00,  3.56it/s]


epoch:[1],train loss is:0.268078,train evaluation is:0.962885 

classification_report: 
               precision    recall  f1-score   support

           0       0.99      0.85      0.91      6000
           1       0.92      1.00      0.96     10915

    accuracy                           0.94     16915
   macro avg       0.96      0.92      0.94     16915
weighted avg       0.95      0.94      0.94     16915

confusion_matrix_: 
 [[ 5077   923]
 [   43 10872]]
test loss is:0.314177,test acc is:0.942891,f1_score is:0.935297


In [18]:
model.ema.store(model.module.parameters())
model.ema.copy_to(model.module.parameters())  

<br>

### 四、模型验证与保存

#### 1. 模型验证

In [19]:
from ark_nlp.factory.predictor import TMPredictor

In [20]:
tm_predictor_instance = TMPredictor(model.module, tokenizer, tm_train_dataset.cat2id)

In [21]:
tm_predictor_instance.predict_one_sample(['胸部皮肤破裂伤', '胸部开放性损伤'], return_proba=True)

[('1', 0.9374510049819946)]

#### 2. 模型保存

In [22]:
!mkdir -p checkpoint/textsim

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [23]:
import pickle

In [24]:
torch.save(model.module.state_dict(), 'checkpoint/textsim/module.pth')

In [25]:
with open('checkpoint/textsim/cat2id.pkl', "wb") as f:
    pickle.dump(tm_train_dataset.cat2id, f)