In [70]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, AutoModelForMaskedLM

import pickle

from utils import *
from dataset import *

from preprocess import *
from wrapper import *
from models import *
from pipeline import PipelineGED

torch.cuda.empty_cache()

device = torch.device("cpu")
# device = torch.device("cuda:0")
ntf()

In [71]:
def sigmoid_dim_1(arr):
    e = np.exp(arr[:, 1, :])
    return e / (1+e)

In [72]:
## useful setup functions
def get_dataset(model_name, oob_model_name, max_length):
    test_dataset_config = {
        'model_name':model_name,
        'aux_model_name':oob_model_name,
        'maxlength':max_length,
        'train_val_split':-1,
        'test':True, 
        'remove_username':False,
        'remove_punctuation':False, 
        'to_simplified':False, 
        'emoji_to_text':False, 
        'device':device,
        'split_words':False, 
        'cut_all':False, 
    }

    test = DatasetWithAuxiliaryEmbeddings(df=test_df.reset_index(), **test_dataset_config)
    test.tokenize()
    test.construct_dataset()
    return test

def load_model_configs(MODEL_ARCH, MODEL_CONFIG):
    max_length = MODEL_CONFIG[MODEL_ARCH]['max_length']
    model_name = MODEL_CONFIG[MODEL_ARCH]['model_name']
    model_architecture = MODEL_CONFIG[MODEL_ARCH]['model_architecture']
    checkpoints = MODEL_CONFIG[MODEL_ARCH]['checkpoints']
    oob_model_name = None
    return model_name, oob_model_name, model_architecture, max_length, checkpoints

def get_pipeline(model_name, oob_model_name, max_length, model_architecture, checkpoints):
    data_configs = {
        'model_name':model_name,
        'maxlength':max_length,
        'train_val_split':-1,
        'test':True, 
        'remove_username':False,
        'remove_punctuation':False, 
        'to_simplified':False, 
        'emoji_to_text':False, 
        'split_words':False, 
        'cut_all':False, 
    }
    clf = PipelineGED(
        model_name=model_name, 
        oob_model_name=oob_model_name, 
        data_configs=data_configs, 
        model_architecture=model_architecture, 
        hidden_layer_size=max_length, 
    )

    def apply_ged_pipeline_oob(texts, checkpoints=checkpoints, majority_vote=False):
        probs = clf(
            texts=texts, 
            checkpoints=checkpoints, 
            device=device, 
            output_probabilities=True, 
            display=False, 
            aggregate=False, 
        )
        return probs

    return clf, apply_ged_pipeline_oob

def get_features(t):
    return torch.nn.Softmax(dim=1)(t)[:, 1, :].cpu().numpy()

In [73]:
test_df = pd.read_csv('../data/data-org/test.csv', sep='\t', index_col='id')

MODEL_CONFIG = {
    'macbert':{
        'model_name':'hfl/chinese-macbert-base', 
        'model_architecture':'bert_with_clf_head', 
        'max_length':128, 
        'checkpoints':[
            '../finetuned_models/ensemble_1/macbert/model0/checkpoint-1142/pytorch_model.bin', 
            '../finetuned_models/ensemble_1/macbert/model1/checkpoint-2278/pytorch_model.bin', 
            '../finetuned_models/ensemble_1/macbert/model2/checkpoint-1137/pytorch_model.bin', 
            '../finetuned_models/ensemble_1/macbert/model3/checkpoint-1138/pytorch_model.bin', 
            '../finetuned_models/ensemble_1/macbert/model4/checkpoint-1141/pytorch_model.bin', 
        ], 
    }, 
    'roberta-word-based':{
        'model_name':'uer/roberta-base-word-chinese-cluecorpussmall', 
        'model_architecture':'bert_word_based', 
        'max_length':64, 
        'checkpoints':[
            '../finetuned_models/rww/model0/checkpoint-1238/pytorch_model.bin', 
            # '../finetuned_models/rww/model1/checkpoint-2478/pytorch_model.bin', 
            # '../finetuned_models/rww/model2/checkpoint-2474/pytorch_model.bin', 
            # '../finetuned_models/rww/model3/checkpoint-2468/pytorch_model.bin', 
            # '../finetuned_models/rww/model4/checkpoint-1238/pytorch_model.bin', 
        ], 
    }, 
    'pert':{
        'model_name':'hfl/chinese-pert-base', 
        'model_architecture':'bert_with_clf_head', 
        'max_length':64, 
        'checkpoints':[
            '../finetuned_models/pert_benchmark/model0/checkpoint-2268/pytorch_model.bin', 
        ], 
    }, 
    'bert_with_bigru':{
        'model_name':'uer/roberta-base-word-chinese-cluecorpussmall', 
        'model_architecture':'bert_with_bigru', 
        'max_length':64, 
        'checkpoints':[
            '../finetuned_models/rww-bigru/model0/checkpoint-2404/pytorch_model.bin', 
        ], 
    }, 
}

