In [24]:
import torch
from tqdm import tqdm
import torch.nn as nn
from datasets import load_from_disk
from transformers import BertTokenizer,BertConfig,AdamW,BertModel
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

# 定义神经网络
class BertClassificationModel(nn.Module):
	def __init__(self):
		super(BertClassificationModel, self).__init__()   
		#加载预训练模型
		pretrained_weights=r"E:\model\bert-base-chinese"
        #定义Bert模型
		self.bert = BertModel.from_pretrained(pretrained_weights)
		for param in self.bert.parameters():
			param.requires_grad = True
		#定义线性函数      
		self.dense = nn.Linear(768, 2)  #bert默认的隐藏单元数是768， 输出单元是2，表示二分类
	def forward(self, input_ids,token_type_ids,attention_mask):
		#得到bert_output
		bert_output = self.bert(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask)
		#获得预训练模型的输出
		bert_cls_hidden_state = bert_output[1]
		#将768维的向量输入到线性层映射为二维向量
		linear_output = self.dense(bert_cls_hidden_state)
		return  linear_output

# 使用BertTokenizer 编码成Bert需要的输入格式
def encoder(max_len,vocab_path,text_list):
	#将text_list embedding成bert模型可用的输入形式
	#加载分词模型
	tokenizer = BertTokenizer.from_pretrained(vocab_path)
	tokenizer = tokenizer(
		text_list,
		padding = True,
		truncation = True,
		max_length = max_len,
		return_tensors='pt'  # 返回的类型为pytorch tensor
		)
	input_ids = tokenizer['input_ids']
	token_type_ids = tokenizer['token_type_ids']
	attention_mask = tokenizer['attention_mask']
	return input_ids,token_type_ids,attention_mask
# 将数据加载为Tensor格式
def load_data(Dataset):
	text_list = []
	labels = []
	for item in Dataset:
		#label在什么位置就改成对应的index
		label = int(item['label'])
		text = item['text']
		text_list.append(text)
		labels.append(label)
# 调用encoder函数，获得预训练模型的三种输入形式
	input_ids,token_type_ids,attention_mask = encoder(max_len=150,vocab_path=r"E:\model\bert-base-chinese\vocab.txt",text_list=text_list)
	labels = torch.tensor(labels)
	#将encoder的返回值以及label封装为Tensor的形式
	data = TensorDataset(input_ids,token_type_ids,attention_mask,labels)
	return data

