In [1]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset

from utils import *
from dataset import *
from preprocess import *
from wrapper import *
from models import BertWithNER

torch.cuda.is_available()
torch.cuda.set_device('cuda:0')

In [2]:
train_df = pd.read_csv('../data/train.csv', sep='\t')

model_name = 'hfl/chinese-macbert-base'
ner_model_name = 'uer/roberta-base-finetuned-cluener2020-chinese'

train_dataset_config = {
    'model_name':model_name,
    'aux_model_name':ner_model_name,
    'maxlength':64,
    'train_val_split':-1,
    'test':False, 
    'remove_username':False,
    'remove_punctuation':False, 
    'to_simplified':False, 
    'emoji_to_text':False, 
    'device':torch.device('cuda'),
}

train = DatasetWithAuxiliaryEmbeddings(df=train_df[:200], **train_dataset_config)
train.tokenize()
train.construct_dataset()

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\holaj\AppData\Local\Temp\jieba.cache
Loading model cost 0.493 seconds.
Prefix dict has been built successfully.
  indexed_value = torch.tensor(value[index]).squeeze()


In [3]:
checkpoints = [
    '../ner_run_v1/fold0/checkpoint-2475/pytorch_model.bin', 
    '../ner_run_v1/fold1/checkpoint-7425/pytorch_model.bin', 
    '../ner_run_v1/fold2/checkpoint-9900/pytorch_model.bin', 
    '../ner_run_v1/fold3/checkpoint-7425/pytorch_model.bin', 
    '../ner_run_v1/fold4/checkpoint-4950/pytorch_model.bin', 
    '../ner_run_v1/fold5/checkpoint-9900/pytorch_model.bin', 
    '../ner_run_v1/fold6/checkpoint-9900/pytorch_model.bin', 
    '../ner_run_v1/fold7/checkpoint-9900/pytorch_model.bin', 
]

In [4]:
output_tensors = []

for cp in checkpoints:
    model = BertWithNER(bert_model=model_name, ner_model=ner_model_name)
    model.load_state_dict(torch.load(cp))

    logits = []
    dataloader = DataLoader(train.dataset['train'].with_format('torch'), batch_size=16)
    for batch in dataloader:
        outputs = model(**batch)
        logits.append(outputs['logits'])

    output_tensors.append(torch.concat(logits))

Some weights of the model checkpoint at hfl/chinese-macbert-base were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at uer/roberta-base-finetuned-cluener2020-chinese were not used when initializi

In [13]:
data = train_df[:1000]
data['prediction'] = torch.argmax(logits, 1)
data = data[['id', 'label', 'prediction', 'text']]
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(data[data.label != data.prediction])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data['prediction'] = torch.argmax(logits, 1)


Unnamed: 0,id,label,prediction,text
6,7,0,1,苏明娟实现了从农家女孩到共青团安徽省委副书记的“逆袭”，这让我们看到：扶贫救困不能只顾解决眼...
7,8,1,0,在翻阅中国话剧100周年纪念活动资料时，他萌生了创作一台寻找中国话剧源头的剧本的意念。
10,11,0,1,我们正在走进五彩缤纷、朝气蓬勃的青春花季。
13,14,0,1,省科学技术奖的推荐、评审和授奖，坚持公开、公平、公正的原则，依法管理，求真务实，注重实效，严...
16,17,0,1,12月9日下午，江西省考古研究所研究室主任张文江在景德镇市唐代南窑遗址考古成果发布会上宣布，...
18,19,0,1,成立陆军领导机构、组建战略支援部队，是中国军队现代化建设的一个重要里程碑，彰显了一个走向复兴...
20,21,0,1,修改后的《中华人民共和国个人所得税法》于9月1日正式实行。新《中华人民共和国个人所得税法》实...
26,27,0,1,保护野生动植物就是保护人类赖以生存的生态环境，国家应加快建立自然保护地管理体系，让重要的自然...
34,35,0,1,其实，日落的景象和日出同样壮观、绮丽，而且神秘、迷人。
37,38,0,1,江苏省出台了校园足球行动计划纲要，创建一批校园足球特色学校和特色县。