### emsemble

In [29]:
MODEL_ARCH = 'roberta-word-based'

model_name, oob_model_name, model_architecture, max_length, checkpoints = load_model_configs(MODEL_ARCH, MODEL_CONFIG)
test = get_dataset(model_name, oob_model_name, max_length)

_, pipe_2 = get_pipeline(model_name, oob_model_name, max_length, model_architecture, checkpoints)
clf_2 = pickle.load(open('../finetuned_models/ensemble_1/roberta_word_based/svm.sav', 'rb'))

rww_logits = pipe_2(texts=test_df.text.values, checkpoints=checkpoints, )

rww_probs = sigmoid_dim_1(rww_logits.cpu().numpy())
rww_pred = clf_2.predict_proba(rww_probs)

Some weights of the model checkpoint at uer/roberta-base-word-chinese-cluecorpussmall were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at uer/roberta-base-word-chinese-cluecorpussmall and are newly initialized: ['bert

In [141]:
preds = rww_pred.argmax(1)
test_data = test_df.copy(deep=True)
test_data['comp'] = pd.read_csv('../data/data-org/labeled_test.csv').set_index('id').label

test_data['pred'] = preds
test_data['score'] = rww_pred[:, 1]

test_data = test_data[['comp', 'pred',  'score', 'text']]
# test_data['error_chars'] = test_err_char_lst
print(len(test_data[test_data.comp != test_data.pred]))
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    # display(data[data.comp != data.prediction].sort_values(by='confidence'))
    display(test_data[test_data.comp != test_data.pred])

352


Unnamed: 0_level_0,comp,pred,score,text
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,1,0,0.220035,唐诗宋词是我国浩如烟海的古代文化中一块瑰丽的瑰宝，所以我们要好好珍惜，将它们传承下去。
2,1,0,0.190483,这次迎新活动举办得非常成功，参加活动的人数超过1000人。
3,1,0,0.206641,春天到了，山上姹紫嫣红的野花开得灿烂极了。
8,0,1,0.756092,王小红在这次考试中取得了好成绩，校长亲自为他颁发了奖状。
10,0,1,0.622002,妹妹总是跟在哥哥的后面，想要哥哥给自己买糖吃。
11,1,0,0.224153,王磊是个著名的画家，能画出精妙绝伦的佳作。
18,0,1,0.648255,我国生产的石油，长期不能自给。
22,0,1,0.953546,厦门的彩虹沙滩很宽很美，小红很羡慕生活在海滩的人，因为他们出门就可以赶海。
24,0,1,0.802656,每次到姥姥家，她都会做上一桌美味的食物，而我都会吃好多，直到把肚皮撑得溜圆。
25,1,0,0.335441,电影《流浪地球》一经上映，刘慈欣的同名小说也备受关注，观众认为其气势恢宏、令人震撼。


### single model

In [74]:
MODEL_ARCH = 'bert_with_bigru'

model_name, oob_model_name, model_architecture, max_length, checkpoints = load_model_configs(MODEL_ARCH, MODEL_CONFIG)
test = get_dataset(model_name, oob_model_name, max_length)

_, pipe_2 = get_pipeline(model_name, oob_model_name, max_length, model_architecture, checkpoints)

In [None]:
rww_logits = pipe_2(texts=test_df.text.values, checkpoints=checkpoints, )

100%|██████████| 65/65 [01:22<00:00,  1.27s/it]


In [13]:
probs = torch.nn.Softmax(dim=1)(rww_logits)[..., 0].cpu().numpy()

preds = probs.argmax(1)
test_data = test_df.copy(deep=True)
test_data['comp'] = pd.read_csv('../data/data-org/labeled_test.csv').set_index('id').label

test_data['pred'] = preds
test_data['score'] = probs[:, 1]

test_data = test_data[['comp', 'pred',  'score', 'text']]
# test_data['error_chars'] = test_err_char_lst
print(len(test_data[test_data.comp != test_data.pred]))
test_data[test_data.comp != test_data.pred].to_csv('../finetuned_models/rww-bigru/model0/wrong.csv')

340


In [8]:
from sklearn.metrics import precision_score, recall_score, f1_score

def display_scores(y_true, y_pred):
    pre, rec, f1 = precision_score(y_true, y_pred), recall_score(y_true, y_pred), f1_score(y_true, y_pred)
    print(f'Precision = {pre}')
    print(f'Recall = {rec}')
    print(f'F1 = {f1}')
    return pre, rec, f1

display_scores(test_data.comp, test_data.pred)

Precision = 0.6626506024096386
Recall = 0.7277882797731569
F1 = 0.6936936936936937


(0.6626506024096386, 0.7277882797731569, 0.6936936936936937)

In [69]:
texts = [   
    '我国生产的石油，长期不能自给。', 
    '石油不能自给。',
    '厦门的彩虹沙滩很宽很美，小红很羡慕生活在海滩的人，因为他们出门就可以赶海。',
    '厦门的彩虹沙滩很宽很美。', 
    '小红很羡慕生活在海滩的人，因为他们出门就可以赶海。'
]

texts = [
    '每天抽出一点时间来练字不仅可以陶冶情操，还可以提升自己的书法水平。', 
    '每天抽出一点时间来练字不仅可以提升自己的书法水平，还可以陶冶情操。', 
]

texts = [
    '我特别喜欢去文具店买各种好看的笔记本、水笔、荧光笔、笔袋等文具。', 
    '我特别喜欢去文具店买各种好看的水笔、荧光笔等文具。', 
    '晓敏获得此次比赛的冠军，没有一个人不是心服口服。', 
    '晓敏获得此次比赛的冠军，没有一个人不是心服口服。', 
    '小明的爸妈出差去了，小明一个人在家很害怕，所以叫了很多好朋友晚上去他家玩。', 
    '小明一个人在家很害怕，所以叫了很多好朋友晚上去他家玩。',
    '研究并了解祖国的悠久历史，是每个学生应该具备的基本能力。',
    '了解并研究祖国的悠久历史，是每个学生应该具备的基本能力。'
]

texts = [
    '小孩子的父母教育他要把垃圾扔进垃圾桶里，不可以随便扔在路上。', 
    '父母教育他要把垃圾扔进垃圾桶里，不可以随便扔在路上。'
]

nn.Softmax(dim=1)(pipe_2(texts, checkpoints=checkpoints, )[..., 0])

100%|██████████| 1/1 [00:00<00:00,  5.43it/s]


tensor([[0.4835, 0.5165],
        [0.5285, 0.4715]])

## Single sentence perturbations

In [18]:
## 搭配
texts = [
    '有关部门对极少数不尊重环卫工人劳动、甚至辱骂侮辱殴打环卫工人的事件，及时进行了批评教育和严肃处理。', 
    '有关部门对极少数不尊重环卫工人劳动、甚至侮辱殴打环卫工人的人及时进行了批评教育。', 
]

texts = [
    '通过最近报刊上发表的一系列文章，给了我们一个十分有益的启示：要形成好的社会风气，就必须加强国民素质教育。', 
    '最近报刊上发表的一系列文章，给了我们一个十分有益的启示：要形成好的社会风气，就必须加强国民素质教育。',
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

NameError: name 'clf_2' is not defined

In [71]:
texts = [
    '通过最近报刊上发表的一系列文章，给了我们一个十分有益的启示：要形成好的社会风气，就必须加强国民素质教育。', 
    '最近报刊上发表的一系列文章，给了我们一个十分有益的启示：要形成好的社会风气，就必须加强国民素质教育。',
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  6.02it/s]
100%|██████████| 1/1 [00:00<00:00,  6.37it/s]
100%|██████████| 1/1 [00:00<00:00,  6.85it/s]
100%|██████████| 1/1 [00:00<00:00,  6.41it/s]
100%|██████████| 1/1 [00:00<00:00,  7.09it/s]


array([[0.06508228, 0.93491772],
       [0.77069301, 0.22930699]])

In [72]:
texts = [
    '万里长城以气魄雄伟而享誉世界。', 
    '万里长城以气魄雄伟享誉世界。',
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  5.78it/s]
100%|██████████| 1/1 [00:00<00:00,  6.02it/s]
100%|██████████| 1/1 [00:00<00:00,  6.13it/s]
100%|██████████| 1/1 [00:00<00:00,  5.95it/s]
100%|██████████| 1/1 [00:00<00:00,  6.54it/s]


array([[0.92655004, 0.07344996],
       [0.95226426, 0.04773574]])

In [73]:
texts = [
    '徽雕，是徽州古建筑中的精华，徽州“三雕”（木雕、砖雕、石雕）的制作技艺，都已经被列入到国家首批非物质文化遗产名录收录了。', 
    '徽雕，是徽州古建筑中的精华，徽州“三雕”（木雕、砖雕、石雕）的制作技艺，都已经被国家首批非物质文化遗产名录收录了。',
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  5.71it/s]
100%|██████████| 1/1 [00:00<00:00,  6.76it/s]
100%|██████████| 1/1 [00:00<00:00,  7.27it/s]
100%|██████████| 1/1 [00:00<00:00,  6.90it/s]
100%|██████████| 1/1 [00:00<00:00,  7.39it/s]


array([[0.17271004, 0.82728996],
       [0.77389363, 0.22610637]])

In [68]:
texts = [
    '张强从小生活在爷爷奶奶身边，因而对父母有着浓厚的感情。', 
    '张强从小生活在爷爷奶奶身边，却对父母有着浓厚的感情。',
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  7.46it/s]
100%|██████████| 1/1 [00:00<00:00,  6.21it/s]
100%|██████████| 1/1 [00:00<00:00,  7.35it/s]
100%|██████████| 1/1 [00:00<00:00,  7.09it/s]
100%|██████████| 1/1 [00:00<00:00,  6.67it/s]


array([[0.92316834, 0.07683166],
       [0.83448222, 0.16551778]])

In [69]:
texts = [
    '我们不仅要在课外学语文，还要在课堂中学语文。', 
    '我们不仅要在课堂中学语文，还要在课外学语文。', 
]

# probs = apply_ged_pipeline(texts)
probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  6.06it/s]
100%|██████████| 1/1 [00:00<00:00,  6.80it/s]
100%|██████████| 1/1 [00:00<00:00,  6.49it/s]
100%|██████████| 1/1 [00:00<00:00,  6.04it/s]
100%|██████████| 1/1 [00:00<00:00,  6.85it/s]


array([[0.95471314, 0.04528686],
       [0.93988721, 0.06011279]])

In [39]:
texts = [
    '小明待人非常大方友善得很。', 
    '小明待人非常大方友善。', 
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00, 21.81it/s]
100%|██████████| 1/1 [00:00<00:00, 22.80it/s]
100%|██████████| 1/1 [00:00<00:00, 22.81it/s]
100%|██████████| 1/1 [00:00<00:00, 22.30it/s]
100%|██████████| 1/1 [00:00<00:00, 22.80it/s]


array([[0.07760738, 0.92239262],
       [0.50696052, 0.49303948]])

In [40]:
# 语义：单面对双面
texts = [
    '他为了民族的兴亡和人民的利益奋斗了一生。', 
    '他为了民族的复兴和人民的利益奋斗了一生。', 
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00, 23.33it/s]
100%|██████████| 1/1 [00:00<00:00, 22.80it/s]
100%|██████████| 1/1 [00:00<00:00, 22.30it/s]
100%|██████████| 1/1 [00:00<00:00, 91.21it/s]
100%|██████████| 1/1 [00:00<00:00, 100.34it/s]


array([[0.88734228, 0.11265772],
       [0.88324504, 0.11675496]])

In [41]:
texts = [
    '我不禁怀疑这条题目是不是老师讲错了。', 
    '我不禁怀疑这道题目是老师讲错了。', 
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00, 22.80it/s]
100%|██████████| 1/1 [00:00<00:00, 22.30it/s]
100%|██████████| 1/1 [00:00<00:00, 23.89it/s]
100%|██████████| 1/1 [00:00<00:00, 91.24it/s]
100%|██████████| 1/1 [00:00<00:00, 100.31it/s]


array([[0.71639447, 0.28360553],
       [0.04231937, 0.95768063]])

In [82]:
texts = [
    '许多水果具有药用功效，例如大家都很熟悉的柠檬中间就含有柠檬酸、柠檬多酚及维生素C等成分就都有很强的抑制血小板聚集的作用。', 
    '许多水果具有药用功效，例如大家都很熟悉的柠檬中间含有的柠檬酸、柠檬多酚及维生素C等成分就都有很强的抑制血小板聚集的作用。', 
    '许多水果具有药用功效，例如大家都很熟悉的柠檬中间就含有柠檬酸、柠檬多酚及维生素C等成分，这些成分都有很强的抑制血小板聚集的作用。', 
]
probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  4.85it/s]
100%|██████████| 1/1 [00:00<00:00,  5.81it/s]
100%|██████████| 1/1 [00:00<00:00,  5.38it/s]
100%|██████████| 1/1 [00:00<00:00,  4.35it/s]
100%|██████████| 1/1 [00:00<00:00,  4.68it/s]