#实例化DataLoader
#设定batch_size
batch_size = 16
#从磁盘加载数据
dataset = load_from_disk('E:\datasets\ChnSentiCorp')
#取出训练集
dataset_train = dataset['train']
dataset_validation = dataset['validation']
#调用load_data函数，将数据加载为Tensor形式
dataset_train_ts = load_data(dataset_train)
dataset_validation_ts = load_data(dataset_validation)
#将训练数据和测试数据进行DataLoader实例化
train_loader = DataLoader(dataset=dataset_train_ts, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(dataset=dataset_validation_ts, batch_size=batch_size, shuffle=True)

# 定义验证函数
def dev(model,validation_loader):
	#将模型放到服务器上
	model.to(device)
	#设定模式为验证模式
	model.eval()
	#设定不会有梯度的改变仅作验证
	with torch.no_grad():
		correct = 0
		total = 0
		for step, (input_ids,token_type_ids,attention_mask,labels) in tqdm(enumerate(validation_loader),desc='Dev Itreation:'):
			input_ids,token_type_ids,attention_mask,labels=input_ids.to(device),token_type_ids.to(device),attention_mask.to(device),labels.to(device)
			out_put = model(input_ids,token_type_ids,attention_mask)
			_, predict = torch.max(out_put.data, 1)
			correct += (predict==labels).sum().item()
			total += labels.size(0)
		res = correct / total
		return res

# 定义训练函数 
def train(model,train_loader,validation_loader):
	#将model放到服务器上
	model.to(device)
	#设定模型的模式为训练模式
	model.train()
	#定义模型的损失函数
	criterion = nn.CrossEntropyLoss()
	param_optimizer = list(model.named_parameters())
	no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
	#设置模型参数的权重衰减
	optimizer_grouped_parameters = [
		{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
		'weight_decay': 0.01},
		{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
	]
	#学习率的设置
	optimizer_params = {'lr': 1e-5, 'eps': 1e-6, 'correct_bias': False}
	#使用AdamW 主流优化器
	optimizer = AdamW(optimizer_grouped_parameters, **optimizer_params)
	#学习率调整器，检测准确率的状态，然后衰减学习率
	scheduler = ReduceLROnPlateau(optimizer,mode='max',factor=0.5,min_lr=1e-7, patience=5,verbose= True, threshold=0.0001, eps=1e-08)
	t_total = len(train_loader)
	#设定训练轮次
	total_epochs = 2
	bestAcc = 0
	correct = 0
	total = 0
	print('Training and verification begin!')
	for epoch in range(total_epochs): 
		for step, (input_ids,token_type_ids,attention_mask,labels) in enumerate(train_loader):
			#从实例化的DataLoader中取出数据，并通过 .to(device)将数据部署到服务器上    input_ids,token_type_ids,attention_mask,labels=input_ids.to(device),token_type_ids.to(device),attention_mask.to(device),labels.to(device)
			#梯度清零
			optimizer.zero_grad()
			#将数据输入到模型中获得输出
			out_put =  model(input_ids,token_type_ids,attention_mask)
			#计算损失
			loss = criterion(out_put, labels)
			_, predict = torch.max(out_put.data, 1)
			correct += (predict == labels).sum().item()
			total += labels.size(0)
			loss.backward()
			optimizer.step()
			#每两步进行一次打印
			if (step + 1) % 2 == 0:
				train_acc = correct / total
				print("Train Epoch[{}/{}],step[{}/{}],tra_acc{:.6f} %,loss:{:.6f}".format(epoch + 1, total_epochs, step + 1, len(train_loader),train_acc*100,loss.item()))
			#每五十次进行一次验证
			if (step + 1) % 50 == 0:
				train_acc = correct / total
				#调用验证函数dev对模型进行验证，并将有效果提升的模型进行保存
				acc = dev(model, validation_loader)
				if bestAcc < acc:
					bestAcc = acc
					#模型保存路径
					path = r"E:\output\savedmodel\model_new.pkl"
					torch.save(model, path)
				print("DEV Epoch[{}/{}],step[{}/{}],tra_acc:{:.6f} %,bestAcc{:.6f}%,dev_acc{:.6f} %,loss:{:.6f}".format(epoch + 1, total_epochs, step + 1, len(train_loader),train_acc*100,bestAcc*100,acc*100,loss.item()))
		scheduler.step(bestAcc)



In [25]:
# 设备配置
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [26]:
#实例化模型
model = BertClassificationModel()

In [27]:
#调用训练函数进行训练与验证
train(model,train_loader,validation_loader)



Training and verification begin!
Train Epoch[1/2],step[2/600],tra_acc71.875000 %,loss:0.626144
Train Epoch[1/2],step[4/600],tra_acc62.500000 %,loss:0.821733
Train Epoch[1/2],step[6/600],tra_acc62.500000 %,loss:0.629073
Train Epoch[1/2],step[8/600],tra_acc53.906250 %,loss:0.802826
Train Epoch[1/2],step[10/600],tra_acc55.625000 %,loss:0.644059
Train Epoch[1/2],step[12/600],tra_acc55.729167 %,loss:0.701151
Train Epoch[1/2],step[14/600],tra_acc58.928571 %,loss:0.532708
Train Epoch[1/2],step[16/600],tra_acc62.500000 %,loss:0.494226
Train Epoch[1/2],step[18/600],tra_acc65.625000 %,loss:0.355855
Train Epoch[1/2],step[20/600],tra_acc68.437500 %,loss:0.334875
Train Epoch[1/2],step[22/600],tra_acc69.602273 %,loss:0.590296
Train Epoch[1/2],step[24/600],tra_acc69.010417 %,loss:0.763688
Train Epoch[1/2],step[26/600],tra_acc69.711538 %,loss:0.428184
Train Epoch[1/2],step[28/600],tra_acc70.535714 %,loss:0.604404
Train Epoch[1/2],step[30/600],tra_acc71.458333 %,loss:0.499765
Train Epoch[1/2],step[32/6

Dev Itreation:: 75it [04:13,  3.38s/it]


DEV Epoch[1/2],step[50/600],tra_acc:75.875000 %,bestAcc83.416667%,dev_acc83.416667 %,loss:0.334971
Train Epoch[1/2],step[52/600],tra_acc76.562500 %,loss:0.228233
Train Epoch[1/2],step[54/600],tra_acc76.851852 %,loss:0.172206
Train Epoch[1/2],step[56/600],tra_acc77.120536 %,loss:0.487175
Train Epoch[1/2],step[58/600],tra_acc77.478448 %,loss:0.319302
Train Epoch[1/2],step[60/600],tra_acc77.708333 %,loss:0.241817
Train Epoch[1/2],step[62/600],tra_acc78.024194 %,loss:0.419322
Train Epoch[1/2],step[64/600],tra_acc78.320312 %,loss:0.311719
Train Epoch[1/2],step[66/600],tra_acc78.977273 %,loss:0.098924
Train Epoch[1/2],step[68/600],tra_acc79.503676 %,loss:0.298042
Train Epoch[1/2],step[70/600],tra_acc79.910714 %,loss:0.130378
Train Epoch[1/2],step[72/600],tra_acc80.295139 %,loss:0.177881
Train Epoch[1/2],step[74/600],tra_acc80.743243 %,loss:0.027336
Train Epoch[1/2],step[76/600],tra_acc80.921053 %,loss:0.192691
Train Epoch[1/2],step[78/600],tra_acc80.769231 %,loss:0.922562
Train Epoch[1/2],st

Dev Itreation:: 75it [04:12,  3.36s/it]


DEV Epoch[1/2],step[100/600],tra_acc:82.375000 %,bestAcc87.500000%,dev_acc87.500000 %,loss:0.354603
Train Epoch[1/2],step[102/600],tra_acc82.414216 %,loss:0.435248
Train Epoch[1/2],step[104/600],tra_acc82.692308 %,loss:0.144677
Train Epoch[1/2],step[106/600],tra_acc82.783019 %,loss:0.411722
Train Epoch[1/2],step[108/600],tra_acc82.754630 %,loss:0.132808
Train Epoch[1/2],step[110/600],tra_acc82.954545 %,loss:0.212569
Train Epoch[1/2],step[112/600],tra_acc83.035714 %,loss:0.247810
Train Epoch[1/2],step[114/600],tra_acc83.333333 %,loss:0.153866
Train Epoch[1/2],step[116/600],tra_acc83.405172 %,loss:0.476932
Train Epoch[1/2],step[118/600],tra_acc83.474576 %,loss:0.117877
Train Epoch[1/2],step[120/600],tra_acc83.541667 %,loss:0.258853
Train Epoch[1/2],step[122/600],tra_acc83.452869 %,loss:0.564014
Train Epoch[1/2],step[124/600],tra_acc83.568548 %,loss:0.362664
Train Epoch[1/2],step[126/600],tra_acc83.779762 %,loss:0.180259
Train Epoch[1/2],step[128/600],tra_acc83.984375 %,loss:0.181857
Trai

Dev Itreation:: 75it [04:11,  3.35s/it]


DEV Epoch[1/2],step[150/600],tra_acc:84.791667 %,bestAcc89.000000%,dev_acc89.000000 %,loss:0.263572
Train Epoch[1/2],step[152/600],tra_acc84.868421 %,loss:0.332283
Train Epoch[1/2],step[154/600],tra_acc84.983766 %,loss:0.095542
Train Epoch[1/2],step[156/600],tra_acc85.096154 %,loss:0.298339
Train Epoch[1/2],step[158/600],tra_acc85.166139 %,loss:0.119378
Train Epoch[1/2],step[160/600],tra_acc85.312500 %,loss:0.162254
Train Epoch[1/2],step[162/600],tra_acc85.262346 %,loss:0.401413
Train Epoch[1/2],step[164/600],tra_acc85.251524 %,loss:0.306531
Train Epoch[1/2],step[166/600],tra_acc85.353916 %,loss:0.156823
Train Epoch[1/2],step[168/600],tra_acc85.416667 %,loss:0.403808
Train Epoch[1/2],step[170/600],tra_acc85.257353 %,loss:0.572626
Train Epoch[1/2],step[172/600],tra_acc85.174419 %,loss:0.580039
Train Epoch[1/2],step[174/600],tra_acc85.308908 %,loss:0.152496
Train Epoch[1/2],step[176/600],tra_acc85.404830 %,loss:0.240043
Train Epoch[1/2],step[178/600],tra_acc85.428371 %,loss:0.284264
Trai

Dev Itreation:: 75it [04:11,  3.35s/it]


DEV Epoch[1/2],step[200/600],tra_acc:85.781250 %,bestAcc89.500000%,dev_acc89.500000 %,loss:0.092837
Train Epoch[1/2],step[202/600],tra_acc85.829208 %,loss:0.334193
Train Epoch[1/2],step[204/600],tra_acc85.906863 %,loss:0.468750
Train Epoch[1/2],step[206/600],tra_acc85.891990 %,loss:0.468275
Train Epoch[1/2],step[208/600],tra_acc85.967548 %,loss:0.167955
Train Epoch[1/2],step[210/600],tra_acc85.982143 %,loss:0.412424
Train Epoch[1/2],step[212/600],tra_acc85.996462 %,loss:0.261865
Train Epoch[1/2],step[214/600],tra_acc85.952103 %,loss:0.304708
Train Epoch[1/2],step[216/600],tra_acc85.966435 %,loss:0.172056
Train Epoch[1/2],step[218/600],tra_acc85.951835 %,loss:0.479677
Train Epoch[1/2],step[220/600],tra_acc85.994318 %,loss:0.105435
Train Epoch[1/2],step[222/600],tra_acc86.036036 %,loss:0.292577
Train Epoch[1/2],step[224/600],tra_acc86.104911 %,loss:0.121126
Train Epoch[1/2],step[226/600],tra_acc86.172566 %,loss:0.140068
Train Epoch[1/2],step[228/600],tra_acc86.239035 %,loss:0.143164
Trai

Dev Itreation:: 75it [04:11,  3.35s/it]


DEV Epoch[1/2],step[250/600],tra_acc:86.775000 %,bestAcc89.916667%,dev_acc89.916667 %,loss:0.198803
Train Epoch[1/2],step[252/600],tra_acc86.830357 %,loss:0.157916
Train Epoch[1/2],step[254/600],tra_acc86.860236 %,loss:0.348669
Train Epoch[1/2],step[256/600],tra_acc86.840820 %,loss:0.382891
Train Epoch[1/2],step[258/600],tra_acc86.845930 %,loss:0.337985
Train Epoch[1/2],step[260/600],tra_acc86.899038 %,loss:0.358330
Train Epoch[1/2],step[262/600],tra_acc86.951336 %,loss:0.075473
Train Epoch[1/2],step[264/600],tra_acc86.979167 %,loss:0.306531
Train Epoch[1/2],step[266/600],tra_acc87.006579 %,loss:0.340688
Train Epoch[1/2],step[268/600],tra_acc87.080224 %,loss:0.295960
Train Epoch[1/2],step[270/600],tra_acc87.037037 %,loss:0.629853
Train Epoch[1/2],step[272/600],tra_acc87.040441 %,loss:0.228393
Train Epoch[1/2],step[274/600],tra_acc87.066606 %,loss:0.161662
Train Epoch[1/2],step[276/600],tra_acc87.092391 %,loss:0.174432
Train Epoch[1/2],step[278/600],tra_acc87.072842 %,loss:0.167732
Trai

Dev Itreation:: 75it [04:08,  3.32s/it]


DEV Epoch[1/2],step[300/600],tra_acc:87.312500 %,bestAcc89.916667%,dev_acc87.833333 %,loss:0.371084
Train Epoch[1/2],step[302/600],tra_acc87.334437 %,loss:0.413331
Train Epoch[1/2],step[304/600],tra_acc87.376645 %,loss:0.236385
Train Epoch[1/2],step[306/600],tra_acc87.377451 %,loss:0.401605
Train Epoch[1/2],step[308/600],tra_acc87.418831 %,loss:0.070097
Train Epoch[1/2],step[310/600],tra_acc87.459677 %,loss:0.059098
Train Epoch[1/2],step[312/600],tra_acc87.479968 %,loss:0.425469
Train Epoch[1/2],step[314/600],tra_acc87.539809 %,loss:0.169715
Train Epoch[1/2],step[316/600],tra_acc87.579114 %,loss:0.267189
Train Epoch[1/2],step[318/600],tra_acc87.598270 %,loss:0.093757
Train Epoch[1/2],step[320/600],tra_acc87.617188 %,loss:0.087927
Train Epoch[1/2],step[322/600],tra_acc87.616460 %,loss:0.467100
Train Epoch[1/2],step[324/600],tra_acc87.615741 %,loss:0.364666
Train Epoch[1/2],step[326/600],tra_acc87.672546 %,loss:0.201078
Train Epoch[1/2],step[328/600],tra_acc87.690549 %,loss:0.235306
Trai

Dev Itreation:: 75it [04:07,  3.30s/it]


DEV Epoch[1/2],step[350/600],tra_acc:87.875000 %,bestAcc90.500000%,dev_acc90.500000 %,loss:0.250373
Train Epoch[1/2],step[352/600],tra_acc87.872869 %,loss:0.229697
Train Epoch[1/2],step[354/600],tra_acc87.888418 %,loss:0.313238
Train Epoch[1/2],step[356/600],tra_acc87.938904 %,loss:0.050853
Train Epoch[1/2],step[358/600],tra_acc87.971369 %,loss:0.060678
Train Epoch[1/2],step[360/600],tra_acc87.968750 %,loss:0.331945
Train Epoch[1/2],step[362/600],tra_acc87.966160 %,loss:0.095914
Train Epoch[1/2],step[364/600],tra_acc87.980769 %,loss:0.312719
Train Epoch[1/2],step[366/600],tra_acc87.995219 %,loss:0.202674
Train Epoch[1/2],step[368/600],tra_acc88.026495 %,loss:0.314659
Train Epoch[1/2],step[370/600],tra_acc88.074324 %,loss:0.153767
Train Epoch[1/2],step[372/600],tra_acc88.121640 %,loss:0.063775
Train Epoch[1/2],step[374/600],tra_acc88.151738 %,loss:0.358274
Train Epoch[1/2],step[376/600],tra_acc88.148271 %,loss:0.504370
Train Epoch[1/2],step[378/600],tra_acc88.210979 %,loss:0.085268
Trai

Dev Itreation:: 75it [03:59,  3.20s/it]


DEV Epoch[1/2],step[400/600],tra_acc:88.406250 %,bestAcc90.500000%,dev_acc90.333333 %,loss:0.397944
Train Epoch[1/2],step[402/600],tra_acc88.417289 %,loss:0.084185
Train Epoch[1/2],step[404/600],tra_acc88.443688 %,loss:0.193778
Train Epoch[1/2],step[406/600],tra_acc88.454433 %,loss:0.126475
Train Epoch[1/2],step[408/600],tra_acc88.449755 %,loss:0.221054
Train Epoch[1/2],step[410/600],tra_acc88.475610 %,loss:0.071736
Train Epoch[1/2],step[412/600],tra_acc88.486044 %,loss:0.139626
Train Epoch[1/2],step[414/600],tra_acc88.466184 %,loss:0.454997
Train Epoch[1/2],step[416/600],tra_acc88.491587 %,loss:0.260643
Train Epoch[1/2],step[418/600],tra_acc88.486842 %,loss:0.356250
Train Epoch[1/2],step[420/600],tra_acc88.511905 %,loss:0.348447
Train Epoch[1/2],step[422/600],tra_acc88.551540 %,loss:0.079849
Train Epoch[1/2],step[424/600],tra_acc88.590802 %,loss:0.152501
Train Epoch[1/2],step[426/600],tra_acc88.615023 %,loss:0.284456
Train Epoch[1/2],step[428/600],tra_acc88.609813 %,loss:0.272826
Trai

Dev Itreation:: 75it [04:00,  3.21s/it]


DEV Epoch[1/2],step[450/600],tra_acc:88.819444 %,bestAcc91.833333%,dev_acc91.833333 %,loss:0.133069
Train Epoch[1/2],step[452/600],tra_acc88.813606 %,loss:0.453514
Train Epoch[1/2],step[454/600],tra_acc88.794053 %,loss:0.426684
Train Epoch[1/2],step[456/600],tra_acc88.774671 %,loss:0.159064
Train Epoch[1/2],step[458/600],tra_acc88.810044 %,loss:0.086221
Train Epoch[1/2],step[460/600],tra_acc88.845109 %,loss:0.086737
Train Epoch[1/2],step[462/600],tra_acc88.866342 %,loss:0.065203
Train Epoch[1/2],step[464/600],tra_acc88.900862 %,loss:0.158056
Train Epoch[1/2],step[466/600],tra_acc88.921674 %,loss:0.230441
Train Epoch[1/2],step[468/600],tra_acc88.955662 %,loss:0.134935
Train Epoch[1/2],step[470/600],tra_acc89.002660 %,loss:0.085640
Train Epoch[1/2],step[472/600],tra_acc88.996292 %,loss:0.503653
Train Epoch[1/2],step[474/600],tra_acc88.989979 %,loss:0.442141
Train Epoch[1/2],step[476/600],tra_acc89.009979 %,loss:0.399784
Train Epoch[1/2],step[478/600],tra_acc89.029812 %,loss:0.090952
Trai

Dev Itreation:: 75it [03:57,  3.17s/it]


DEV Epoch[1/2],step[500/600],tra_acc:89.162500 %,bestAcc92.166667%,dev_acc92.166667 %,loss:0.067616
Train Epoch[1/2],step[502/600],tra_acc89.143426 %,loss:0.346785
Train Epoch[1/2],step[504/600],tra_acc89.149306 %,loss:0.404289
Train Epoch[1/2],step[506/600],tra_acc89.155138 %,loss:0.136020
Train Epoch[1/2],step[508/600],tra_acc89.148622 %,loss:0.651542
Train Epoch[1/2],step[510/600],tra_acc89.166667 %,loss:0.138091
Train Epoch[1/2],step[512/600],tra_acc89.196777 %,loss:0.122743
Train Epoch[1/2],step[514/600],tra_acc89.214494 %,loss:0.191869
Train Epoch[1/2],step[516/600],tra_acc89.183624 %,loss:0.360739
Train Epoch[1/2],step[518/600],tra_acc89.201255 %,loss:0.317472
Train Epoch[1/2],step[520/600],tra_acc89.218750 %,loss:0.127023
Train Epoch[1/2],step[522/600],tra_acc89.236111 %,loss:0.371223
Train Epoch[1/2],step[524/600],tra_acc89.241412 %,loss:0.301351
Train Epoch[1/2],step[526/600],tra_acc89.246673 %,loss:0.148440
Train Epoch[1/2],step[528/600],tra_acc89.240057 %,loss:0.204783
Trai

Dev Itreation:: 75it [03:59,  3.20s/it]


DEV Epoch[1/2],step[550/600],tra_acc:89.340909 %,bestAcc92.166667%,dev_acc92.083333 %,loss:0.076465
Train Epoch[1/2],step[552/600],tra_acc89.311594 %,loss:0.241284
Train Epoch[1/2],step[554/600],tra_acc89.327617 %,loss:0.396355
Train Epoch[1/2],step[556/600],tra_acc89.366007 %,loss:0.105739
Train Epoch[1/2],step[558/600],tra_acc89.370520 %,loss:0.117944
Train Epoch[1/2],step[560/600],tra_acc89.352679 %,loss:0.313359
Train Epoch[1/2],step[562/600],tra_acc89.368327 %,loss:0.182034
Train Epoch[1/2],step[564/600],tra_acc89.383865 %,loss:0.132844
Train Epoch[1/2],step[566/600],tra_acc89.333039 %,loss:0.374160
Train Epoch[1/2],step[568/600],tra_acc89.337588 %,loss:0.132997
Train Epoch[1/2],step[570/600],tra_acc89.364035 %,loss:0.065078
Train Epoch[1/2],step[572/600],tra_acc89.368444 %,loss:0.218899
Train Epoch[1/2],step[574/600],tra_acc89.405488 %,loss:0.100511
Train Epoch[1/2],step[576/600],tra_acc89.409722 %,loss:0.082217
Train Epoch[1/2],step[578/600],tra_acc89.413927 %,loss:0.105598
Trai

Dev Itreation:: 75it [03:58,  3.18s/it]


DEV Epoch[1/2],step[600/600],tra_acc:89.593750 %,bestAcc92.916667%,dev_acc92.916667 %,loss:0.316182
Train Epoch[2/2],step[2/600],tra_acc89.607558 %,loss:0.112155
Train Epoch[2/2],step[4/600],tra_acc89.631623 %,loss:0.042632
Train Epoch[2/2],step[6/600],tra_acc89.655528 %,loss:0.110118
Train Epoch[2/2],step[8/600],tra_acc89.658717 %,loss:0.136336
Train Epoch[2/2],step[10/600],tra_acc89.661885 %,loss:0.049295
Train Epoch[2/2],step[12/600],tra_acc89.675245 %,loss:0.332547
Train Epoch[2/2],step[14/600],tra_acc89.708876 %,loss:0.030710
Train Epoch[2/2],step[16/600],tra_acc89.742289 %,loss:0.067195
Train Epoch[2/2],step[18/600],tra_acc89.775485 %,loss:0.019404
Train Epoch[2/2],step[20/600],tra_acc89.808468 %,loss:0.026674
Train Epoch[2/2],step[22/600],tra_acc89.841238 %,loss:0.015758
Train Epoch[2/2],step[24/600],tra_acc89.863782 %,loss:0.110879
Train Epoch[2/2],step[26/600],tra_acc89.896166 %,loss:0.028290
Train Epoch[2/2],step[28/600],tra_acc89.918392 %,loss:0.096047
Train Epoch[2/2],step[

Dev Itreation:: 75it [03:58,  3.19s/it]


DEV Epoch[2/2],step[50/600],tra_acc:90.086538 %,bestAcc93.166667%,dev_acc93.166667 %,loss:0.095459
Train Epoch[2/2],step[52/600],tra_acc90.097776 %,loss:0.176908
Train Epoch[2/2],step[54/600],tra_acc90.118502 %,loss:0.122987
Train Epoch[2/2],step[56/600],tra_acc90.139101 %,loss:0.103460
Train Epoch[2/2],step[58/600],tra_acc90.169073 %,loss:0.068835
Train Epoch[2/2],step[60/600],tra_acc90.189394 %,loss:0.098000
Train Epoch[2/2],step[62/600],tra_acc90.209592 %,loss:0.079552
Train Epoch[2/2],step[64/600],tra_acc90.220256 %,loss:0.178233
Train Epoch[2/2],step[66/600],tra_acc90.212087 %,loss:0.608726
Train Epoch[2/2],step[68/600],tra_acc90.232036 %,loss:0.264960
Train Epoch[2/2],step[70/600],tra_acc90.242537 %,loss:0.184145
Train Epoch[2/2],step[72/600],tra_acc90.252976 %,loss:0.181341
Train Epoch[2/2],step[74/600],tra_acc90.254080 %,loss:0.255055
Train Epoch[2/2],step[76/600],tra_acc90.282914 %,loss:0.032299
Train Epoch[2/2],step[78/600],tra_acc90.283923 %,loss:0.065126
Train Epoch[2/2],st

Dev Itreation:: 75it [03:57,  3.17s/it]


DEV Epoch[2/2],step[100/600],tra_acc:90.464286 %,bestAcc93.166667%,dev_acc91.500000 %,loss:0.166085
Train Epoch[2/2],step[102/600],tra_acc90.491453 %,loss:0.079678
Train Epoch[2/2],step[104/600],tra_acc90.518466 %,loss:0.027104
Train Epoch[2/2],step[106/600],tra_acc90.527620 %,loss:0.114103
Train Epoch[2/2],step[108/600],tra_acc90.554379 %,loss:0.010106
Train Epoch[2/2],step[110/600],tra_acc90.580986 %,loss:0.005324
Train Epoch[2/2],step[112/600],tra_acc90.598666 %,loss:0.249382
Train Epoch[2/2],step[114/600],tra_acc90.598739 %,loss:0.094185
Train Epoch[2/2],step[116/600],tra_acc90.590084 %,loss:0.009122
Train Epoch[2/2],step[118/600],tra_acc90.607591 %,loss:0.060076
Train Epoch[2/2],step[120/600],tra_acc90.607639 %,loss:0.381183
Train Epoch[2/2],step[122/600],tra_acc90.599030 %,loss:0.096547
Train Epoch[2/2],step[124/600],tra_acc90.625000 %,loss:0.086935
Train Epoch[2/2],step[126/600],tra_acc90.625000 %,loss:0.280258
Train Epoch[2/2],step[128/600],tra_acc90.650755 %,loss:0.037220
Trai

Dev Itreation:: 75it [03:58,  3.19s/it]


DEV Epoch[2/2],step[150/600],tra_acc:90.800000 %,bestAcc93.166667%,dev_acc92.916667 %,loss:0.084780
Train Epoch[2/2],step[152/600],tra_acc90.791223 %,loss:0.215840
Train Epoch[2/2],step[154/600],tra_acc90.807361 %,loss:0.051071
Train Epoch[2/2],step[156/600],tra_acc90.806878 %,loss:0.109129
Train Epoch[2/2],step[158/600],tra_acc90.822889 %,loss:0.021396
Train Epoch[2/2],step[160/600],tra_acc90.838816 %,loss:0.055829
Train Epoch[2/2],step[162/600],tra_acc90.846457 %,loss:0.192246
Train Epoch[2/2],step[164/600],tra_acc90.870419 %,loss:0.026778
Train Epoch[2/2],step[166/600],tra_acc90.894256 %,loss:0.072874
Train Epoch[2/2],step[168/600],tra_acc90.901693 %,loss:0.119240
Train Epoch[2/2],step[170/600],tra_acc90.909091 %,loss:0.334063
Train Epoch[2/2],step[172/600],tra_acc90.908355 %,loss:0.107063
Train Epoch[2/2],step[174/600],tra_acc90.931848 %,loss:0.050518
Train Epoch[2/2],step[176/600],tra_acc90.939111 %,loss:0.117987
Train Epoch[2/2],step[178/600],tra_acc90.938303 %,loss:0.083407
Trai

Dev Itreation:: 75it [03:59,  3.19s/it]


DEV Epoch[2/2],step[200/600],tra_acc:91.093750 %,bestAcc93.166667%,dev_acc92.666667 %,loss:0.255332
Train Epoch[2/2],step[202/600],tra_acc91.100374 %,loss:0.079110
Train Epoch[2/2],step[204/600],tra_acc91.106965 %,loss:0.073896
Train Epoch[2/2],step[206/600],tra_acc91.121278 %,loss:0.019093
Train Epoch[2/2],step[208/600],tra_acc91.135520 %,loss:0.249191
Train Epoch[2/2],step[210/600],tra_acc91.134259 %,loss:0.414974
Train Epoch[2/2],step[212/600],tra_acc91.156096 %,loss:0.019180
Train Epoch[2/2],step[214/600],tra_acc91.177826 %,loss:0.010200
Train Epoch[2/2],step[216/600],tra_acc91.199449 %,loss:0.011773
Train Epoch[2/2],step[218/600],tra_acc91.213325 %,loss:0.046600
Train Epoch[2/2],step[220/600],tra_acc91.219512 %,loss:0.211401
Train Epoch[2/2],step[222/600],tra_acc91.240876 %,loss:0.036961
Train Epoch[2/2],step[224/600],tra_acc91.254551 %,loss:0.228608
Train Epoch[2/2],step[226/600],tra_acc91.275726 %,loss:0.016616
Train Epoch[2/2],step[228/600],tra_acc91.289251 %,loss:0.008761
Trai

Dev Itreation:: 75it [03:58,  3.18s/it]


DEV Epoch[2/2],step[250/600],tra_acc:91.426471 %,bestAcc93.583333%,dev_acc93.583333 %,loss:0.305708
Train Epoch[2/2],step[252/600],tra_acc91.439261 %,loss:0.028556
Train Epoch[2/2],step[254/600],tra_acc91.422717 %,loss:0.282940
Train Epoch[2/2],step[256/600],tra_acc91.442757 %,loss:0.060341
Train Epoch[2/2],step[258/600],tra_acc91.448135 %,loss:0.031676
Train Epoch[2/2],step[260/600],tra_acc91.460756 %,loss:0.016549
Train Epoch[2/2],step[262/600],tra_acc91.480568 %,loss:0.052185
Train Epoch[2/2],step[264/600],tra_acc91.478588 %,loss:0.087842
Train Epoch[2/2],step[266/600],tra_acc91.483834 %,loss:0.285417
Train Epoch[2/2],step[268/600],tra_acc91.489055 %,loss:0.120830
Train Epoch[2/2],step[270/600],tra_acc91.501437 %,loss:0.168090
Train Epoch[2/2],step[272/600],tra_acc91.520929 %,loss:0.044661
Train Epoch[2/2],step[274/600],tra_acc91.518879 %,loss:0.199233
Train Epoch[2/2],step[276/600],tra_acc91.538242 %,loss:0.037744
Train Epoch[2/2],step[278/600],tra_acc91.557517 %,loss:0.127501
Trai

Dev Itreation:: 75it [03:58,  3.17s/it]


DEV Epoch[2/2],step[300/600],tra_acc:91.618056 %,bestAcc93.583333%,dev_acc92.583333 %,loss:0.060921
Train Epoch[2/2],step[302/600],tra_acc91.629712 %,loss:0.046311
Train Epoch[2/2],step[304/600],tra_acc91.648230 %,loss:0.027955
Train Epoch[2/2],step[306/600],tra_acc91.659768 %,loss:0.070474
Train Epoch[2/2],step[308/600],tra_acc91.678139 %,loss:0.013178
Train Epoch[2/2],step[310/600],tra_acc91.689560 %,loss:0.043633
Train Epoch[2/2],step[312/600],tra_acc91.694079 %,loss:0.169116
Train Epoch[2/2],step[314/600],tra_acc91.705416 %,loss:0.030090
Train Epoch[2/2],step[316/600],tra_acc91.716703 %,loss:0.034378
Train Epoch[2/2],step[318/600],tra_acc91.727941 %,loss:0.298305
Train Epoch[2/2],step[320/600],tra_acc91.732337 %,loss:0.009025
Train Epoch[2/2],step[322/600],tra_acc91.750271 %,loss:0.013562
Train Epoch[2/2],step[324/600],tra_acc91.761364 %,loss:0.316607
Train Epoch[2/2],step[326/600],tra_acc91.765659 %,loss:0.124447
Train Epoch[2/2],step[328/600],tra_acc91.783405 %,loss:0.012186
Trai

Dev Itreation:: 75it [03:58,  3.19s/it]


DEV Epoch[2/2],step[350/600],tra_acc:91.914474 %,bestAcc93.583333%,dev_acc92.583333 %,loss:0.015222
Train Epoch[2/2],step[352/600],tra_acc91.918330 %,loss:0.168031
Train Epoch[2/2],step[354/600],tra_acc91.928721 %,loss:0.031676
Train Epoch[2/2],step[356/600],tra_acc91.945607 %,loss:0.015913
Train Epoch[2/2],step[358/600],tra_acc91.962422 %,loss:0.018794
Train Epoch[2/2],step[360/600],tra_acc91.966146 %,loss:0.147029
Train Epoch[2/2],step[362/600],tra_acc91.976351 %,loss:0.060557
Train Epoch[2/2],step[364/600],tra_acc91.992998 %,loss:0.043504
Train Epoch[2/2],step[366/600],tra_acc91.996636 %,loss:0.081367
Train Epoch[2/2],step[368/600],tra_acc92.006715 %,loss:0.024791
Train Epoch[2/2],step[370/600],tra_acc92.016753 %,loss:0.089874
Train Epoch[2/2],step[372/600],tra_acc92.033179 %,loss:0.038698
Train Epoch[2/2],step[374/600],tra_acc92.030287 %,loss:0.003802
Train Epoch[2/2],step[376/600],tra_acc92.046619 %,loss:0.029959
Train Epoch[2/2],step[378/600],tra_acc92.062883 %,loss:0.005137
Trai

Dev Itreation:: 75it [03:58,  3.18s/it]


DEV Epoch[2/2],step[400/600],tra_acc:92.150000 %,bestAcc93.583333%,dev_acc93.416667 %,loss:0.021796
Train Epoch[2/2],step[402/600],tra_acc92.165669 %,loss:0.051693
Train Epoch[2/2],step[404/600],tra_acc92.181275 %,loss:0.022649
Train Epoch[2/2],step[406/600],tra_acc92.184394 %,loss:0.217616
Train Epoch[2/2],step[408/600],tra_acc92.199901 %,loss:0.065676
Train Epoch[2/2],step[410/600],tra_acc92.196782 %,loss:0.210993
Train Epoch[2/2],step[412/600],tra_acc92.199852 %,loss:0.126941
Train Epoch[2/2],step[414/600],tra_acc92.215237 %,loss:0.022777
Train Epoch[2/2],step[416/600],tra_acc92.230561 %,loss:0.040117
Train Epoch[2/2],step[418/600],tra_acc92.245825 %,loss:0.026756
Train Epoch[2/2],step[420/600],tra_acc92.242647 %,loss:0.145742
Train Epoch[2/2],step[422/600],tra_acc92.245597 %,loss:0.096700
Train Epoch[2/2],step[424/600],tra_acc92.248535 %,loss:0.396861
Train Epoch[2/2],step[426/600],tra_acc92.257554 %,loss:0.151851
Train Epoch[2/2],step[428/600],tra_acc92.266537 %,loss:0.006214
Trai

Dev Itreation:: 75it [03:59,  3.19s/it]


DEV Epoch[2/2],step[450/600],tra_acc:92.309524 %,bestAcc93.666667%,dev_acc93.666667 %,loss:0.043823
Train Epoch[2/2],step[452/600],tra_acc92.324144 %,loss:0.043059
Train Epoch[2/2],step[454/600],tra_acc92.314991 %,loss:0.214980
Train Epoch[2/2],step[456/600],tra_acc92.323627 %,loss:0.089032
Train Epoch[2/2],step[458/600],tra_acc92.332231 %,loss:0.026956
Train Epoch[2/2],step[460/600],tra_acc92.340802 %,loss:0.071582
Train Epoch[2/2],step[462/600],tra_acc92.331685 %,loss:0.400912
Train Epoch[2/2],step[464/600],tra_acc92.346100 %,loss:0.012798
Train Epoch[2/2],step[466/600],tra_acc92.360460 %,loss:0.065987
Train Epoch[2/2],step[468/600],tra_acc92.368914 %,loss:0.095042
Train Epoch[2/2],step[470/600],tra_acc92.377336 %,loss:0.035781
Train Epoch[2/2],step[472/600],tra_acc92.374067 %,loss:0.798726
Train Epoch[2/2],step[474/600],tra_acc92.382449 %,loss:0.311793
Train Epoch[2/2],step[476/600],tra_acc92.379182 %,loss:0.500892
Train Epoch[2/2],step[478/600],tra_acc92.393321 %,loss:0.042323
Trai

Dev Itreation:: 75it [03:59,  3.19s/it]


DEV Epoch[2/2],step[500/600],tra_acc:92.500000 %,bestAcc93.666667%,dev_acc93.166667 %,loss:0.228871
Train Epoch[2/2],step[502/600],tra_acc92.490926 %,loss:0.240640
Train Epoch[2/2],step[504/600],tra_acc92.498868 %,loss:0.166946
Train Epoch[2/2],step[506/600],tra_acc92.506781 %,loss:0.099649
Train Epoch[2/2],step[508/600],tra_acc92.503384 %,loss:0.155181
Train Epoch[2/2],step[510/600],tra_acc92.500000 %,loss:0.167626
Train Epoch[2/2],step[512/600],tra_acc92.507869 %,loss:0.070509
Train Epoch[2/2],step[514/600],tra_acc92.515709 %,loss:0.027458
Train Epoch[2/2],step[516/600],tra_acc92.517921 %,loss:0.165443
Train Epoch[2/2],step[518/600],tra_acc92.520125 %,loss:0.124751
Train Epoch[2/2],step[520/600],tra_acc92.511161 %,loss:0.094959
Train Epoch[2/2],step[522/600],tra_acc92.524510 %,loss:0.075091
Train Epoch[2/2],step[524/600],tra_acc92.521130 %,loss:0.281581
Train Epoch[2/2],step[526/600],tra_acc92.523313 %,loss:0.085952
Train Epoch[2/2],step[528/600],tra_acc92.525488 %,loss:0.034318
Trai

Dev Itreation:: 75it [03:59,  3.19s/it]


DEV Epoch[2/2],step[550/600],tra_acc:92.608696 %,bestAcc93.666667%,dev_acc93.250000 %,loss:0.036323
Train Epoch[2/2],step[552/600],tra_acc92.621528 %,loss:0.020289
Train Epoch[2/2],step[554/600],tra_acc92.634315 %,loss:0.007628
Train Epoch[2/2],step[556/600],tra_acc92.641652 %,loss:0.060804
Train Epoch[2/2],step[558/600],tra_acc92.654361 %,loss:0.005590
Train Epoch[2/2],step[560/600],tra_acc92.650862 %,loss:0.197503
Train Epoch[2/2],step[562/600],tra_acc92.663511 %,loss:0.006199
Train Epoch[2/2],step[564/600],tra_acc92.670747 %,loss:0.142027
Train Epoch[2/2],step[566/600],tra_acc92.672599 %,loss:0.217294
Train Epoch[2/2],step[568/600],tra_acc92.685146 %,loss:0.006395
Train Epoch[2/2],step[570/600],tra_acc92.676282 %,loss:0.064290
Train Epoch[2/2],step[572/600],tra_acc92.688780 %,loss:0.009828
Train Epoch[2/2],step[574/600],tra_acc92.701235 %,loss:0.018620
Train Epoch[2/2],step[576/600],tra_acc92.703019 %,loss:0.060691
Train Epoch[2/2],step[578/600],tra_acc92.710102 %,loss:0.397717
Trai

Dev Itreation:: 75it [04:01,  3.22s/it]


DEV Epoch[2/2],step[600/600],tra_acc:92.781250 %,bestAcc94.166667%,dev_acc94.166667 %,loss:0.061845


In [1]:
# 定义预测函数
import torch.nn.functional as F
def predict(model,test_loader):
    model.to(device)
    # 将模型中的某些特定层或部分切换到评估模式
    model.eval()
    predicts = []
    predict_probs = []
    with torch.no_grad():
        correct = 0
        total = 0
        for step, (input_ids,token_type_ids,attention_mask,labels) in enumerate(test_loader): 
            input_ids,token_type_ids,attention_mask,labels=input_ids.to(device),token_type_ids.to(device),attention_mask.to(device),labels.to(device)
            out_put = model(input_ids,token_type_ids,attention_mask)
            _, predict = torch.max(out_put.data, 1)
            pre_numpy = predict.cpu().numpy().tolist()
            print("pre_numpy:")
            print(pre_numpy)
            print("labels:")
            print(labels)
            predicts.extend(pre_numpy)
            probs = F.softmax(out_put, dim=1).detach().cpu().numpy().tolist()
            predict_probs.extend(probs)
            correct += (predict==labels).sum().item()
            total += labels.size(0)
        res = correct / total
        print('**************结果**************\npredict_Accuracy : {} %'.format(100 * res))
        #返回预测结果和预测的概率
        return predicts,predict_probs


In [13]:
# 定义预测函数
import torch.nn.functional as F
def predict(model,test_loader):
    model.to(device)
    # 将模型中的某些特定层或部分切换到评估模式
    model.eval()
    predicts = []
    predict_probs = []
    with torch.no_grad():
        correct = 0
        total = 0
        for step, (input_ids,token_type_ids,attention_mask,labels) in enumerate(test_loader): 
            input_ids,token_type_ids,attention_mask,labels=input_ids.to(device),token_type_ids.to(device),attention_mask.to(device),labels.to(device)
            out_put = model(input_ids,token_type_ids,attention_mask)
            _, predict = torch.max(out_put.data, 1)
            pre_numpy = predict.cpu().numpy().tolist()
            print("pre_numpy:")
            print(pre_numpy)
            print("labels:")
            print(labels)
            predicts.extend(pre_numpy)
            probs = F.softmax(out_put, dim=1).detach().cpu().numpy().tolist()
            predict_probs.extend(probs)
            correct += (predict==labels).sum().item()
            total += labels.size(0)
        res = correct / total
        print('**************结果**************\npredict_Accuracy : {} %'.format(100 * res))
        #返回预测结果和预测的概率
        return predicts,predict_probs

# 使用训练好的模型进行预测
# 1、加载测试数据集
dataset_test = dataset['test']
dataset_test_ts = load_data(dataset_test)
test_loader = DataLoader(dataset=dataset_test_ts, batch_size=batch_size, shuffle=False) 
# 2、加载训练好的模型
path = r'E:\output\savedmodel\model_new.pkl'
Trained_model = torch.load(path)
# 3、开始预测
print("The prediction start !\n**************说明**************\n一、1代表正向情感，0代表负面情感。\n二、预测值代表“训练好的模型对该句子所预测的情感”，对应变量pre_numpy。\n三、真实值代表“实际上该句子所表达的情感”，对应变量labels。\n********************************")
#predicts是预测的（0或1），predict_probs是概率值
predicts,predict_probs = predict(Trained_model,test_loader)
#predicts
#predict_probs


The prediction start !
**************说明**************
一、1代表正向情感，0代表负面情感。
二、预测值代表“训练好的模型对该句子所预测的情感”，对应变量pre_numpy。
三、真实值代表“实际上该句子所表达的情感”，对应变量labels。
********************************
pre_numpy:
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]
labels:
tensor([1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1])
pre_numpy:
[0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0]
labels:
tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0])
pre_numpy:
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0]
labels:
tensor([1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0])
pre_numpy:
[1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1]
labels:
tensor([1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1])
pre_numpy:
[0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0]
labels:
tensor([0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0])
pre_numpy:
[0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0]
labels:
tensor([0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0])
pre_numpy:
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0]
labels:
t