Find this notebook at 

In [1]:
# Install dependencies
# !curl -L https://raw.githubusercontent.com/chloeli-15/stream-of-search/refs/heads/main/finetune/setup.py -o ./setup.py
# !python -m pip install .
# !python -m pip install flash-attn --no-build-isolation
# # installation - will take a while

! pip install tiktoken datasets

Collecting tiktoken
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m33.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading datasets-3.5.0-py3-none-any.whl (491 kB)


# Parsers for countdown trajectories

In [2]:
from tqdm import tqdm
# from datasets import load_dataset
import re
import json, sys, os, glob
import numpy as np
from typing import List, Tuple
import math
import matplotlib.pyplot as plt
import re
import re
import ast
import tiktoken
from transformers.utils.logging import disable_progress_bar
disable_progress_bar()


def validate_operations(initial_numbers, operations, target):
    """
    Validates if the operations use all initial numbers exactly once and reach the target.
    """
    # Make a copy of initial numbers to track usage
    available_numbers = initial_numbers.copy()
    # Dictionary to store intermediate results
    intermediate_results = {}

    for i, operation in enumerate(operations):
        # Parse the operation (format like '9-5=4')
        parts = operation.split('=')
        if len(parts) != 2:
            return False

        result_str = parts[1].strip()
        expression = parts[0].strip()

        # Find operator (+, -, *, /)
        op_match = re.search(r'[+\-*/]', expression)
        if not op_match:
            return False

        operator = op_match.group(0)
        operands = expression.split(operator)
        if len(operands) != 2:
            return False

        left_operand = operands[0].strip()
        right_operand = operands[1].strip()

        # Try to get values for operands
        left_val = None
        if left_operand.isdigit() or (left_operand.startswith('-') and left_operand[1:].isdigit()):
            left_val = int(left_operand)
            if left_val in available_numbers:
                available_numbers.remove(left_val)
            elif left_val not in intermediate_results.values():
                return False  # Not an available number or intermediate result
        else:
            return False  # Non-numeric operand

        right_val = None
        if right_operand.isdigit() or (right_operand.startswith('-') and right_operand[1:].isdigit()):
            right_val = int(right_operand)
            if right_val in available_numbers:
                available_numbers.remove(right_val)
            elif right_val not in intermediate_results.values():
                return False  # Not an available number or intermediate result
        else:
            return False  # Non-numeric operand

        # Calculate the result
        calculated_result = None
        if operator == '+':
            calculated_result = left_val + right_val
        elif operator == '-':
            calculated_result = left_val - right_val
        elif operator == '*':
            calculated_result = left_val * right_val
        elif operator == '/':
            if right_val == 0:
                return False  # Division by zero
            calculated_result = left_val / right_val

        # Verify the result matches what's stated
        result_val = int(result_str) if result_str.isdigit() else None
        if result_val is None or abs(calculated_result - result_val) > 1e-10:
            return False

        # Store intermediate result for future operations
        intermediate_results[f"step{i}"] = result_val

    # Check if all initial numbers were used and final result matches target
    return len(available_numbers) == 0 and result_val == target


