In [None]:
import os
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template

In [None]:
# Load model
load_path = os.path.join(os.getcwd(), "outputs", "checkpoint-0")
model, tokenizer = FastLanguageModel.from_pretrained(load_path)
print(f"Model loaded from {load_path}")
model = FastLanguageModel.for_inference(model)

In [None]:
from collections import namedtuple
Prompt = namedtuple("Prompt", ["system", "user"])

prompts = [
    Prompt(
        system="You are a helpful assistant, answering user questions about computer-architecture related topics",
        user="What is the PC??"
        ),
    Prompt(
        system="You are a helpful assistant helping with CPU cache line eviction.",
        user="What features were used during your fine-tuning? There are only 4. Pick from [Current PC, Current address, Cache lines, Cache size, Cache Miss Rate, Cache Hit Rate]"
        ),
    
    Prompt(
        system="You are a helpful assistant helping with CPU cache line eviction.",
        user="""
            This is our current prompt to an eviction policy.
            What can we changes should be made to help the model better identify the cache line to evict and ultimately make the eviction policy more efficient?
            
            Prompt: 
            Current PC is <pc>
            Current address: <list-of-addresses>
            Cache lines are: <list-of-cache-lines>
            Eviction:
            """
        ),
    
]

def evaluate(checkpoint, prompt):
    model, tokenizer = FastLanguageModel.from_pretrained(checkpoint)
    model = FastLanguageModel.for_inference(model)
    messages = [{"role": "system", "content": prompt.system}, {"role": "user", "content": prompt.user}]
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,  # Must add for generation
        return_tensors="pt",
    ).to("cuda")
    out = model.generate(
        input_ids=inputs, max_new_tokens=1024, use_cache=True, temperature=0.3, min_p=0.1, do_sample=False
    )
    decoded = tokenizer.decode(out.squeeze()[inputs.shape[1]:].cpu().numpy(), skip_special_tokens=True)
    return decoded

In [None]:
checkpoints = [os.path.join(os.getcwd(), "outputs", f"checkpoint-{idx}") for idx in [0, 1000, 2000, 3000, 4000, 5000, 6000]]

prompt_idx = 2
results = list()

for checkpoint in checkpoints:
    answer = evaluate(checkpoint, prompts[prompt_idx])
    results.append(
        (checkpoint, answer)
    )