In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration
from data_process import text_preprocess, build_dataset, get_sentence_prediction
from rouge import Rouge
from tqdm import tqdm
import pickle
import time
import datetime
from config import config
config = config()

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'BertTokenizer'.


In [2]:
def predict(test_dataset, model, batch_size, is_test=True):
    test_iter = DataLoader(test_dataset, batch_size, shuffle=False)
    
    model = model.to(device)
    model.eval()
    
    sentence_target = []
    sentence_pred = []
    epoch_loss_test, epoch_acc_test = 0.0,0.0
    with torch.no_grad():
        start_time = time.time()
        for inputs in tqdm(test_iter, desc="testing: "):
                inputs_1,inputs_2 = {},{}
                for name in inputs.keys():
                    inputs_1[name],inputs_2[name] = torch.chunk(inputs[name].long(), chunks=2, dim=1)
                    inputs_1[name] = inputs_1[name].contiguous().to(device)
                    inputs_2[name] = inputs_2[name].contiguous().to(device)
                
                if is_test:
                    loss = torch.tensor([0.0])
                    outputs = torch.tensor([])
                    for inputs_ in [inputs_1,inputs_2]:
                        outputs_dict = model(**inputs_) # outputs: [batch, seq_len, vocab_size]
                        loss,outputs = loss+outputs_dict['loss'].item(), torch.cat((outputs, outputs_dict['logits'].cpu()), dim=1)
                    batch_size,seq_len = outputs.size(0), outputs.size(1)

                    batch_acc_test = torch.tensor([0.0])
                    for predicted,target in zip(outputs, inputs['labels'].cpu()):
                        batch_acc_test += (target == predicted.argmax(dim=1)).sum().item()/seq_len

                    epoch_loss_test += loss.item()/batch_size
                    epoch_acc_test += batch_acc_test.item()/batch_size
                    
                    batch_sentence_pred, batch_sentence_target = get_sentence_prediction(outputs, inputs['labels'].cpu())
                    sentence_pred.extend(batch_sentence_pred)
                    sentence_target.extend(batch_sentence_target)
                    
                else:
                    outputs = torch.tensor([])
                    for inputs_ in [inputs_1,inputs_2]:
                        outputs_dict = model(**inputs_) # outputs: [batch, seq_len, vocab_size]
                        outputs = torch.cat((outputs, outputs_dict['logits'].cpu()))
        
                    batch_sentence_pred= get_sentence_prediction(outputs)
                    sentence_pred.extend(batch_sentence_pred)
    
    # 参数打印
    if is_test:
        duration = str(datetime.timedelta(seconds=time.time() - start_time))[:7]
        print("Time: {} | test_loss: {:.3} | test_acc: {:.3}".format(
            duration, epoch_loss_test/len(test_iter), epoch_acc_test/len(test_iter)))
        
    torch.cuda.empty_cache()
    return sentence_pred, sentence_target

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)
is_test=True

if is_test:
    # 文本清洗
    test_data = text_preprocess(config.test_path)
    # 构建dataset
    test_dataset = build_dataset(test_data)
else:
    with open(config.train_dataset, 'rb') as f:
        test_dataset = pickle.load(f)

model = BartForConditionalGeneration.from_pretrained(config.model_name)

device: cuda


  0%|          | 0/1 [00:00<?, ?ba/s]

In [4]:
try:
    saved_model = torch.load("save_models/epoch_8.pkl", map_location=device)
    model.load_state_dict(saved_model['model_state_dict'])
except:
    print("load model state error!")

batch_size = 8
sentence_pred, sentence_target = predict(test_dataset, model, batch_size)

rouge = Rouge()
scores = rouge.get_scores(sentence_pred, sentence_target, avg=True)
print(scores)

testing: 100%|██████████| 63/63 [01:12<00:00,  1.16s/it]


Time: 0:01:12 | test_loss: 0.363 | test_acc: 0.675
{'rouge-1': {'r': 0.7407703618808655, 'p': 0.8301389575131733, 'f': 0.7826470432601791}, 'rouge-2': {'r': 0.44291685246939977, 'p': 0.4246464920849807, 'f': 0.433483433469579}, 'rouge-l': {'r': 0.6710040981424625, 'p': 0.7518462901387658, 'f': 0.7088875704998725}}


