In [1]:
!pip install -U evaluate nltk rouge_score absl-py



In [1]:
import evaluate
import torch
import transformers
from typing import Dict
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType, PeftModel,PeftConfig
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from datasets import load_dataset, load_from_disk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id= "../base/qwen/Qwen2-0_5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16, load_in_8bit=True, trust_remote_code=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


In [3]:
data = load_from_disk("../outputs/final/Qwen2-0_5B-instruct-lora/eval_data")

In [4]:
refs = data["output"]
messages = data.map(lambda example: {"message": [{"role": "system", "content": "为下面的新闻生成摘要"},{"role": "user", "content": example["input"]}]})["message"]

In [5]:
preds = []
for message in messages:
    text = tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    preds.append(response)

In [12]:
import sys
sys.path.append("..")
from metrics.bleu.bleu import Bleu
from metrics.rouge.rouge import Rouge

# Calculate BLEU and ROUGE
rouge = evaluate.load("../metrics/rouge")
bleu = Bleu()
print("load done")
result_rouge = rouge.compute(predictions=preds, references=refs)
result_bleu = bleu.compute(predictions=preds, references=refs)

print("ROUGE:", result_rouge)
print("BLEU:", result_bleu)

load done
ROUGE: {'rouge1': 0.06666666666666667, 'rouge2': 0.0, 'rougeL': 0.06666666666666667, 'rougeLsum': 0.06666666666666667}
BLEU: {'bleu': 0.0, 'precisions': [0.0, 0.0, 0.0, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 1.8666666666666667, 'translation_length': 28, 'reference_length': 15}


In [10]:
preds[:5]


['瑞信亚太区私人银行董事总经理、大中华区副主席陶冬对经济前景持乐观态度。预计经济增长率在3.4%-4%，明年初将出现小幅调整。未来半年内，美元仍不会大幅走强，并且短期内不太可能有大的行情，整体上美元汇率将处于90到93的震荡区间。',
 '优步宣布推出新订阅服务 All-Access Plan 初始覆盖洛杉矶、奥斯汀、奥兰多、丹佛和迈阿密',
 '随着互联网存款产品业务全面停摆，多家中小银行开始与互联网平台开展金融服务合作。一些银行希望通过将互联网存款产品引入零售业务智能化线上化服务，以及利用合规性手段推进脱敏化的数据合作，减轻监管压力，以期顺利转为传统意义上的个人理财代理业务。',
 '美国在特朗普就任美国总统后，宣布对中国发起贸易战，中、美关系再次面临新的挑战。',
 '国内女科高危感染防控新战线开炮：国产四价HPV疫苗有望在2019年上市']

In [13]:
refs[:5]

['陶冬：全球经济正经历发展小顶峰，缓慢加息是美国经济主旋律',
 '优步推出订阅服务 价值14.99美元',
 '部分银行希冀将互联网存款产品销售纳入其业务范畴',
 '中国之声：一纸任性征税清单 三轮经贸磋商归零',
 '智飞生物：四价HPV疫苗存在“抢苗”现象 未来供应不成问题']