In [None]:
import time
from typing import Union

import torch

from datasets import load_dataset

from peft import PeftModel

from sklearn.metrics import f1_score

from transformers import GenerationConfig, LlamaTokenizer, LlamaForCausalLM

from tqdm import tqdm

In [None]:
prompt_template = {
    "prompt": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
    "response": "### Response:"    
}

class Prompter(object):
    __slots__ = ("template", "_verbose")

    def __init__(self, verbose: bool = False):
        self._verbose = verbose

    def generate_prompt(
        self,
        definition: str,
        inputs: str,
        targets: Union[None, str] = None,
    ) -> str:
        """Generate a prompt from instruction and input."""
        res = prompt_template["prompt"].format(
            instruction=definition, input=inputs
        )

        if targets:
            res = f"{res}{targets}"

        return res

    def get_response(self, output: str) -> str:
        return output.split(prompt_template["response"])[1].strip()


prompter = Prompter()

In [None]:
base_model = "decapoda-research/llama-13b-hf"
lora_weights = "lora-llama-natural-instructions-13b"
load_8bit = True

In [None]:
tokenizer = LlamaTokenizer.from_pretrained("chainyo/alpaca-lora-7b")
tokenizer.padding_side = "left"
tokenizer.pad_token_id = (0)

In [None]:
model = LlamaForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=load_8bit,
    torch_dtype=torch.float16,
    device_map="auto",
)
model = PeftModel.from_pretrained(
    model,
    lora_weights,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
if not load_8bit:
    model.half()

model.eval()
if torch.__version__ >= "2":
    model = torch.compile(model)

In [None]:
# boolq_dataset = load_dataset("boolq", split="validation[:80%]")
# piqa_dataset = load_dataset("piqa", split="validation")
# winogrande_dataset = load_dataset("winogrande", "winogrande_debiased", split="validation")
openbookqa_dataset = load_dataset("openbookqa", "main", split="validation")

In [None]:
generation_config = GenerationConfig(
    temperature=0.2,
    top_p=0.75,
    top_k=40,
    num_beams=4,
)

In [None]:

prompt = prompter.generate_prompt(
    "In this task, you have to analyze the full sentences and do reasoning and quick maths to find the correct answer.",
    f"You are now a superbowl star. You are the quarterback of the team. Your team is down by 3 points. You are in the last 2 minutes of the game. The other team has a score of 28. What is the score of your team?",
)
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048)
input_ids = inputs["input_ids"].to(model.device)

with torch.no_grad():
    gen_outputs = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=50,
    )

s = gen_outputs.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
response = prompter.get_response(output)
print(response)

In [None]:
preds = []
ground_truths = []

total_time = 0.0
num_iterations = len(openbookqa_dataset)

for data in tqdm(openbookqa_dataset, desc="OpenBookQA", total=num_iterations):
    prompt = prompter.generate_prompt(
        "In this task, you need to read and analyze the input to choose A, B, C or D as the correct response or the correct ending of the input.",
        f"{data['question_stem']}\nA: {data['choices']['text'][0]}\nB: {data['choices']['text'][1]}\nC: {data['choices']['text'][2]}\nD: {data['choices']['text'][3]}",
    )
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048)
    input_ids = inputs["input_ids"].to(model.device)

    start_time = time.time()
    with torch.no_grad():
        gen_outputs = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=50,
        )
    end_time = time.time()

    s = gen_outputs.sequences[0]
    output = tokenizer.decode(s, skip_special_tokens=True)
    response = prompter.get_response(output)
    del inputs, input_ids, gen_outputs, s, output
    
    iteration_time = end_time - start_time
    total_time += iteration_time

    preds.append(response)
    ground_truths.append(data["answerKey"])

In [None]:
ground_truths_to_f1 = ground_truths
preds_to_f1 = preds

In [None]:
f1 = f1_score(ground_truths_to_f1, preds_to_f1, average="micro")
print(f"Lora Natural Instructions F1: {f1:.4f}")

In [None]:
avg_inference_time = total_time / num_iterations
print(f"Average time per inference: {avg_inference_time:.4f}")