# Knowledge Editing Benchmark Evaluation

This notebook performs benchmark evaluation for knowledge editing tasks using CounterFact dataset.

## Algorithms Supported
- ICL (In-Context Learning)
- ROME (WIP)
- MEMIT (WIP)
- Fine-tuning (WIP)


## 1. Setup and Imports


In [1]:
import os
import sys
import json
import time
import random
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

# Add frontend directory to path to import evaluation modules
frontend_dir = os.path.join(os.path.dirname(os.getcwd()), 'frontend')
if os.path.exists(frontend_dir):
    sys.path.insert(0, frontend_dir)
else:
    # If running from root directory
    sys.path.insert(0, 'frontend')

from eval import CounterfactEvaluator, generate_text
from algorithms.icl import ICLAlgorithm




## 2. Configuration


In [2]:
# Model Configuration
MODEL_NAME = "google/gemma-2-2b"  # Change to your desired model
CUDA_VISIBLE_DEVICES = "0"  # GPU device(s) to use

# Generation Parameters
MAX_NEW_TOKENS = 15
TEMPERATURE = 0.7
BATCH_SIZE = 16

# Dataset Configuration
DATASET_NAME = "counterfact"  # Options: counterfact, mquake-cf, wikiupdate
USE_SUBSAMPLING = True  # Use subset of data for faster evaluation
SUBSAMPLING_SIZE = 10  # Number of samples to evaluate
SUBSAMPLING_SEED = 42  # Random seed for subsampling

# Algorithm
ALGORITHM = "icl"  # Options: icl, rome, memit, ft


## 3. Load Model


In [3]:
# Set CUDA devices
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES

print(f"Loading model: {MODEL_NAME}")
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

print(f"Model loaded successfully on device: {device}")


Loading model: google/gemma-2-2b
CUDA_VISIBLE_DEVICES: 0
Using device: cuda