In [11]:
show_pred=[]
show_target=[]
for p,t in zip(sentence_pred,sentence_target):
    score=rouge.get_scores(p,t)[0]['rouge-2']
    if score['r']>=0.6:
        show_pred.append(p)
        show_target.append(t)
        print(score)

{'r': 0.6697819314641744, 'p': 0.6304985337243402, 'f': 0.6495468227991256}
{'r': 0.6409495548961425, 'p': 0.5869565217391305, 'f': 0.6127659524564761}
{'r': 0.616519174041298, 'p': 0.5854341736694678, 'f': 0.6005747076470225}
{'r': 0.6205882352941177, 'p': 0.5844875346260388, 'f': 0.6019971419374401}
{'r': 0.6897590361445783, 'p': 0.6735294117647059, 'f': 0.6815476140483278}
{'r': 0.6233062330623306, 'p': 0.5882352941176471, 'f': 0.6052631528989266}
{'r': 0.640117994100295, 'p': 0.6061452513966481, 'f': 0.6226685746306883}
{'r': 0.6013986013986014, 'p': 0.581081081081081, 'f': 0.591065287097696}
{'r': 0.6026936026936027, 'p': 0.5507692307692308, 'f': 0.5755626959747625}
{'r': 0.6184738955823293, 'p': 0.5945945945945946, 'f': 0.6062992076003628}


In [30]:
print('"correct": ',show_target[2].replace(' ',''),'\n"predicted": ',show_pred[2].replace(' ',''))

"correct":  热刺开场115秒就取得领先，维尔通亨同贝尔踢墙配合后从左路突入禁区，最后在12码处的劲射打在埃文斯腿上偏转入近角。第8分钟，登贝莱传球，邓普西22码处左脚劲射偏出。1分钟后，桑德罗传球，列农突入禁区左侧的射门被费迪南德挡出底线。贝尔开出角球，考克尔远点头球攻门偏出。第12分钟，邓普西头球解围，范佩西禁区左侧的射门被前阿森纳队友加拉斯封堵。第21分钟，埃弗拉边路阻挡沃克尔犯规，贝尔任意球传中，维尔通亨8码处头球攻门高出。第23分钟，香川真司赢得禁区边缘外任意球，范佩西传射被贝尔挡出底线。热刺第32分钟扩大比分，登贝莱传球下半场，鲁尼替换吉格斯出场。曼联第51分钟扳回一城，范佩西分球右路，鲁尼传中，无人防守的纳尼前点小禁区边缘捅射破门，1-2。热刺第52分钟再度拉开差距，迪福摆脱费迪南德后直传，贝尔禁区左侧10码处劲射被林德加德勉强扑出，邓普西轻松打入空门，3-1。曼联第53分钟再度扳回一城，范佩西直传，沃克尔盯人不紧，香川真司拿球转身后11码处推射右下角入网，2-3。第55分钟，曼联解围角球，登贝莱传球，贝尔25码处左脚劲射偏出。2分钟后，埃弗拉传球，鲁尼禁区前劲射被弗里德尔没收。第61分钟，列农禁区前对香川真司犯规，鲁尼25码处任意球 
"predicted":  热刺第场仅秒取取得领先，维尔通亨前埃尔踢墙配合后突中路突破禁区，面后的12码处的射射打在费文斯手上偏转入网角。第8分钟，邓贝莱传球，邓普西禁码处劲脚劲射偏出。第分钟后，列德罗传球，列农禁入禁区左侧的射门被费迪南德挡出底线。第尔对出角球，拉克尔头点头球攻门偏出近第14分钟，埃普西传球摆围失埃佩西禁区边侧的射门被封曼森纳边员哈拉挡挡堵。第14分钟，埃弗拉左路对挡维克尔犯规被范尔25意球传中，维尔通亨头码处头球攻门偏出。第24分钟，埃川真司传得禁区前缘外任意球机维佩西直射被挡尔挡出。线。第刺第26分钟扳大比分，埃贝莱传球下半场，吉尼替换吉格斯出场。曼联第50分钟扩回一城，范佩西直球右路，鲁尼传中，纳人防守的纳尼小点8禁区边缘捅射入门，1-2。第刺第54分钟再度拉开差距，迪福背脱埃迪南德后传传，贝尔禁区左侧12码处射射被林德加德勉强扑出，邓普西小松打入空门，3-1。第联第54分钟再度拉平一城，鲁佩西直传，鲁克尔禁人不紧，香川真司禁球转身摆禁码处低射远下角入网，2-3。第59分钟，列联解围角