In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataset import ReasoningHashDataset
checkpoint_path = "model_20250206_083544/model_checkpoint_batch_300"
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
dataset = ReasoningHashDataset(
    tokenizer=tokenizer,
    num_samples=10000,  # Small number for testing
    hash_length=4,  # Shorter hashes for testing
    chains=[2, 3, 4, 5],  # Simpler chain lengths
    vary_hash=True,
    num_chains=3,
    device="cuda",
    rl=True
)



  from .autonotebook import tqdm as notebook_tqdm


In [34]:
import torch as T

# Load tokenizer and model
model = AutoModelForCausalLM.from_pretrained(
    checkpoint_path,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=T.bfloat16
)

# Set model to eval mode for inference
model.eval()

# Quick test to verify loading
# test_input = "Map:\n9lqz=>9h1e\nmvnz=>z0e6\nmvnz=>9lqz\nz0e6=>xtka\nmvnz=>7ck5\nxtka=>ati4\nStart: mvnz\nTask: Multiple hash chains are provided. Find the shortest chain and provide the end 4 char hash. Think hard in tag! Circle your answer in <circle>HERE</circle> after </think>\n-----\nSTART\n"
# input_ids = tokenizer(t, return_tensors="pt").to(model.device)

with T.no_grad():
    output = model.generate(
        input_ids=dataset[0]["input"]["input_ids"].unsqueeze(0),
        attention_mask=dataset[0]["input"]["attention_mask"].unsqueeze(0),
        max_new_tokens=4000,
        temperature=0.85,
        top_p=0.8,
        pad_token_id=tokenizer.eos_token_id,
    )

print("Test output:", tokenizer.decode(output[0], skip_special_tokens=True))

Test output: Map:
10qy=>1vvn
1vvn=>nyw5
j34i=>9ru9
hdka=>fk9r
w0dn=>10qy
2rts=>w0dn
k0l1=>3c5x
2rts=>ev3v
9ru9=>hdka
2rts=>j34i
ev3v=>k0l1
Start: 2rts
Task: Multiple hash chains are provided. Find the shortest chain and provide the end 4 char hash. Think hard in tag! Circle your answer in <circle>HERE</circle> after </think>
-----
START
<think>There are two lists provided. 2rts is listed twice, once as ev3v and once as w0dn. 1vvn is listed twice, once as 10qy and once as nyw5. 9ru9 is listed twice, once as j34i and once as 9ru9. 10qy is listed once as 1vvn and once as 10qy. 2rts is listed once as ev3v and once as w0dn. hdka is listed once as hdka and once as fk9r. k0l1 is listed once as k0l1 and once as 3c5x. j34i is listed once as j34i and once as 9ru9. w0dn is listed once as w0dn and once as 10qy. nyw5 is listed once as 1vvn and once as nyw5. 2rts is listed once as 2rts and once as ev3v. 3c5x is listed once as k0l1 and once as 3c5x. 9ru9 is listed once as 9ru9 and once as hdka. 1vvn 

In [36]:
from typing import List
from tqdm import tqdm
import torch
from collections import Counter
from typing import Tuple

from collections import Counter

def generate_hash_consensus(
    model,
    tokenizer,
    prompt: str,
    num_sequences: int = 8,
    max_length: int = 400
) -> str:
    """
    Generate multiple hashes and return the most common one.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(model.device)
    
    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        num_return_sequences=num_sequences,
        temperature=0.8,
        top_p=0.85,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        max_new_tokens=max_length
    )
    
    hashes = []
    for sequence in output:
        generated_text = tokenizer.decode(sequence)
        generated_text = generated_text[generated_text.index(prompt)+len(prompt):]
        
        try:
            if "<circle>" in generated_text and "</circle>" in generated_text:
                circle = generated_text.split("<circle>")[1].split("</circle>")[0]
                hashes.append(circle)
        except:
            continue
    
    if not hashes:
        return "NO CIRCLE"
    return hashes

def generate_hash(model, tokenizer, prompt, max_length=100):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(model.device)
    pad_token_id = tokenizer.eos_token_id
    
    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        num_return_sequences=1,
        temperature=0.6,
        top_p=0.95,
        pad_token_id=pad_token_id,
        do_sample=True,
    )
    
    generated_text = tokenizer.decode(output[0])
    # slide to just generated part
    generated_text = generated_text[generated_text.index(prompt)+len(prompt):]
    # get whats in between <circle> and </circle>
    # check if it has circle
    try:
        if "<circle>" not in generated_text and "</circle>" not in generated_text:
            return "NO CIRCLE"
        circle = generated_text.split("<circle>")[1].split("</circle>")[0]
    except:
        return "NO CIRCLE"
    return circle

def evaluate_reasoning_hash(model, tokenizer, num_tests: int = 100, hash_length: int = 5, 
                            chains: List[int] = [3, 4, 5, 6], vary_hash: bool = True, num_chains: int = 4):
    model.eval()
    
    # Create evaluation dataset
    eval_dataset = ReasoningHashDataset(tokenizer, num_samples=num_tests, hash_length=hash_length, 
                                        chains=chains, vary_hash=vary_hash, num_chains=num_chains)
    
    correct_predictions = 0
    total_predictions = 0
    chain_accuracies = {}

    for _ in tqdm(range(len(eval_dataset)), desc="Evaluating Reasoning Hash", leave=False):
        full_text, hash_list, start, actual_target, prompt = eval_dataset.get_eval_item(_)
        predicted_target = generate_hash_consensus(model, tokenizer, prompt)
        if actual_target in predicted_target:
            correct_predictions += 1
        # predicted_target = predicted_target[:3].lower()
        # actual_target = actual_target[:3].lower()

        total_predictions += 1
        # if any(predicted_target[i:i+3] == actual_target[i:i+3] for i in range(len(predicted_target)-2)):
        #     correct_predictions += 1
        
        # Determine the chain length for this sample
        chain_length = ReasoningHashDataset.find_shortest_path(hash_list, start) + 1
        
        # Initialize the chain_length entry if it doesn't exist
        if chain_length not in chain_accuracies:
            chain_accuracies[chain_length] = {'correct': 0, 'total': 0}
        
        chain_accuracies[chain_length]['total'] += 1
        if  actual_target in predicted_target:
            chain_accuracies[chain_length]['correct'] += 1
        else:
            print(f"@@@Failed with {predicted_target} and {actual_target}")

    overall_accuracy = correct_predictions / total_predictions
    for length in chain_accuracies:
        if chain_accuracies[length]['total'] > 0:
            chain_accuracies[length]['accuracy'] = chain_accuracies[length]['correct'] / chain_accuracies[length]['total']
        else:
            chain_accuracies[length]['accuracy'] = 0

    return overall_accuracy, chain_accuracies