array([[0.07215578, 0.92784422],
       [0.72542971, 0.27457029],
       [0.84729584, 0.15270416]])

In [84]:
texts = [
    '“高雅艺术进校园”活动旨在提高学生们的审美素养，引导学生树立正确的文化观，增强学生的文化自信，提升校园文化品位，优化育人环境。', 
    '“高雅艺术进校园”活动旨在加强学生们的审美条件，引导学生树立正确的文化观，增强学生的文化自信，提升校园文化品位，优化育人环境。', 
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  8.62it/s]
100%|██████████| 1/1 [00:00<00:00,  8.06it/s]
100%|██████████| 1/1 [00:00<00:00,  7.73it/s]
100%|██████████| 1/1 [00:00<00:00,  6.98it/s]
100%|██████████| 1/1 [00:00<00:00,  7.03it/s]


array([[0.96792445, 0.03207555],
       [0.05741971, 0.94258029]])

In [86]:
texts = [
    '北京博物馆展出了新出土的两千多年前的文物。', 
    '北京博物馆展出了两千多年前新出土的文物。', 
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  6.34it/s]
100%|██████████| 1/1 [00:00<00:00,  6.33it/s]
100%|██████████| 1/1 [00:00<00:00,  5.78it/s]
100%|██████████| 1/1 [00:00<00:00,  6.20it/s]
100%|██████████| 1/1 [00:00<00:00,  6.41it/s]