def evaluate_countdown_trajectory_claude(ds_entry) -> Tuple[bool, str]:
    """
    Given:
      - target: the desired final result.
      - nums: list of initial numbers.
      - trajectory: a string that should include lines of the form:
            SOLUTION: YES/NO
            OPERATIONS: [list of strings like 'A+B=C', ...]
            RESULT: final_value
    This function uses regex to extract the parts and then simulates the operations
    by “consuming” initial numbers (and intermediate results) from a pool, verifying:
      • each operation string has one binary operator
      • each operation is valid (left op right equals given result)
      • the operations use all initial numbers exactly once (by simulating removal from pool).
    Returns a tuple (is_valid, message) where is_valid is True only if the trajectory is correct.
    """
    # target: int, nums: List[int], trajectory: str
    trajectory = ds_entry['completion']
    target = ds_entry['target']
    nums = ds_entry['nums']
    # Extract SOLUTION (YES or NO)
    sol_match = re.search(r"SOLUTION:\s*(YES|NO)", trajectory, re.IGNORECASE)
    if not sol_match:
        return False, "Could not find SOLUTION declaration."
    sol_decl = sol_match.group(1).upper()
    if sol_decl != "YES":
        return False, "The trajectory indicates no valid solution."

    # Extract OPERATIONS list
    ops_match = re.search(r"OPERATIONS:\s*(\[[^\]]*\])", trajectory, re.DOTALL)
    if not ops_match:
        return False, "Could not find OPERATIONS list."
    try:
        operations = ast.literal_eval(ops_match.group(1))
        if not isinstance(operations, list):
            return False, "OPERATIONS is not a list."
    except Exception as e:
        return False, f"Failed to parse OPERATIONS list: {e}"

    # Extract final RESULT
    res_match = re.search(r"RESULT:\s*([-+.\d]+)", trajectory)
    if not res_match:
        return False, "Could not find RESULT."
    try:
        expected_final = float(res_match.group(1))
    except Exception as e:
        return False, f"Failed to parse RESULT: {e}"

    # Simulation: available numbers (as floats).
    available = [float(n) for n in nums]
    # We simulate the sequence of operations.
    # Each operation must be of the form "operand1 operator operand2 = result"
    op_pattern = re.compile(r"^\s*([\d.]+)\s*([\+\-\*/])\s*([\d.]+)\s*=\s*([\d.]+)\s*$")

    for idx, op_str in enumerate(operations):
        m = op_pattern.match(op_str)
        if not m:
            return False, f"Operation '{op_str}' does not match required pattern."
        op1_str, operator, op2_str, given_result_str = m.groups()
        try:
            op1 = float(op1_str)
            op2 = float(op2_str)
            op_result = float(given_result_str)
        except Exception as e:
            return False, f"Error converting numbers in op '{op_str}': {e}"

        # Check that the operation is valid:
        if operator == '+':
            computed = op1 + op2
        elif operator == '-':
            computed = op1 - op2
        elif operator == '*':
            computed = op1 * op2
        elif operator == '/':
            # Avoid division by zero
            if math.isclose(op2,0.0):
                return False, f"Division by zero in op '{op_str}'."
            computed = op1 / op2
        else:
            return False, f"Unknown operator '{operator}' in op '{op_str}'."

        if not math.isclose(computed, op_result, rel_tol=1e-5):
            return False, f"In op '{op_str}', computed {computed} which does not match given {op_result}."

        # Now simulate consumption:
        # For each operand, check if it is "available". If yes, remove one instance.
        # (We assume that if the operand equals a value in available within tolerance,
        #  it comes from the pool of that operand.)
        def consume(value: float, pool: List[float]) -> bool:
            for i, num in enumerate(pool):
                if math.isclose(num, value, rel_tol=1e-5):
                    del pool[i]
                    return True
            return False

        # Try to consume op1 and op2
        if not consume(op1, available):
            return False, f"Operand {op1} in op '{op_str}' not available from initial/intermediate numbers."
        if not consume(op2, available):
            return False, f"Operand {op2} in op '{op_str}' not available from initial/intermediate numbers."
        # Append current operation result to available pool
        available.append(op_result)

    # At the end, exactly one number should remain; it should equal the target.
    if len(available) != 1:
        return False, f"After all operations, expected one value but got {len(available)} values: {available}"
    if not math.isclose(available[0], float(target), rel_tol=1e-5):
        return False, f"Final value {available[0]} does not equal the target {target}."

    return True, f"Trajectory is valid with operations: {operations}"

def evaluate_countdown_trajectory(ds_entry):
    solved, remarks = evaluate_countdown_trajectory_claude(ds_entry)
    enc = tiktoken.get_encoding("cl100k_base")
    tokens = enc.encode(ds_entry['completion'])

    return {
        'solved': solved,
        'target': ds_entry['target'],
        'initial_numbers': ds_entry['nums'],
        'remarks': remarks,
        'completion_length': len(tokens),
    }

