In [None]:
import os
import random
from util import generate_eval_prompts, initialize_dfs
from peft import PeftModel, PeftConfig
from transformers import set_seed, AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
from pprint import pprint

In [None]:
seed = 1212

# dataset splits
test_split = 0.1  # 10% of the all
valid_split = 0.05  # 5% of the training set

model_name = "mistralai/Mistral-7B-v0.1"
model_checkpoint = os.path.expanduser("~/finetuning_mistral7b_v1/checkpoint-89")
cache_dir = os.path.expanduser("~/.cache/huggingface/")

In [None]:
set_seed(seed)

In [None]:
_, df_test = initialize_dfs(test=test_split)
s0, r0, p0 = generate_eval_prompts(df_test, shots=0)
s1, r1, p1 = generate_eval_prompts(df_test, shots=1, fuzzy=True)
sources = s0 + s1
references = r0 + r1
prompts = p0 + p1
dataset = Dataset.from_dict({"sources": sources, "references": references, "prompts": prompts})
pprint(dataset)

In [None]:
peftconfig = PeftConfig.from_pretrained(model_checkpoint)

model_base = AutoModelForCausalLM.from_pretrained(
    peftconfig.base_model_name_or_path, device_map="auto", cache_dir=cache_dir
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    add_bos_token=True,
    add_eos_token=False,  # always False for inference
)

new_model = PeftModel.from_pretrained(model_base, model_checkpoint)
print("Peft model loaded")

In [None]:
def generate_response(prompt, model):
  encoded_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
  model_inputs = encoded_input.to('cuda')

  generated_ids = model.generate(**model_inputs,
                                 max_new_tokens=20,
                                 min_new_tokens=1,
                                 do_sample=False,
                                 pad_token_id=tokenizer.eos_token_id)

  decoded_output = tokenizer.batch_decode(generated_ids)
  return decoded_output[0].replace(prompt, "")

In [None]:
pprint(dataset["prompts"][35])

In [None]:
response = generate_response(dataset["prompts"][35], new_model)
print(f"Reference: {dataset['references'][35]}")
print(f"Response: {response}")