In [1]:
def verify_solution_text(names, solution, solution_text):
    """
    Verifies if the solution_text correctly describes the knight/knave status of each person.
    
    Args:
        names: List of names
        solution: List of booleans (True for knight, False for knave)
        solution_text: String describing the solution
    
    Returns:
        Boolean indicating if the solution_text is correct, and any discrepancies found
    """
    # Make sure we have the same number of names and solutions
    if len(names) != len(solution):
        return False, "Mismatch in lengths of names and solution arrays"
    
    # Clean up the solution text and split by commas and 'and'
    text = solution_text.split("RESULT:")[-1].strip().replace('.', '')
    # Handle 'and' at the end
    text = text.replace(' and ', ', ')
    
    parts = text.split(', ')
    
    if len(parts) != len(names):
        return False, f"Solution text has {len(parts)} parts but there are {len(names)} people"
    
    # Check each person
    discrepancies = []
    
    for i, part in enumerate(parts):
        # Find which name this part refers to
        name_idx = -1
        for j, name in enumerate(names):
            if name in part:
                name_idx = j
                break
        
        if name_idx == -1:
            discrepancies.append(f"Couldn't find any name in '{part}'")
            continue
            
        # Check if the knight/knave status is correct
        is_knight = "knight" in part.lower()
        is_knave = "knave" in part.lower()
        
        if is_knight and not solution[name_idx]:
            discrepancies.append(f"{names[name_idx]} is described as knight but should be knave")
        elif is_knave and solution[name_idx]:
            discrepancies.append(f"{names[name_idx]} is described as knave but should be knight")
        elif not is_knight and not is_knave:
            discrepancies.append(f"Couldn't determine if {names[name_idx]} is knight or knave in '{part}'")
    
    return len(discrepancies) == 0, discrepancies

# use the results to update the verified and discrepancies column of the data_[key]_[key]set
def eval_dataset(data, field='solution_text', verified_col='verified', discrepancies_col='discrepancies'):
    """
    Updates the dataset with verification results.
    
    Args:
        data: The dataset to update
    """
    verified = []
    discrepancies = []
    
    for i in range(len(data)):
        names = data['names'][i]
        solution = data['solution'][i]
        solution_text = data[field][i]
        
        is_verified, discrepancy_list = verify_solution_text(names, solution, solution_text)
        
        verified.append(is_verified)
        discrepancies.append(", ".join(discrepancy_list))
    
    data = data.add_column(verified_col, verified)
    data = data.add_column(discrepancies_col, discrepancies)
    
    return data

In [7]:
# qwen-2.5-0.5B-instruct-sft-lora-countdown-search-1k
import sys
sys.path.append('/cs/student/msc/ml/2024/ycheah/projects/sos/stream-of-search')
from finetune.run_adapter_model import load_model, generate, generate_batch

adapter="chloeli/qwen-2.5-1.5B-instruct-sft-lora-countdown-search-react-1k"
batch_size=32
model, tokenizer = load_model(adapter)
model.eval()
model.cuda()
tokenizer.pad_token = tokenizer.eos_token

def message_template(example_question):
    return [{ "content": f"{example_question}.\nConclude with the final result in EXACTLY this format:\n```\nSOLUTION: YES/NO\ \nRESULT: final_value\n```\nThe final_value should be statements separated by commas. For example, 'Michael is a knight, Zoey is a knight, and Ethan is a knight.'", "role": "user" }]



2025-04-03 19:09:25,708 - INFO - Using base model: Qwen/Qwen2.5-1.5B-Instruct
2025-04-03 19:09:25,709 - INFO - Loading base model without quantization...
2025-04-03 19:09:25,868 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
2025-04-03 19:09:26,462 - INFO - Applying LoRA adapters...


In [None]:
from tqdm import tqdm
import datasets
data_ = datasets.load_dataset("K-and-K/knights-and-knaves", name="test")


In [9]:
context_len = 512
temperature = 0.7

keys = ["2ppl", "3ppl", "4ppl"]
results = {}
results['trajectories'] = {}
results['scores'] = {}

for key in keys:
    output_texts_concat = []

    data = data_[key]
    data = data.map(lambda x: {
        "test_prompt": message_template(x['quiz']) 
    })
    
    # Generate completions for this batch
    for i, data_batch in tqdm(enumerate(data.iter(batch_size=batch_size)), total=len(data)//batch_size):   
        chat_inputs = tokenizer.apply_chat_template(data_batch["test_prompt"], return_tensors="pt", padding=True, truncation=True, max_length=context_len, return_length=True, tokenize=False)
        outputs = generate_batch(model, tokenizer, chat_inputs, max_new_tokens=context_len, temperature=temperature)
        output_texts_concat.extend(outputs)

    # Add completions column to dataset
    column_name = f"completions_{key}"
    data = data.add_column(column_name, output_texts_concat)
    
    # Evaluate completions
    verified_column = f"verified_{key}"
    discrepancies_column = f"discrepancies_{key}"
    data = eval_dataset(data, column_name, verified_column, discrepancies_column)
    
    # Calculate score
    score = data[verified_column].count(True) / len(data) * 100
    print(f"{key} score: {score:.2f}%")
    
    # Store score and trajectories
    results['scores'][key] = score
    results['trajectories'][key] = []
    
    # Create trajectory data using the correct column names for each key
    for i in range(len(data)):
        results['trajectories'][key].append({
            'completions': data[column_name][i],
            'verified': data[verified_column][i],
            'discrepancies': data[discrepancies_column][i]
        })

import json, os
savepath = f"./results/ood/{adapter}/knk.json"
os.makedirs(os.path.dirname(savepath), exist_ok=True)
with open(savepath, 'w') as f:
    json.dump(results, f, indent=4)

4it [01:16, 19.16s/it]                       


2ppl score: 2.00%


4it [01:18, 19.70s/it]                       


3ppl score: 0.00%


4it [01:20, 20.25s/it]                       


4ppl score: 0.00%


In [10]:
f"verified_{key}"

'verified_4ppl'

In [11]:
len(output_texts_concat)

100