In [None]:
import sys
import gc
import os
import time
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
from functools import partial
from scipy.special import softmax
import torch
from torch import nn
import torch.nn.functional as F

from transformers import LlamaTokenizer, AutoTokenizer, AutoModelForCausalLM
import pickle

import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

In [None]:
# Configuration
CONFIG = {
    # Dataset configuration
    'dataset_path': 'data/SciQ/test.csv',  # Change this to your dataset path
    'dataset_name': 'SciQ',  # For tracking purposes
    
    # Model configuration
    'model_name': 'daryl149/llama-2-7b-chat-hf',
    'torch_dtype': torch.bfloat16,
    'device': 'cuda',
    
    # Evaluation configuration
    'get_sample': False,  # Set to True to test on single sample
    'sample_index': 0,    # Which sample to test if get_sample is True
    'batch_size': 4,      # For STL approach
    
    # Prompt templates
    'baseline_prompt': """<s> [INST] Your task is to analyze the question and answer options A, B, C or D below.
{query}
[/INST] Answer:""",
    
    'stl_prompt': """<s> [INST] Your task is to analyze the question and answer below. If the answer is correct, respond yes, if it is not correct respond no.
{question} [/INST]""",
    
    'cot_prompt': """<s> [INST] Your task is to analyze the question and answer options A, B, C or D below. Let's think step by step and provide your reasoning, then give your final answer in the format "Answer: [A/B/C/D]".

{query}
[/INST] Let me think step by step:

""",
    
    # Generation parameters
    'max_new_tokens_baseline': 1,
    'max_new_tokens_stl': 1,
    'max_new_tokens_cot': 200,
    
    # Output configuration
    'save_results': True,
    'results_dir': 'results',
}

# Create results directory
os.makedirs(CONFIG['results_dir'], exist_ok=True)

print("Configuration:")
for key, value in CONFIG.items():
    if 'prompt' in key:
        print(f"  {key}: [Template defined]")
    else:
        print(f"  {key}: {value}")

In [None]:
def load_model_and_tokenizer(model_name, torch_dtype=torch.bfloat16, device='cuda'):
    """Load model and tokenizer"""
    print(f"Loading model: {model_name}")
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = model.to(device)
    
    # Setup tokenizer for batch processing
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id
    
    print(f"Model loaded successfully on {device}")
    return model, tokenizer

# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(
    CONFIG['model_name'], 
    CONFIG['torch_dtype'], 
    CONFIG['device']
)

In [None]:
def get_token_ids(tokenizer):
    """Get token IDs for labels and yes/no tokens"""
    labels = ["A", "B", "C", "D"]
    labels_token_id = tokenizer.encode("A B C D")[1:]  # Remove BOS token
    yes_token_id = tokenizer.encode("yes")[1:]
    no_token_id = tokenizer.encode("no")[1:]
    
    print(f"Labels token IDs: {labels_token_id}")
    print(f"Yes token IDs: {yes_token_id}")
    print(f"No token IDs: {no_token_id}")
    
    return labels, labels_token_id, yes_token_id, no_token_id

# Get token IDs
labels, labels_token_id, yes_token_id, no_token_id = get_token_ids(tokenizer)

In [None]:
def load_dataset(dataset_path, get_sample=False, sample_index=0):
    """Load and prepare dataset"""
    print(f"Loading dataset from: {dataset_path}")
    
    if not os.path.exists(dataset_path):
        raise FileNotFoundError(f"Dataset not found at {dataset_path}")
    
    df = pd.read_csv(dataset_path)
    
    # Handle different dataset formats
    if 'id' in df.columns:
        df = df.drop("id", axis=1)
    
    df.fillna(' ', inplace=True)
    df = df.astype(str)
    
    if get_sample:
        df = df.iloc[[sample_index]]
        print(f"Using sample at index {sample_index}")
    
    # Create instruction column for baseline and CoT approaches
    df['instruction'] = 'Question: ' + df['question'] + '\n\nA. ' + df['A'] + '\n\nB. ' + df['B'] + '\n\nC. ' + df['C'] + ' \n\nD. ' + df['D']
    
    print(f"Dataset loaded: {len(df)} samples")
    print("\nSample question:")
    print(f"Question: {df['question'].iloc[0]}")
    print(f"A. {df['A'].iloc[0]}")
    print(f"B. {df['B'].iloc[0]}")
    print(f"C. {df['C'].iloc[0]}")
    print(f"D. {df['D'].iloc[0]}")
    print(f"Answer: {df['answer'].iloc[0]}")
    
    return df