def evaluate_countdown_trajectories(results_all_trials):
    """
    Evaluates a dataset of countdown problem solver trajectories and returns a flattened
    list of results in the format expected by visualization tools and wandb logging.
    """
    # Create a dict to track which questions were solved in any trial
    question_solved = {i: False for i in range(len(results_all_trials[0]))}
    trial_success_rates = []

    # Process each trial
    for i in range(len(results_all_trials)):
        successes_in_trial = 0
        for j in range(len(results_all_trials[i])):
            # problem_text = results_all_trials[i][j].get("prompt", "")
            # solution_text = results_all_trials[i][j].get("completion", "")
            results_all_trials[i][j]['parsed_results'] = evaluate_countdown_trajectory(results_all_trials[i][j])

            # Update success counts
            if results_all_trials[i][j]['parsed_results']['solved']:
                successes_in_trial += 1
                question_solved[j] = True

        # Calculate success rate for this trial
        trial_success_rates.append(successes_in_trial / len(results_all_trials[i]))

    # Calculate aggregate statistics
    best_of_n_successes = sum(question_solved.values())
    best_of_n_rate = best_of_n_successes / len(question_solved)
    mean_of_n_trials = sum(trial_success_rates) / len(trial_success_rates)

    # Print summary statistics
    print(f"Success rate for each trial: {trial_success_rates}")
    print(f"\nSummary:")
    print(f"  Best-of-{len(trial_success_rates)} success rate: {best_of_n_rate:.4f} ({best_of_n_successes}/{len(question_solved)})")
    print(f"  Mean success rate across trials: {mean_of_n_trials:.4f}\n")

    # Create the properly formatted results array
    # Start with the metrics summary (which will be at index 1 after hyperparams are added)
    final_results = [{
        'trial_success_rates': trial_success_rates,
        'best_of_n': best_of_n_rate,
        'mean': mean_of_n_trials
    }]

    # Add individual trajectory results (using best results from any trial)
    for j in range(len(results_all_trials[0])):
        # Find the best result for this problem across all trials
        best_result = None
        for i in range(len(results_all_trials)):
            result = results_all_trials[i][j]
            if best_result is None or (result['parsed_results']['solved'] and not best_result['parsed_results']['solved']):
                best_result = result

        # Add this result to our final list
        if best_result:
            final_results.append(best_result)

    return final_results

# Parser for Knights and Knaves

In [3]:
def extract_parts(string):
    # Try the original pattern for folders with "countdown-"
    pattern1 = re.compile(r'(\d+\.\d+B).*countdown-(.+?)$')
    match = pattern1.search(string)

    if match:
        return [match.group(1), match.group(2)]

    string = string.lower()
    # Pattern for folders like "Qwen2.5-1.5B-Instruct"
    pattern2 = re.compile(r'qwen\d+\.\d+-(\d+\.\d+B)-instruct', re.IGNORECASE)
    match = pattern2.search(string)

    if match:
        return [match.group(1), "base_model"]

    # If no pattern matches, return default values
    return ["unknown", "unknown"]


def verify_solution_text(names, solution, solution_text):
    """
    Verifies if the solution_text correctly describes the knight/knave status of each person.

    Args:
        names: List of names
        solution: List of booleans (True for knight, False for knave)
        solution_text: String describing the solution

    Returns:
        Boolean indicating if the solution_text is correct, and any discrepancies found
    """
    # Make sure we have the same number of names and solutions
    if len(names) != len(solution):
        return False, "Mismatch in lengths of names and solution arrays"

    # Clean up the solution text and split by commas and 'and'
    text = solution_text.split("RESULT:")[-1].strip().replace('.', '')
    # Handle 'and' at the end
    text = text.replace(' and ', ', ')

    parts = text.split(', ')

    if len(parts) != len(names):
        return False, f"Solution text has {len(parts)} parts but there are {len(names)} people"

    # Check each person
    discrepancies = []

    for i, part in enumerate(parts):
        # Find which name this part refers to
        name_idx = -1
        for j, name in enumerate(names):
            if name in part:
                name_idx = j
                break

        if name_idx == -1:
            discrepancies.append(f"Couldn't find any name in '{part}'")
            continue

        # Check if the knight/knave status is correct
        is_knight = "knight" in part.lower()
        is_knave = "knave" in part.lower()

        if is_knight and not solution[name_idx]:
            discrepancies.append(f"{names[name_idx]} is described as knight but should be knave")
        elif is_knave and solution[name_idx]:
            discrepancies.append(f"{names[name_idx]} is described as knave but should be knight")
        elif not is_knight and not is_knave:
            discrepancies.append(f"Couldn't determine if {names[name_idx]} is knight or knave in '{part}'")

    return len(discrepancies) == 0, discrepancies


