In [None]:
import json
import re
import torch
from ollama import Client
from tqdm import tqdm

MODEL = "gemma3:1b"
client = Client(host="http://localhost:11434")

embedding = torch.nn.Embedding(30, 768)
embedding.load_state_dict(torch.load("softprompt_gemma1b.pt"))
embedding.eval()

TEST_PATH = "test_100.jsonl"
data = [json.loads(x) for x in open(TEST_PATH)]

def build_prompt(q):
    vec = " ".join([f"<v{i}>" for i in range(30)])
    return f"{vec}\nInstruction: {q}"

def extract(response):
    m = re.search(r'"p_answer"\s*:\s*"([^"]+)"', response)
    return m.group(1).strip() if m else ""

correct = 0
for item in tqdm(data):
    prompt = build_prompt(item["question"])
    result = client.generate(model=MODEL, prompt=prompt)
    pred = extract(result["response"]).replace(" ", "")
    
    if pred == item["answer"]:
        correct += 1

acc = correct / len(data) * 100
print(f"\nðŸŽ¯ Evaluation Accuracy = {acc:.2f}%")