# Load dataset
df = load_dataset(CONFIG['dataset_path'], CONFIG['get_sample'], CONFIG['sample_index'])

In [None]:
def count_tokens(tokenizer, text):
    """Count tokens in text"""
    return len(tokenizer.encode(text))

def baseline_approach(model, tokenizer, df, labels_token_id, config):
    """Baseline approach: Direct answer prediction"""
    print("\n" + "="*50)
    print("BASELINE APPROACH")
    print("="*50)
    
    results = []
    preds = []
    logits_list = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing baseline"):
        start_time = time.time()
        
        # Prepare input using configurable prompt
        full_prompt = config['baseline_prompt'].format(query=row['instruction'])
        input_tokens = count_tokens(tokenizer, full_prompt)
        
        inputs = tokenizer(full_prompt, return_tensors="pt").to(f"cuda:{model.device.index}")
        
        # Generate prediction
        with torch.no_grad():
            output = model.generate(
                input_ids=inputs["input_ids"], 
                attention_mask=inputs["attention_mask"],
                max_new_tokens=config['max_new_tokens_baseline'],
                return_dict_in_generate=True, 
                output_scores=True
            )
        
        # Process logits
        first_token_logits = output.scores[0][0]
        option_logits = first_token_logits[labels_token_id].float().cpu().numpy()
        logits_list.append(option_logits)
        
        # Get prediction
        pred = np.array(["A", "B", "C", "D"])[np.argsort(option_logits)[::-1][:4]]
        pred = ' '.join(pred)
        preds.append(pred)
        
        # Calculate metrics
        end_time = time.time()
        latency = end_time - start_time
        output_tokens = config['max_new_tokens_baseline']
        
        # Record results
        results.append({
            'question_id': idx,
            'approach': 'baseline',
            'question': row['question'][:100] + '...' if len(row['question']) > 100 else row['question'],
            'true_answer': row['answer'],
            'predicted_answer': pred.split()[0],
            'full_prediction': pred,
            'input_tokens': input_tokens,
            'output_tokens': output_tokens,
            'latency_seconds': latency,
            'logits': option_logits.tolist(),
            'probabilities': softmax(option_logits).tolist()
        })
    
    # Calculate accuracy
    first_preds = [pred.split()[0] for pred in preds]
    accuracy = sum(1 for i, pred in enumerate(first_preds) if pred == df.iloc[i]['answer']) / len(df)
    
    print(f"\nBaseline Results:")
    print(f"Accuracy: {accuracy * 100:.2f}%")
    print(f"Average latency: {np.mean([r['latency_seconds'] for r in results]):.3f}s")
    print(f"Average input tokens: {np.mean([r['input_tokens'] for r in results]):.1f}")
    
    return results, accuracy

# Run baseline approach
baseline_results, baseline_accuracy = baseline_approach(model, tokenizer, df, labels_token_id, CONFIG)

In [None]:
def get_stl_prompts(row, config):
    """Generate prompts for Single Token Logits approach"""
    base_prompt = config['stl_prompt']
    question = f"\nQuestion: {row['question']}\nProposed answer: "
    
    prompts = []
    for letter in "ABCD":
        prompt_suffix = f"{row[letter]}\n\n### Response:\n"
        full_prompt = base_prompt.format(question=question) + prompt_suffix
        prompts.append(full_prompt)
    
    return prompts