def eval_dataset(data, field='solution_text', verified_col='verified', discrepancies_col='discrepancies'):
    """
    Updates the dataset with verification results.

    Args:
        data: The dataset to update
    """
    verified = []
    discrepancies = []

    for i in range(len(data)):
        names = data['names'][i]
        solution = data['solution'][i]
        solution_text = data[field][i]

        is_verified, discrepancy_list = verify_solution_text(names, solution, solution_text)

        verified.append(is_verified)
        discrepancies.append(", ".join(discrepancy_list))

    data = data.add_column(verified_col, verified)
    data = data.add_column(discrepancies_col, discrepancies)

    return data


def load_results(results_dir="./results/ood"):
    """
    Load all KnK results from the results directory.

    Args:
        results_dir: Directory containing results folders

    Returns:
        Dictionary mapping adapter names to their results
    """
    all_results = {}

    # Find all knk.json files
    knk_files = glob.glob(f"{results_dir}/**/knk.json", recursive=True)

    for file_path in knk_files:
        # Extract adapter name from path
        adapter_name = file_path.split(results_dir + '/')[1].split('/knk.json')[0]

        # Load the JSON data
        with open(file_path, 'r') as f:
            data = json.load(f)

        all_results[adapter_name] = data

    return all_results



# Loading fine-tuned models

In [4]:
#%%
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
import logging
import os, glob

from accelerate import infer_auto_device_map

#%%
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def load_model(adapter_path, base_model=None, use_quantization=False):
    """Load a QLoRA fine-tuned model from Hugging Face"""

    # Get base model name from adapter config if not provided
    # Look for model in local directory
    # if glob.glob(f"{adapter_path}") and glob.glob(f"{adapter_path}/adapter_config.json") == []:
    #     adapter_path = glob.glob(f"{adapter_path}/*/*/adapter_config.json")[0].split("/adapter_config.json")[0]

    peft_config = PeftConfig.from_pretrained(adapter_path)
    base_model = base_model or peft_config.base_model_name_or_path
    logger.info(f"Using base model: {base_model}")

    # Load base model with or without quantization
    if use_quantization:
        logger.info("Loading base model with quantization...")
        # Set up 4-bit quantization
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True
        )

        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            quantization_config=quantization_config,
            device_map="auto",
            trust_remote_code=True,
        )


    else:
        # Load base model without quantization
        logger.info("Loading base model without quantization...")
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            device_map="auto",
            trust_remote_code=True
        )

    # Load and apply adapter weights
    logger.info("Applying LoRA adapters...")
    model = PeftModel.from_pretrained(model, adapter_path)


    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

def load_model_from_hub(model_name):
    """Load a model from Hugging Face Hub"""
    # Load the model
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Set padding token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

def generate_batch(model, tokenizer, prompt, max_new_tokens=512, temperature=0.7):
    """
    Generate text using the loaded model
    Takes input str after chat template has been applied
    """
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, padding_side='right').to(model.device)

    # Generate with sampling
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature if temperature>0.0 else None,
            top_p=0.9 if temperature>0.0 else None,
            top_k=20 if temperature>0.0 else None,
            do_sample=temperature>0.0,
        )

    return tokenizer.batch_decode(outputs, skip_special_tokens=True)


# Run evaluation

In [5]:
import os
import json
import random
import argparse
from datetime import datetime
from tqdm import tqdm

import numpy as np
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, GPTNeoForCausalLM
from datasets import load_dataset, DatasetDict, Dataset

import sys
import pandas as pd

from typing import List, Tuple


