In [None]:
# === IMPORTS ===
import os
import sys
import json
import numpy as np
from typing import List, Dict
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict

sys.path.insert(0, '/home/shougan/projects/aip-fredashi/shougan/balance-budget')
BASE_DIR = Path('/home/shougan/projects/aip-fredashi/shougan/balance-budget')

from tuning.config import OUTPUTS_DIR
from tuning.inference.ifeval_inference import run_inference_ifeval
from instruction_following_eval import evaluation_lib, instructions_registry

In [None]:
# === UTILITY FUNCTIONS ===

def pass_at_k(n: int, c: int, k: int) -> float:
    """Calculate pass@k: probability that at least one of k samples is correct."""
    if n - c < k:
        return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

def compute_pass_at_k_scores(results_per_prompt: List[List[bool]], k_values: List[int]) -> Dict[int, float]:
    """Compute average pass@k across all prompts."""
    scores = {k: [] for k in k_values}
    for results in results_per_prompt:
        n, c = len(results), sum(results)
        for k in k_values:
            if k <= n:
                scores[k].append(pass_at_k(n, c, k))
    return {k: np.mean(v) for k, v in scores.items() if v}

def save_responses(results: List[Dict], model_name: str):
    path = OUTPUTS / model_name
    path.mkdir(parents=True, exist_ok=True)
    with open(path / "responses_multi_sample.jsonl", "w") as f:
        for r in results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

def load_responses(model_name: str) -> List[Dict]:
    with open(OUTPUTS / model_name / "responses_multi_sample.jsonl") as f:
        return [json.loads(line) for line in f]

In [None]:
# === CONFIGURATION ===
model_name = "llama3-1B"
k_values = [1, 2]
n_samples = 4
temperature = 0.7
run_inference_flag = True
num_examples = 20  # Set to None for full dataset

# Setup paths
OUTPUTS = Path(OUTPUTS_DIR) / "pass@k_responses"
IFEVAL_INPUT_PATH = BASE_DIR / "instruction_following_eval/data/input_data.jsonl"

In [None]:
# === STEP 1: LOAD IFEVAL INPUTS ===
inputs_map = {inp.prompt: inp for inp in evaluation_lib.read_prompt_list(str(IFEVAL_INPUT_PATH))}
print(f"Loaded {len(inputs_map)} IFEval prompts")

In [None]:
print(evaluation_lib.read_prompt_list(str(IFEVAL_INPUT_PATH))[0])
test_prompt = evaluation_lib.read_prompt_list(str(IFEVAL_INPUT_PATH))[0].prompt
print("\n", test_prompt)
print("\n",inputs_map[test_prompt])

In [None]:
# === STEP 2: RUN INFERENCE OR LOAD CACHED RESPONSES ===
print(f"Model: {model_name}")

if run_inference_flag:
    print(f"Running inference with n_samples={n_samples}, temperature={temperature}")
    raw_results = run_inference_ifeval(
        model_name=model_name,
        n_samples=n_samples,
        temperature=temperature,
        save_results=False,
        num_examples=num_examples
    )
    # Group responses by prompt for pass@k evaluation
    # raw_results: [{prompt: "", responses: ["", "", ...]}, ...] where each response is separate
    grouped = defaultdict(list)
    for r in raw_results:
        grouped[r["prompt"]].append(r["response"])
    
    model_results = [{"prompt": p, "responses": resps} for p, resps in grouped.items()]
    print(f"Generated {len(model_results)} prompts with {n_samples} samples each")
    save_responses(model_results, model_name)
    print("Responses saved.")
else:
    print("Loading cached responses...")
    model_results = load_responses(model_name)
    print(f"Loaded {len(model_results)} results")

In [None]:
# === STEP 4: EVALUATE SINGLE RESPONSE USING PRE-BUILT FUNCTIONS ===
def evaluate_single_response(inp: evaluation_lib.InputExample, response: str, strict: bool = True) -> bool:
    """Evaluate a single response using the pre-built IFEval functions."""
    prompt_to_response = {inp.prompt: response}
    
    if strict:
        result = evaluation_lib.test_instruction_following_strict(inp, prompt_to_response)
    else:
        result = evaluation_lib.test_instruction_following_loose(inp, prompt_to_response)
    
    return result.follow_all_instructions

In [None]:
test_response = model_results[0]["responses"][0]
test_eval_input = inputs_map[model_results[0]["prompt"]]
evaluate_single_response(test_eval_input,test_response)

In [None]:
# === STEP 5: COMPUTE PASS@K SCORES ===

model_name = "llama3-8B_sft-tuluif-10000"
model_results = load_responses(model_name)
final_results_strict = defaultdict(list)
final_results_loose = defaultdict(list)
# === PIPELINE ===
results_per_prompt = {}
for item in tqdm(model_results, desc="Evaluating responses"):
    prompt = item["prompt"]
    responses = item["responses"]
    eval_input = inputs_map[prompt]
    strict_results = [evaluate_single_response(eval_input, r, strict=True) for r in responses]
    loose_results = [evaluate_single_response(eval_input, r, strict=False) for r in responses]
    results_per_prompt[prompt] = {
        "strict": strict_results,
        "loose": loose_results,
        **{f"strict_k{i}": pass_at_k(len(strict_results), sum(strict_results), i) for i in k_values},
        **{f"loose_k{i}": pass_at_k(len(loose_results), sum(loose_results), i) for i in k_values},
        "strict_k":[pass_at_k(len(strict_results), sum(strict_results), i) for i in k_values],
        "loose_k":[ pass_at_k(len(loose_results), sum(loose_results), i) for i in k_values]
    }
    for i in k_values:
        final_results_strict[f"k{i}"].append(pass_at_k(len(strict_results), sum(strict_results), i))
        final_results_loose[f"k{i}"].append(pass_at_k(len(loose_results), sum(loose_results), i))
    # for r in responses:
    #     if evaluate_single_response(eval_input, r, strict=True):
    #         print(prompt)
    #         print(r)
    #         print("="*50)
    #         break

In [None]:

print([np.mean(final_results_strict[f"k{i}"]) for i in k_values])


In [None]:
my_sum = 0

for prompt, result in results_per_prompt.items():
    my_sum += result['loose_k64']
print(my_sum/541)

In [None]:
# === FINAL RESULTS ===
print(f"\n{'='*50}")
print(f"RESULTS: {model_name}")
print(f"{'='*50}")
print(f"Strict pass@k: {strict_scores}")
print(f"Loose pass@k:  {loose_scores}")