`torch_dtype` is deprecated! Use `dtype` instead!
  warn(
2025-12-13 10:06:36.506266: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-13 10:06:36.520943: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-13 10:06:36.539233: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-13 10:06:36.544690: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-13 10

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model loaded successfully on device: cuda


## 4. Initialize Evaluator and Algorithm


In [4]:
# Initialize evaluator
evaluator = CounterfactEvaluator()

# Get dataset
if hasattr(evaluator.data, '__iter__') and not isinstance(evaluator.data, dict):
    data_list = list(evaluator.data)
else:
    data_list = evaluator.data if isinstance(evaluator.data, list) else list(evaluator.data.values())

print(f"Total cases in dataset: {len(data_list)}")

# Apply subsampling if requested
if USE_SUBSAMPLING and SUBSAMPLING_SIZE < len(data_list):
    random.seed(SUBSAMPLING_SEED)
    data_list = random.sample(data_list, SUBSAMPLING_SIZE)
    print(f"Subsampling: Selected {len(data_list)} cases from {len(evaluator.data)} total cases (seed={SUBSAMPLING_SEED})")

# Temporarily replace evaluator data with subsampled data
original_data = evaluator.data
evaluator.data = data_list

# Initialize algorithm
if ALGORITHM == "icl":
    algorithm = ICLAlgorithm(model=model, tokenizer=tokenizer)
    print(f"Initialized {ALGORITHM.upper()} algorithm")
else:
    raise ValueError(f"Algorithm {ALGORITHM} not yet implemented. Use 'icl' for now.")


Total cases in dataset: 19728
Subsampling: Selected 10 cases from 19728 total cases (seed=42)
Initialized ICL algorithm


## 5. Run Evaluation


In [None]:
# Progress tracking callback
def progress_callback(case_idx, case_result, total):
    """Print progress during evaluation"""
    if (case_idx + 1) % 10 == 0 or (case_idx + 1) == total:
        progress = ((case_idx + 1) / total * 100)
        subject = case_result.get('subject', 'N/A')
        print(f"Progress: {case_idx + 1}/{total} ({progress:.1f}%) - Current: {subject}")

print(f"Starting evaluation with {len(data_list)} cases...")
print(f"Parameters: max_new_tokens={MAX_NEW_TOKENS}, temperature={TEMPERATURE}, batch_size={BATCH_SIZE}")
print("-" * 60)

start_time = time.time()

# Run evaluation
results = evaluator.evaluate(
    algorithm=algorithm,
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=TEMPERATURE,
    batch_size=BATCH_SIZE,
    progress_callback=progress_callback
)

elapsed_time = time.time() - start_time

print("-" * 60)
print(f"Evaluation completed in {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")
print(f"Total cases evaluated: {len(results)}")

# Restore original data
evaluator.data = original_data


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Starting evaluation with 10 cases...
Parameters: max_new_tokens=15, temperature=0.7, batch_size=16
------------------------------------------------------------


## 6. Calculate Metrics


In [None]:
def calculate_metrics(results):
    """Calculate aggregate metrics from evaluation results"""
    all_reliability_scores = []
    all_portability_scores = []
    all_locality_scores = []
    
    for case_id, case_result in results.items():
        eval_data = case_result.get('eval', {})
        
        # Reliability: paraphrase_prompts scores
        paraphrase_scores = eval_data.get('icl_scores_paraphrase_prompts', [])
        all_reliability_scores.extend(paraphrase_scores)
        
        # Portability: generation_prompts scores
        generation_scores = eval_data.get('icl_scores_generation_prompts', [])
        all_portability_scores.extend(generation_scores)
        
        # Locality: neighborhood_prompts + attribute_prompts scores
        neighborhood_scores = eval_data.get('icl_scores_neighborhood_prompts', [])
        attribute_scores = eval_data.get('icl_scores_attribute_prompts', [])
        all_locality_scores.extend(neighborhood_scores)
        all_locality_scores.extend(attribute_scores)
    
    metrics = {
        'reliability': sum(all_reliability_scores) / len(all_reliability_scores) if all_reliability_scores else 0.0,
        'portability': sum(all_portability_scores) / len(all_portability_scores) if all_portability_scores else 0.0,
        'locality': sum(all_locality_scores) / len(all_locality_scores) if all_locality_scores else 0.0,
        'total_cases': len(results)
    }
    
    return metrics

metrics = calculate_metrics(results)

print("Overall Metrics:")
print(f"  Reliability: {metrics['reliability']*100:.2f}%")
print(f"  Portability: {metrics['portability']*100:.2f}%")
print(f"  Locality: {metrics['locality']*100:.2f}%")
print(f"  Total Cases: {metrics['total_cases']}")


## 7. Detailed Results Analysis


In [None]:
import pandas as pd
import numpy as np

# Convert results to DataFrame for easier analysis
case_results = []
for case_id, case_result in results.items():
    eval_data = case_result.get('eval', {})
    
    # Calculate average scores for each metric
    pr_scores = eval_data.get('icl_scores_paraphrase_prompts', [])
    gr_scores = eval_data.get('icl_scores_generation_prompts', [])
    nr_scores = eval_data.get('icl_scores_neighborhood_prompts', [])
    at_scores = eval_data.get('icl_scores_attribute_prompts', [])
    
    case_results.append({
        'Case ID': case_id,
        'Subject': case_result.get('subject', ''),
        'Target New': case_result.get('target_new', ''),
        'Target Old': case_result.get('target_old', ''),
        'PR (Paraphrase Reliability)': np.mean(pr_scores) if pr_scores else 0.0,
        'GR (Generation Portability)': np.mean(gr_scores) if gr_scores else 0.0,
        'NR (Neighborhood Locality)': np.mean(nr_scores) if nr_scores else 0.0,
        'AT (Attribute Locality)': np.mean(at_scores) if at_scores else 0.0,
    })

df = pd.DataFrame(case_results)

# Display summary statistics
print("Score Statistics:")
print(df[['PR (Paraphrase Reliability)', 'GR (Generation Portability)', 
          'NR (Neighborhood Locality)', 'AT (Attribute Locality)']].describe())

# Display first few results
print("\nFirst 10 Cases:")
display(df.head(10))


## 8. Visualization


In [None]:
import matplotlib.pyplot as plt

# Plot score distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Score Distributions', fontsize=16)

metrics_list = ['PR (Paraphrase Reliability)', 'GR (Generation Portability)', 
                'NR (Neighborhood Locality)', 'AT (Attribute Locality)']

for idx, metric in enumerate(metrics_list):
    ax = axes[idx // 2, idx % 2]
    ax.hist(df[metric], bins=20, alpha=0.7, edgecolor='black')
    ax.set_title(metric)
    ax.set_xlabel('Score')
    ax.set_ylabel('Frequency')
    ax.set_xlim(0, 1)
    ax.axvline(df[metric].mean(), color='red', linestyle='--', label=f'Mean: {df[metric].mean():.3f}')
    ax.legend()

plt.tight_layout()
plt.show()

# Plot overall metrics
fig, ax = plt.subplots(figsize=(8, 6))
metric_names = ['Reliability', 'Portability', 'Locality']
metric_values = [metrics['reliability'], metrics['portability'], metrics['locality']]
bars = ax.bar(metric_names, [v*100 for v in metric_values], color=['#4CAF50', '#2196F3', '#FF9800'])
ax.set_ylabel('Score (%)')
ax.set_title('Overall Metrics')
ax.set_ylim(0, 100)

# Add value labels on bars
for bar, value in zip(bars, metric_values):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{value*100:.2f}%',
            ha='center', va='bottom')

plt.tight_layout()
plt.show()


## 9. Save Results


In [None]:
# Save results to JSON
from datetime import datetime

output_data = {
    'config': {
        'model_name': MODEL_NAME,
        'algorithm': ALGORITHM,
        'max_new_tokens': MAX_NEW_TOKENS,
        'temperature': TEMPERATURE,
        'batch_size': BATCH_SIZE,
        'dataset': DATASET_NAME,
        'use_subsampling': USE_SUBSAMPLING,
        'subsampling_size': SUBSAMPLING_SIZE if USE_SUBSAMPLING else None,
        'subsampling_seed': SUBSAMPLING_SEED if USE_SUBSAMPLING else None,
    },
    'metrics': metrics,
    'results': results,
    'elapsed_time': elapsed_time
}

# Create output filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"benchmark_results_{DATASET_NAME}_{ALGORITHM}_{timestamp}.json"

with open(output_filename, 'w', encoding='utf-8') as f:
    json.dump(output_data, f, indent=2, ensure_ascii=False)

print(f"Results saved to: {output_filename}")

# Also save DataFrame as CSV
csv_filename = f"benchmark_results_{DATASET_NAME}_{ALGORITHM}_{timestamp}.csv"
df.to_csv(csv_filename, index=False)
print(f"CSV saved to: {csv_filename}")


## 10. Examine Individual Cases


In [None]:
# View detailed results for a specific case
case_id_to_examine = list(results.keys())[0]  # Change to examine different case
case_result = results[case_id_to_examine]

print(f"Case ID: {case_id_to_examine}")
print(f"Subject: {case_result.get('subject', 'N/A')}")
print(f"Target Old: {case_result.get('target_old', 'N/A')}")
print(f"Target New: {case_result.get('target_new', 'N/A')}")
print(f"Prompt: {case_result.get('prompt', 'N/A')}")
print("\n" + "="*60)
print("Evaluation Results:")
print("="*60)

eval_data = case_result.get('eval', {})

# Display generated texts and scores
for test_type in ['paraphrase_prompts', 'generation_prompts', 'neighborhood_prompts', 'attribute_prompts']:
    generated_texts = eval_data.get(f'icl_{test_type}', [])
    scores = eval_data.get(f'icl_scores_{test_type}', [])
    
    if generated_texts:
        print(f"\n{test_type.replace('_', ' ').title()}:")
        for i, (text, score) in enumerate(zip(generated_texts, scores)):
            print(f"  [{i+1}] Score: {score}, Generated: {text[:100]}...")

# Pretty print full JSON for this case
print("\n" + "="*60)
print("Full JSON:")
print("="*60)
print(json.dumps(case_result, indent=2, ensure_ascii=False))