array([[0.0453879 , 0.9546121 ],
       [0.29029855, 0.70970145]])

In [45]:
texts = [
    '我想，人是由三部分组成的：对往事的追忆、对未来的憧憬和对现时的把握。', 
    '我想，人是由三部分组成的：对往事的追忆、对现时的把握和对未来的憧憬。', 
]
probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00, 22.30it/s]
100%|██████████| 1/1 [00:00<00:00, 22.80it/s]
100%|██████████| 1/1 [00:00<00:00, 91.22it/s]
100%|██████████| 1/1 [00:00<00:00, 83.61it/s]
100%|██████████| 1/1 [00:00<00:00, 77.18it/s]


array([[0.90410073, 0.09589927],
       [0.89289776, 0.10710224]])

In [87]:
# 成分残缺/赘余

texts = [
    '随着通讯日渐发达，手机几乎成为大家不可缺少的必需品，但使用量增加之后，关于手机质量的投诉也越来越多。', 
    '随着通讯日渐发达，手机几乎成为大家的必需品，但使用量增加之后，关于手机质量的投诉也越来越多。', 
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  6.90it/s]
100%|██████████| 1/1 [00:00<00:00,  6.17it/s]
100%|██████████| 1/1 [00:00<00:00,  5.59it/s]
100%|██████████| 1/1 [00:00<00:00,  6.32it/s]
100%|██████████| 1/1 [00:00<00:00,  6.85it/s]