def eval_ll(model, tokenizer, data, batch_size=128, context_len=4096, temperature=0.0, n=1):
    """
    Evaluate the model on the data using a sliding window so that the context length is not exceeded
    """
    output_texts_concat = []

    for i, data_batch in tqdm(enumerate(data.iter(batch_size=batch_size)), total=len(data)//batch_size):
        # tokenize and generate data_batch['test_prompt']. Input is a list of dicts with role
        # if args.chat_template:
        chat_inputs = tokenizer.apply_chat_template(data_batch["test_prompt"], return_tensors="pt", padding=True, truncation=True, max_length=context_len, return_length=True, tokenize=False)
        # else:
        #     chat_inputs = data_batch["test_prompt"]["content"] # if no chat template
        outputs = generate_batch(model, tokenizer, chat_inputs, max_new_tokens=context_len, temperature=temperature)
        output_texts_concat.extend(outputs)

    return output_texts_concat


def custom_eval(args=None):
    """Entry point that can be called programmatically or via command line"""
    if args is None:
        args = parser.parse_args()

    timenow = datetime.now().strftime('%Y%m%d-%H%M%S')
    # Initialize wandb if upload_results is True
    if args.experiment_name is None:
        args.experiment_name = f"{timenow}-custom_eval-{args.adapter.split('/')[-1]}"

    torch.manual_seed(args.seed)

    if 'Qwen' not in args.adapter:
        model, tokenizer = load_model(args.adapter, args.ckpt)
    else:
        model, tokenizer = load_model_from_hub(args.adapter)
    model.eval()
    model.cuda()
    tokenizer.pad_token = tokenizer.eos_token

    datasets_to_eval = [
            # Dataset name,        split ,       message_field
            ("MelinaLaimon/stream-of-search", "test", "messages_sos"),
            ("MelinaLaimon/stream-of-search-ood", "countdown_3num", "messages"),
            ("MelinaLaimon/stream-of-search-ood", "countdown_5num", "messages"),
            ("K-and-K/knights-and-knaves", "2ppl", "quiz"),
        ]

    for dataset, split, message_field in datasets_to_eval:
      # load data
      if dataset == "MelinaLaimon/stream-of-search":
          data = load_dataset(dataset, split=split).select(range(args.num))
          data = data.map(lambda x: { # type: ignore
              'test_prompt': [
                  # {'role': 'system', 'content': SYSTEM_PROMPT},
                  x[message_field]["role"=="user"]
              ],
              # 'answer': extract_hash_answer(x['answer'])
          })
      elif dataset == "MelinaLaimon/stream-of-search-ood":
          data = load_dataset(dataset, split=split).select(range(args.num))
          data = data.map(lambda x: { # type: ignore
              'test_prompt': [{"content": x['user_prompt'], "role": "user"}],
          })

      elif dataset == "K-and-K/knights-and-knaves":
          data = load_dataset(dataset, name="train", split=split).select(range(args.num))

          def _message_template(example_question):
              return [{ "content": f"{example_question}.\nConclude with the final result in EXACTLY this format:\n```\nSOLUTION: YES/NO\ \nRESULT: final_value\n```\nThe final_value should be statements separated by commas. For example, 'Michael is a knight, Zoey is a knight, and Ethan is a knight.'", "role": "user" }]

          data = data.map(lambda x: {
              "test_prompt": _message_template(x[message_field])
          })
      else: raise NotImplementedError

      # Add deepseeke prompt for backtracking
      if "deepseek" in adapter:
          print("Adding deepseek instructions to the prompt")
          deepseek_inst = "\nNote that the solution does exist. Verify your solutions before your present your final results and backtrack to correct mistakes from before your mistakes if you have to."

          data = data.map(lambda x: { # type: ignore
              'test_prompt': [
                  {"content": x['test_prompt'][0]["content"] + deepseek_inst, "role": "user"}
              ],
          })

      if "stream-of-search" in dataset:
        results_all_trials = []
        for trial in range(args.gens):

            tokenizer.padding_side = "left"

            results = []
            completions = eval_ll(model, tokenizer, data, batch_size=args.batch_size, context_len=args.ctx, temperature=args.temperature, n=args.gens)
            # parse into list of dictionaries
            for i in range(len(data['test_prompt'])):
                results.append({
                    'nums': data['nums'][i],
                    'target': data['target'][i],
                    'solution': data['solution'][i],
                    'prompt': data['test_prompt'][i][0]['content'],
                    'completion': completions[i]
                })
            results_all_trials.append(results)

        eval_results = evaluate_countdown_trajectories(results_all_trials)
        eval_results.insert(0, {"hyperparams": vars(args)})

        # Save results locally
        model_name = args.adapter.split("/")[-1]
        save_path = os.path.join("results/", f'{model_name}')

        timenow = datetime.now().strftime("%Y%m%d-%H%M%S")
        results_file = f"{save_path}/{split}_{args.num}_{timenow}.json"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        with open(results_file, "w") as f:
            json.dump(eval_results, f, indent=4)

      else: # knk
        results = {
            "scores": {},
            "trajectories": {}
        }

        output_texts_concat = []
        # Generate completions for this batch
        for i, data_batch in tqdm(enumerate(data.iter(batch_size=args.batch_size)), total=len(data)//args.batch_size):
            chat_inputs = tokenizer.apply_chat_template(data_batch["test_prompt"], return_tensors="pt", padding=True, truncation=True, max_length=args.ctx, return_length=True, tokenize=False)
            outputs = generate_batch(model, tokenizer, chat_inputs, max_new_tokens=args.ctx, temperature=args.temperature)
            output_texts_concat.extend(outputs)

        # Add completions column to dataset
        column_name = f"completions_{split}"
        data = data.add_column(column_name, output_texts_concat)

        # Evaluate completions
        verified_column = f"verified_{split}"
        discrepancies_column = f"discrepancies_{split}"
        data = eval_dataset(data, column_name, verified_column, discrepancies_column)

        # Calculate score
        score = data[verified_column].count(True) / len(data) * 100
        print(f"{split} score: {score:.2f}%")

        # Store score and trajectories
        results['scores'][split] = score
        results['trajectories'][split] = []

        # Create trajectory data using the correct column names for each key
        for i in range(len(data)):
            results['trajectories'][split].append({
                'completions': data[column_name][i],
                'verified': data[verified_column][i],
                'discrepancies': data[discrepancies_column][i]
            })

        savepath = f"./results/{adapter.split('/')[-1]}/knk.json"
        os.makedirs(os.path.dirname(savepath), exist_ok=True)
        with open(savepath, 'w') as f:
            json.dump(results, f, indent=4)

# Running eval
In our main experiments, we run 128 samples on each of these models. That will take a very very long time on colab - therefore we're only using 4 samples from each test here and only do the 0.5B models.

Note that in our experiments, we use much larger hyperparameters. In our experiments for Countdown, Countdown3 and Coundown5 we use:
- nums: sample size to eval, in our experiments we use 128
- ctx: context length, we use 16384 for RSoS and 8192 otherwise
- Batch size: for both 0.5B and 1.5B maximum is only around 4 for a T4 GPU

For KnK we use:
- nums: sample size to eval, in our experiments we use 200
- ctx: context length, we use 8192 for RSoS and 4096 otherwise

This is only a toy example with 4 data points. This will take ~1 hour.

In [None]:
adapters = [
    "yeok/qwen-2.5-0.5B-instruct-sft-lora-countdown-deepseek-correct-5k",
    "chloeli/qwen-2.5-0.5B-instruct-sft-lora-countdown-search-react-correct-seq10k-5k",
    "chloeli/qwen-2.5-0.5B-instruct-sft-lora-countdown-search-seq8k-5k",
    "chloeli/qwen-2.5-0.5B-instruct-sft-lora-countdown-optimal-seq8k-5k",
    "Qwen/Qwen2.5-0.5B-Instruct", # Qwen-

    "yeok/qwen-2.5-1.5B-instruct-sft-lora-countdown-deepseek-correct-5k",
    "chloeli/qwen-2.5-1.5B-instruct-sft-lora-countdown-search-react-correct-seq10k-5k",
    "chloeli/qwen-2.5-1.5B-instruct-sft-lora-countdown-search-seq8k-5k",
    "chloeli/qwen-2.5-1.5B-instruct-sft-lora-countdown-optimal-seq8k-5k",
    "Qwen/Qwen2.5-0.5B-Instruct", # Qwen-
]
for adapter in tqdm(adapters, desc="Evaluating models"):
  args = argparse.Namespace(
      seed=4,
      adapter=adapter,
      ckpt=None,
      batch_size=2, # Batch size, for both 0.5B and 1.5B maximum is only around 4 for a T4 GPU
      num=2,        # sample size to eval, in our experiments we use 128
      chat_template=True,
      temperature=0.7,
      ctx=8192,     # In our experiments we use 16384 RSoS and 8192 otherwise
      gens=1,
      experiment_name=adapter,
      upload_results=False,
      wandb_project=None,
      wandb_entity=None
  )
  custom_eval(args)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating test_target split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Adding deepseek instructions to the prompt


Map:   0%|          | 0/2 [00:00<?, ? examples/s]


  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [04:34<00:00, 274.76s/it]


Success rate for each trial: [0.5]

Summary:
  Best-of-1 success rate: 0.5000 (1/2)
  Mean success rate across trials: 0.5000



Generating countdown_3num split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating countdown_5num split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Adding deepseek instructions to the prompt


Map:   0%|          | 0/2 [00:00<?, ? examples/s]


  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [01:32<00:00, 92.72s/it]


Success rate for each trial: [0.0]

Summary:
  Best-of-1 success rate: 0.0000 (0/2)
  Mean success rate across trials: 0.0000



Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Adding deepseek instructions to the prompt


Map:   0%|          | 0/2 [00:00<?, ? examples/s]


  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [06:55<00:00, 415.71s/it]


Success rate for each trial: [0.0]

Summary:
  Best-of-1 success rate: 0.0000 (0/2)
  Mean success rate across trials: 0.0000



Generating 2ppl split:   0%|          | 0/200 [00:00<?, ? examples/s]

Generating 3ppl split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating 4ppl split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating 5ppl split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating 6ppl split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating 7ppl split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating 8ppl split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Adding deepseek instructions to the prompt


Map:   0%|          | 0/2 [00:00<?, ? examples/s]


  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [01:18<00:00, 78.61s/it]
Evaluating models:  10%|█         | 1/10 [15:27<2:19:07, 927.46s/it]

2ppl score: 0.00%



  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [07:17<00:00, 437.10s/it]


Success rate for each trial: [0.0]

Summary:
  Best-of-1 success rate: 0.0000 (0/2)
  Mean success rate across trials: 0.0000




  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  1.02it/s]


Success rate for each trial: [0.0]

Summary:
  Best-of-1 success rate: 0.0000 (0/2)
  Mean success rate across trials: 0.0000




  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:11<00:00, 11.51s/it]