def single_token_logits_approach(model, tokenizer, df, yes_token_id, no_token_id, config):
    """Single Token Logits approach"""
    print("\n" + "="*50)
    print("SINGLE TOKEN LOGITS APPROACH") 
    print("="*50)
    
    # Prepare all prompts
    f = partial(get_stl_prompts, config=config)
    all_prompts = df.apply(f, axis=1).values
    all_prompts = [item for sublist in all_prompts for item in sublist]
    
    results = []
    yes_logits_all = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(all_prompts), config['batch_size']), desc="Processing STL batches"):
            batch_start_time = time.time()
            
            batch = all_prompts[i:i+config['batch_size']]
            
            # Count input tokens for each prompt in batch
            input_tokens_batch = [count_tokens(tokenizer, prompt) for prompt in batch]
            
            batch_tokens = tokenizer(batch, return_tensors="pt", return_attention_mask=True, padding=True).to(f"cuda:{model.device.index}")
            
            batch_outputs = model.generate(
                **batch_tokens,
                max_new_tokens=config['max_new_tokens_stl'],
                return_dict_in_generate=True,
                output_scores=True,
            )
            
            first_token_logits = batch_outputs.scores[0]
            batch_end_time = time.time()
            batch_latency = batch_end_time - batch_start_time
            
            # Process each item in batch
            for j, scores in enumerate(first_token_logits):
                yes_logit = float(scores[yes_token_id].float().cpu().numpy())
                no_logit = float(scores[no_token_id].float().cpu().numpy())
                yes_logits_all.append(yes_logit)
                
                # Calculate which question this belongs to
                question_idx = (i + j) // 4
                option_idx = (i + j) % 4
                
                if question_idx < len(df):  # Safety check
                    results.append({
                        'question_id': question_idx,
                        'option': ['A', 'B', 'C', 'D'][option_idx],
                        'yes_logit': yes_logit,
                        'no_logit': no_logit,
                        'input_tokens': input_tokens_batch[j],
                        'output_tokens': config['max_new_tokens_stl'],
                        'latency_seconds': batch_latency / len(batch)  # Distribute batch latency
                    })
            
            # Cleanup
            del batch_tokens
            del batch_outputs
            gc.collect()
            torch.cuda.empty_cache()
    
    # Reshape results by question
    yes_logits_array = np.array(yes_logits_all)
    reshaped_logits = np.reshape(yes_logits_array, (-1, 4))
    probs_output = np.apply_along_axis(softmax, 1, reshaped_logits)
    
    # Generate predictions
    labels = np.array(["A", "B", "C", "D"])
    preds = []
    question_results = []
    
    for q_idx, (option_logits, option_probs) in enumerate(zip(reshaped_logits, probs_output)):
        pred = labels[np.argsort(option_logits)[::-1][:4]]
        pred = ' '.join(pred)
        preds.append(pred)
        
        # Get question-level metrics
        question_results.append({
            'question_id': q_idx,
            'approach': 'stl',
            'question': df.iloc[q_idx]['question'][:100] + '...' if len(df.iloc[q_idx]['question']) > 100 else df.iloc[q_idx]['question'],
            'true_answer': df.iloc[q_idx]['answer'],
            'predicted_answer': pred.split()[0],
            'full_prediction': pred,
            'input_tokens': np.mean([r['input_tokens'] for r in results if r['question_id'] == q_idx]),
            'output_tokens': 4,  # 4 options, 1 token each
            'latency_seconds': sum([r['latency_seconds'] for r in results if r['question_id'] == q_idx]),
            'yes_logits': option_logits.tolist(),
            'probabilities': option_probs.tolist()
        })
    
    # Calculate accuracy
    first_preds = [pred.split()[0] for pred in preds]
    accuracy = sum(1 for i, pred in enumerate(first_preds) if pred == df.iloc[i]['answer']) / len(df)
    
    print(f"\nSTL Results:")
    print(f"Accuracy: {accuracy * 100:.2f}%")
    print(f"Average latency: {np.mean([r['latency_seconds'] for r in question_results]):.3f}s")
    print(f"Average input tokens: {np.mean([r['input_tokens'] for r in question_results]):.1f}")
    
    return question_results, accuracy