array([[0.58442285, 0.41557715],
       [0.93260904, 0.06739096]])

In [89]:
texts = [
    '清华大学联合剑桥大学、麻省理工学院，成立低碳能源大学联盟未来交通研究中心，他们试图寻找解决北京雾霾天出行困难的破解之道。',
    '清华大学联合剑桥大学、麻省理工学院，成立低碳能源大学联盟未来交通研究中心，他们试图寻找北京雾霾天出行困难的破解之道。',
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  6.62it/s]
100%|██████████| 1/1 [00:00<00:00,  6.90it/s]
100%|██████████| 1/1 [00:00<00:00,  6.90it/s]
100%|██████████| 1/1 [00:00<00:00,  6.71it/s]
100%|██████████| 1/1 [00:00<00:00,  6.94it/s]


array([[0.90367589, 0.09632411],
       [0.82693135, 0.17306865]])

In [48]:
texts = [
     '新的土地法规定，农民耕种的符合政策规定的自留地是一种正当的劳动，各级政府不得以各种理由加以干涉。', 
     '新的土地法规定，农民在符合政策规定的自留地上耕种是一种正当的劳动，各级政府不得以各种理由加以干涉。', 
]
probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00, 91.21it/s]
100%|██████████| 1/1 [00:00<00:00, 100.33it/s]
100%|██████████| 1/1 [00:00<00:00, 100.32it/s]
100%|██████████| 1/1 [00:00<00:00, 22.80it/s]
100%|██████████| 1/1 [00:00<00:00, 111.49it/s]