Success rate for each trial: [0.0]

Summary:
  Best-of-1 success rate: 0.0000 (0/2)
  Mean success rate across trials: 0.0000




  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:04<00:00,  4.68s/it]
Evaluating models:  20%|██        | 2/10 [23:17<1:27:46, 658.32s/it]

2ppl score: 0.00%



  0%|          | 0/1 [00:00<?, ?it/s][A

# Visualizing the results

In [None]:
import json
import pandas as pd
import glob as glob

methods = {
    "Baseline": "Qwen2.5-x.5B-Instruct",
    "OP": "qwen-2.5-x.5B-instruct-sft-lora-countdown-optimal-seq8k-5k",
    "SoS": "qwen-2.5-x.5B-instruct-sft-lora-countdown-search-seq8k-5k",
    "RSoS": "qwen-2.5-x.5B-instruct-sft-lora-countdown-search-react-correct-seq10k-5k",
    "Distill": "qwen-2.5-x.5B-instruct-sft-lora-countdown-deepseek-correct-5k"
}

tasks = {
    "Countdown": "test_*.json",
    "Countdown-3": "countdown_3num_*.json",
    "Countdown-5": "countdown_5num_*.json",
    "KnK": "knk.json"
}

sizes = ["0.5B", "1.5B"]

def parse_results_from_json(file):
    try:
        with open(file, 'r') as f:
            data = json.load(f)
        # print("File:", file)
        # Extract the relevant information from the JSON data
        if "knk" in file:
            return data["scores"]["2ppl"]
        if "countdown" in file or "test" in file:
            # hyperparams = data[0]['hyperparams']
            return data[1]['mean']*100
    except Exception as e:
        print("Error reading file:", file)
        print("Error message:", e)
        return None

# initialize results to store for sizes, method_key, task_key
results = {size: {method_key: {task_key: None for task_key in tasks.keys()} for method_key in methods.keys()} for size in sizes}

available_files = []
for size_val in sizes:
    for method_key, method_val in methods.items():
        for task_key, task_val in tasks.items():
            folder_to_look_for = f"results/{method_val.replace('x.5B', size_val)}/{task_val}"
            specific_file = glob.glob(folder_to_look_for)
            if specific_file:
                results[size_val][method_key][task_key] = parse_results_from_json(specific_file[0])
                available_files.append(specific_file[0])
# above works, put their filepath in a pandas df

# put results in a pandas df
df = pd.DataFrame.from_dict({(i, j): results[i][j]
                           for i in results.keys()
                           for j in results[i].keys()},
                           orient='index')
# Tranpose
df = df.transpose()
# make the df float up to 2 decimals
df = df.round(2)
df.to_latex("results.tex", index=True, float_format="%.2f")
df

# Visualizing Trajectories

In [None]:
for file in available_files:
  with open(file, 'r') as f:
    data = json.load(f)
  for item in data:
    print(item)