# Run STL approach
stl_results, stl_accuracy = single_token_logits_approach(model, tokenizer, df, yes_token_id, no_token_id, CONFIG)

In [None]:
def extract_answer_from_cot(text):
    """Extract answer from Chain-of-Thought response"""
    # Look for patterns like "Answer: A", "Answer:A", "answer: B", etc.
    patterns = [
        r"Answer:\s*([ABCD])",
        r"answer:\s*([ABCD])",
        r"Answer\s*:\s*([ABCD])",
        r"answer\s*:\s*([ABCD])",
        r"The answer is\s*([ABCD])",
        r"the answer is\s*([ABCD])",
        r"Therefore,?\s*([ABCD])",
        r"therefore,?\s*([ABCD])",
        r"\b([ABCD])\s*is\s*correct",
        r"correct answer is\s*([ABCD])",
        r"([ABCD])\s*$"  # Single letter at end
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).upper()
    
    # If no pattern found, look for the most frequent A, B, C, D in the text
    letters = re.findall(r'\b([ABCD])\b', text)
    if letters:
        from collections import Counter
        most_common = Counter(letters).most_common(1)[0][0]
        return most_common
    
    # Default fallback
    return "A"

def chain_of_thought_approach(model, tokenizer, df, config):
    """Chain-of-Thought approach: Generate reasoning then answer"""
    print("\n" + "="*50)
    print("CHAIN-OF-THOUGHT APPROACH")
    print("="*50)
    
    results = []
    preds = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing CoT"):
        start_time = time.time()
        
        # Prepare input using configurable prompt
        full_prompt = config['cot_prompt'].format(query=row['instruction'])
        input_tokens = count_tokens(tokenizer, full_prompt)
        
        inputs = tokenizer(full_prompt, return_tensors="pt").to(f"cuda:{model.device.index}")
        
        # Generate prediction with reasoning
        with torch.no_grad():
            output = model.generate(
                input_ids=inputs["input_ids"], 
                attention_mask=inputs["attention_mask"],
                max_new_tokens=config['max_new_tokens_cot'],
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode the generated text
        generated_ids = output[0][len(inputs["input_ids"][0]):]
        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # Extract answer from the generated text
        predicted_answer = extract_answer_from_cot(generated_text)
        preds.append(predicted_answer)
        
        # Calculate metrics
        end_time = time.time()
        latency = end_time - start_time
        output_tokens = len(generated_ids)
        
        # Record results
        results.append({
            'question_id': idx,
            'approach': 'cot',
            'question': row['question'][:100] + '...' if len(row['question']) > 100 else row['question'],
            'true_answer': row['answer'],
            'predicted_answer': predicted_answer,
            'full_prediction': generated_text[:500] + '...' if len(generated_text) > 500 else generated_text,
            'reasoning': generated_text,
            'input_tokens': input_tokens,
            'output_tokens': output_tokens,
            'latency_seconds': latency,
            'logits': [],  # Not applicable for CoT
            'probabilities': []  # Not applicable for CoT
        })
    
    # Calculate accuracy
    accuracy = sum(1 for i, pred in enumerate(preds) if pred == df.iloc[i]['answer']) / len(df)
    
    print(f"\nCoT Results:")
    print(f"Accuracy: {accuracy * 100:.2f}%")
    print(f"Average latency: {np.mean([r['latency_seconds'] for r in results]):.3f}s")
    print(f"Average input tokens: {np.mean([r['input_tokens'] for r in results]):.1f}")
    print(f"Average output tokens: {np.mean([r['output_tokens'] for r in results]):.1f}")
    
    return results, accuracy

# Run CoT approach
cot_results, cot_accuracy = chain_of_thought_approach(model, tokenizer, df, CONFIG)

In [None]:
def save_results(baseline_results, stl_results, cot_results, config):
    """Save results to files"""
    if not config['save_results']:
        return
    
    # Combine all results
    all_results = baseline_results + stl_results + cot_results
    results_df = pd.DataFrame(all_results)
    
    # Create filename with model and dataset info
    model_name = config['model_name'].replace('/', '_')
    dataset_name = config['dataset_name']
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    
    filename = f"{config['results_dir']}/results_{model_name}_{dataset_name}_{timestamp}.csv"
    results_df.to_csv(filename, index=False)
    
    print(f"\nResults saved to: {filename}")
    
    # Print summary
    print(f"\nSummary:")
    print(f"Dataset: {dataset_name}")
    print(f"Model: {config['model_name']}")
    print(f"Total questions: {len(baseline_results)}")
    print(f"Baseline accuracy: {baseline_accuracy * 100:.2f}%")
    print(f"STL accuracy: {stl_accuracy * 100:.2f}%")
    print(f"CoT accuracy: {cot_accuracy * 100:.2f}%")
    
    return results_df

# Save results
results_df = save_results(baseline_results, stl_results, cot_results, CONFIG)

In [None]:
# Display sample results
if len(baseline_results) > 0:
    print("\n" + "="*60)
    print("SAMPLE RESULTS")
    print("="*60)
    
    sample_idx = 0
    baseline_sample = baseline_results[sample_idx]
    stl_sample = stl_results[sample_idx]
    cot_sample = cot_results[sample_idx]
    
    print(f"\nQuestion: {baseline_sample['question']}")
    print(f"True Answer: {baseline_sample['true_answer']}")
    
    print(f"\nBaseline:")
    print(f"  Prediction: {baseline_sample['predicted_answer']}")
    print(f"  Latency: {baseline_sample['latency_seconds']:.3f}s")
    print(f"  Input tokens: {baseline_sample['input_tokens']}")
    print(f"  Probabilities: {[f'{p:.3f}' for p in baseline_sample['probabilities']]}")
    
    print(f"\nSTL:")
    print(f"  Prediction: {stl_sample['predicted_answer']}")
    print(f"  Latency: {stl_sample['latency_seconds']:.3f}s")
    print(f"  Input tokens: {stl_sample['input_tokens']:.1f}")
    print(f"  Probabilities: {[f'{p:.3f}' for p in stl_sample['probabilities']]}")
    
    print(f"\nCoT:")
    print(f"  Prediction: {cot_sample['predicted_answer']}")
    print(f"  Latency: {cot_sample['latency_seconds']:.3f}s")
    print(f"  Input/Output tokens: {cot_sample['input_tokens']}/{cot_sample['output_tokens']}")
    print(f"  Reasoning: {cot_sample['reasoning'][:200]}...")

In [None]:
# Performance comparison
print("\n" + "="*60)
print("PERFORMANCE COMPARISON")
print("="*60)

approaches = ['Baseline', 'STL', 'CoT']
accuracies = [baseline_accuracy * 100, stl_accuracy * 100, cot_accuracy * 100]
avg_latencies = [
    np.mean([r['latency_seconds'] for r in baseline_results]),
    np.mean([r['latency_seconds'] for r in stl_results]),
    np.mean([r['latency_seconds'] for r in cot_results])
]
avg_input_tokens = [
    np.mean([r['input_tokens'] for r in baseline_results]),
    np.mean([r['input_tokens'] for r in stl_results]),
    np.mean([r['input_tokens'] for r in cot_results])
]
avg_output_tokens = [
    np.mean([r['output_tokens'] for r in baseline_results]),
    np.mean([r['output_tokens'] for r in stl_results]),
    np.mean([r['output_tokens'] for r in cot_results])
]

print(f"{'Approach':<12} {'Accuracy':<10} {'Latency (s)':<12} {'Input Tokens':<13} {'Output Tokens':<13}")
print("-" * 70)
for i, approach in enumerate(approaches):
    print(f"{approach:<12} {accuracies[i]:<10.2f} {avg_latencies[i]:<12.3f} {avg_input_tokens[i]:<13.1f} {avg_output_tokens[i]:<13.1f}")

In [None]:
# Cleanup
print(f"\nCleaning up...")
del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()
print("Done!")