array([[0.07164799, 0.92835201],
       [0.85486411, 0.14513589]])

In [100]:
texts = [
    '虽然实验没有成功，但谁也不会认为这是他没有作努力的缘故。', 
    '虽然实验没有成功，但谁也不会认为这是因为他没有作努力的缘故。', 
    '虽然实验没有成功，但没有人不会认为这是因为他没有作努力的缘故。', 
]

probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  4.52it/s]
100%|██████████| 1/1 [00:00<00:00,  4.65it/s]
100%|██████████| 1/1 [00:00<00:00,  4.52it/s]
100%|██████████| 1/1 [00:00<00:00,  4.85it/s]
100%|██████████| 1/1 [00:00<00:00,  5.13it/s]


array([[0.41730797, 0.58269203],
       [0.03874894, 0.96125106],
       [0.04123123, 0.95876877]])

In [109]:
texts = [
    '能否打赢脱贫攻坚战，关键要做到因村因户因人施策，对症下药、精准滴灌，扶到点上、扶到根上。', 
    '要想打赢脱贫攻坚战，关键要做到因村因户因人施策，对症下药、精准滴灌，扶到点上、扶到根上。', 
    '艺人们过去一贯遭白眼，如今却受到人们热切的青睐，在这白眼和青睐之间，他们体味着人间温暖。', 
    '艺人们过去一贯遭白眼，如今却受到人们热切的青睐，在这白眼和青睐之间，他们体味着人间冷暖。', 
    '在这段时间里，我们的身体和精神都有很大的收获，体重逐日增加，最高的达5公斤，精神非常愉快。', 
    '在这段时间里，我们的体重逐日增加，精神非常愉快。'
]
probs = clf_2.predict_proba(get_features(pipe_2(texts)))
probs

100%|██████████| 1/1 [00:00<00:00,  3.38it/s]
100%|██████████| 1/1 [00:00<00:00,  3.38it/s]
100%|██████████| 1/1 [00:00<00:00,  3.37it/s]
100%|██████████| 1/1 [00:00<00:00,  2.62it/s]
100%|██████████| 1/1 [00:00<00:00,  2.21it/s]


array([[0.08230082, 0.91769918],
       [0.94335742, 0.05664258],
       [0.01362464, 0.98637536],
       [0.01152593, 0.98847407],
       [0.05122233, 0.94877767],
       [0.04545544, 0.95454456]])

In [46]:
torch.cuda.empty_cache()