In [28]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_path = "./hospital_open_llama_3b"
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.34it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [29]:
from datasets import load_dataset
from tqdm import tqdm
import json

eval_dataset = load_dataset("json", data_files="hospital_test.json", split="train")

def evaluate_position(sample):
    try:
        prompt = pipe.tokenizer.apply_chat_template(sample["messages"][:2], tokenize=False, add_generation_prompt=True)
        outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
        # 提取验证集的位置信息
        generated_positions = json.loads(sample["messages"][2]["content"])["positions"]
        print(generated_positions)
        # 提取生成的位置信息
        print(json.loads(sample["messages"][2]["content"])["explanation"])
        print(f"Generated Answer:{outputs[0]['generated_text'][len(prompt):].strip()}")
        true_positions = json.loads(outputs[0]["generated_text"][len(prompt):].strip())["positions"]
        print(true_positions)
        # 比较生成的位置信息和验证集中的位置信息是否相等
        if generated_positions == true_positions:
            print("right")
            return 1
        else:
            print("wrong")
            return 0
    except json.JSONDecodeError:
        # 如果无法解析为 JSON，直接输出“error format”
        print("error format")
        return 0

success_rate = []
number_of_eval_samples = 127
# iterate over eval dataset and predict
for s in tqdm(eval_dataset.shuffle().select(range(number_of_eval_samples))):
    success_rate.append(evaluate_position(s))

# compute accuracy
accuracy = sum(success_rate) / len(success_rate)

print(f"Accuracy: {accuracy * 100:.2f}%")



  0%|          | 0/127 [00:00<?, ?it/s]


KeyboardInterrupt: 