In [1]:
import torch
from transformer_lens import HookedTransformer 
from sae_lens import SAE, HookedSAETransformer
from datasets import load_dataset, concatenate_datasets
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
from torch.nn import functional as F
from functools import partial
import re

class FeatureSelector:
    def __init__(
        self,
        device: str = "cuda"
    ):
        self.device = device
        
        # Load model as HookedTransformer
        self.model = HookedSAETransformer.from_pretrained(
            "google/gemma-2-2b-it",
            device=device,
            fold_ln=False,
            center_writing_weights=False,
            center_unembed=False,
        )
        
        # Load SAE
        self.sae, cfg_dict, sparsity = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id="layer_20/width_16k/canonical",
            device=device,
        )
        self.sae.use_error_term = True

        self.threshold=0.01

    def calculate_sparsity(self, dataset, batch_size=1, n_batches=100):
        """Calculate feature sparsities on a dataset"""
        total_activations = 0
        feature_activations = torch.zeros(self.sae.cfg.d_sae, device=self.device)
        
        for i in tqdm(range(n_batches)):
            batch = dataset.select(range(i*batch_size, (i+1)*batch_size))
            
            # For WMDP, format as multiple choice question
            if 'choices' in batch.features:
                prompts = []
                for item in batch:
                    prompt = f"{item['question']}\nA: {item['choices'][0]}\nB: {item['choices'][1]}\nC: {item['choices'][2]}\nD: {item['choices'][3]}\nAnswer:"
                    prompts.append(prompt)
            else:
                prompts = batch['text']
            
            tokens = self.model.to_tokens(prompts)
            
            _, cache = self.model.run_with_cache_with_saes(
                tokens,
                saes=[self.sae]
            )
            
            # Get feature activations
            acts = cache[f"blocks.{self.sae.cfg.hook_layer}.hook_resid_post.hook_sae_acts_post"]
            
            # Count positive activations
            feature_activations += (acts > self.threshold).float().sum(dim=(0,1))
            total_activations += acts.shape[0] * acts.shape[1]
            
        return feature_activations / total_activations
    
    def select_features(
        self, 
        wmdp_dataset,
        wikitext_dataset,
        retain_sparsity_threshold=0.01,
        n_features=20
    ):
        """Select features that are active on WMDP but not on WikiText"""
        print("Calculating sparsities on WMDP-bio...")
        wmdp_sparsities = self.calculate_sparsity(wmdp_dataset)
        
        print("Calculating sparsities on WikiText...")
        wikitext_sparsities = self.calculate_sparsity(wikitext_dataset)
        
        # Find features below threshold on WikiText
        retain_mask = wikitext_sparsities < retain_sparsity_threshold
        
        # Sort by activation on WMDP
        wmdp_sparsities[~retain_mask] = -1
        _, top_features = torch.topk(wmdp_sparsities, n_features)
        
        return top_features.cpu().numpy()

    def evaluate_ablation(
            self,
            feature_indices,
            mmlu_dataset,
            openwebtext_dataset,
            clamp_value=0.0,
            owt_token_limit=50000,
            batch_size=1,
        ):
        """Evaluate ablation effects on MMLU and increased loss on OpenWebText using feature ablation hook"""

        # MMLU evaluation
        # Add SAE and ablation hook
        def ablate_feature_hook(activations, hook):
            # activations shape: [batch_size, seq_len, num_features]
            activations[:, :, feature_indices] = clamp_value
            return activations

        hook_point = self.sae.cfg.hook_name + ".hook_sae_acts_post"
        self.model.add_sae(self.sae)
        self.model.add_hook(hook_point, ablate_feature_hook, dir='fwd')

        mmlu_correct = 0
        mmlu_total = 0

        for item in tqdm(mmlu_dataset, desc="Evaluating MMLU"):
            prompt = f"{item['question']}\nA) {item['choices'][0]}\nB) {item['choices'][1]}\nC) {item['choices'][2]}\nD) {item['choices'][3]}\nAnswer:"

            # Convert prompt to tokens to get prompt length
            prompt_tokens = self.model.to_tokens(prompt, prepend_bos=False)
            prompt_length = prompt_tokens.shape[1]

            # Generate model output with modified SAE

            generated_tokens = self.model.generate(
                prompt_tokens,
                max_new_tokens=5,
                temperature=0.0,
                do_sample=False,
                prepend_bos=False,
            )

            # Slice the generated tokens correctly
            new_tokens = generated_tokens[0][prompt_length:]

            # Decode the generated tokens into text
            generated_text = self.model.tokenizer.decode(
                new_tokens,
                skip_special_tokens=True
            )

            # Extract the predicted answer using regex
            match = re.search(r'\b([ABCD])\b', generated_text)
            if match:
                pred_letter = match.group(1)
            else:
                # Handle cases where no match is found
                pred_letter = None

            answer_key = {0:'A', 1:'B', 2: 'C', 3:'D'}
            correct_answer = answer_key[item['answer']]
            if pred_letter == correct_answer:
                mmlu_correct += 1
            mmlu_total += 1

        # Reset hooks and SAEs before OpenWebText processing
        self.model.reset_hooks()
        self.model.reset_saes()



        # Process OpenWebText with a fixed subset
        total_loss_original = 0.0
        total_loss_ablation = 0.0
        
        # Pre-select a small subset that's likely to contain enough tokens
        # Assuming average tokens per text is ~200, select ~300 samples to be safe
        n_samples = 300  # This should comfortably give us 50k tokens
        subset = openwebtext_dataset.select(range(n_samples))
        
        print("Processing OpenWebText...")
        accumulated_tokens = 0
        
        for i in tqdm(range(0, len(subset), batch_size)):
            batch = subset.select(range(i, min(i + batch_size, len(subset))))
            tokens = self.model.to_tokens(batch['text'], prepend_bos=False)
            
            # Stop if we've reached the token limit
            if accumulated_tokens >= owt_token_limit:
                break
                
            # Only process up to the token limit
            if accumulated_tokens + tokens.shape[1] > owt_token_limit:
                tokens = tokens[:, :(owt_token_limit - accumulated_tokens)]
            
            # Compute original loss without ablation
            with torch.no_grad():
                logits = self.model(tokens, return_type='logits')
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = tokens[..., 1:].contiguous()
                loss = F.cross_entropy(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1),
                    reduction='sum'
                )
                total_loss_original += loss.item()
    
            # Add SAE and ablation hook
            self.model.add_sae(self.sae)
            self.model.add_hook(hook_point, ablate_feature_hook, dir='fwd')
    
            # Compute loss with ablation
            with torch.no_grad():
                logits = self.model(tokens, return_type='logits')
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = tokens[..., 1:].contiguous()
                loss = F.cross_entropy(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1),
                    reduction='sum'
                )
                total_loss_ablation += loss.item()
    
            # Reset hooks and SAEs for next batch
            self.model.reset_hooks()
            self.model.reset_saes()
            
            accumulated_tokens += tokens.shape[1]
    
        # Compute increased loss per token
        increased_loss = (total_loss_ablation - total_loss_original) / accumulated_tokens
    
        return {
            'mmlu_accuracy': mmlu_correct / mmlu_total,
            'increased_loss_openwebtext': increased_loss
        }

        # Now process OpenWebText

        # total_loss_original = 0.0
        # total_loss_ablation = 0.0
        # total_tokens = 0

        # print("Processing OpenWebText...")
        # for i in tqdm(range(len(openwebtext_dataset)), desc="Evaluating OpenWebText"):
        #     # Break if we have processed enough tokens
        #     if total_tokens >= owt_token_limit:
        #         break

        #     # Get a batch of data
        #     batch = openwebtext_dataset.select(range(i*batch_size, (i+1)*batch_size))
        #     texts = batch['text']
        #     tokens = self.model.to_tokens(texts, prepend_bos=False)
        #     num_tokens = tokens.shape[1]
        #     if total_tokens + num_tokens > owt_token_limit:
        #         # Trim tokens to reach exactly owt_token_limit tokens
        #         tokens = tokens[:, :owt_token_limit - total_tokens]
        #         num_tokens = tokens.shape[1]

        #     # Compute original loss without ablation
        #     with torch.no_grad():
        #         logits = self.model(tokens, return_type='logits')
        #         shift_logits = logits[..., :-1, :].contiguous()
        #         shift_labels = tokens[..., 1:].contiguous()
        #         loss = F.cross_entropy(
        #             shift_logits.view(-1, shift_logits.size(-1)),
        #             shift_labels.view(-1),
        #             reduction='sum'
        #         )
        #         total_loss_original += loss.item()

        #     # Add SAE and ablation hook
        #     self.model.add_sae(self.sae)
        #     self.model.add_hook(hook_point, ablate_feature_hook, dir='fwd')

        #     # Compute loss with ablation
        #     with torch.no_grad():
        #         logits = self.model(tokens, return_type='logits')
        #         shift_logits = logits[..., :-1, :].contiguous()
        #         shift_labels = tokens[..., 1:].contiguous()
        #         loss = F.cross_entropy(
        #             shift_logits.view(-1, shift_logits.size(-1)),
        #             shift_labels.view(-1),
        #             reduction='sum'
        #         )
        #         total_loss_ablation += loss.item()

        #     # Reset hooks and SAEs for next batch
        #     self.model.reset_hooks()
        #     self.model.reset_saes()

        #     total_tokens += num_tokens

        # # Compute increased loss per token
        # increased_loss = (total_loss_ablation - total_loss_original) / total_tokens

        # return {
        #     'mmlu_accuracy': mmlu_correct / mmlu_total,
        #     'increased_loss_openwebtext': increased_loss
        # }
    def evaluate_wmdp(
        self,
        feature_indices,
        wmdp_dataset,
        clamp_value=0.0
        ):
        """Evaluate ablation effects on WMDP-bio using feature ablation hook"""

        # Define the ablation hook
        def ablate_feature_hook(activations, hook):
            # activations shape: [batch_size, seq_len, num_features]
            activations[:, :, feature_indices] = clamp_value
            return activations

        # The hook point where SAE feature activations are available
        hook_point = self.sae.cfg.hook_name + ".hook_sae_acts_post"

        # Add the SAE to the model
        self.model.add_sae(self.sae)
        # Add the ablation hook
        self.model.add_hook(hook_point, ablate_feature_hook, dir='fwd')

        # WMDP evaluation
        wmdp_correct = 0
        wmdp_total = 0

        for item in tqdm(wmdp_dataset, desc="Evaluating WMDP-bio"):
            prompt = f"{item['question']}\nA: {item['choices'][0]}\nB: {item['choices'][1]}\nC: {item['choices'][2]}\nD: {item['choices'][3]}\nAnswer:"

            # Convert prompt to tokens to get prompt length
            prompt_tokens = self.model.to_tokens(prompt, prepend_bos=False)
            prompt_length = prompt_tokens.shape[1]

            # Generate model output with modified SAE
            with nullcontext():
                generated_tokens = self.model.generate(
                    prompt_tokens,
                    max_new_tokens=5,
                    temperature=0.0,
                    do_sample=False,
                    prepend_bos=False,
                )

            # Slice the generated tokens correctly
            new_tokens = generated_tokens[0][prompt_length:]

            # Decode the generated tokens into text
            generated_text = self.model.tokenizer.decode(
                new_tokens,
                skip_special_tokens=True
            )

            # Extract the predicted answer using regex
            match = re.search(r'\b([ABCD])\b', generated_text)
            if match:
                pred_letter = match.group(1)
            else:
                # Handle cases where no match is found
                pred_letter = None

            answer_key = {0:'A', 1:'B', 2: 'C', 3:'D'}
            if pred_letter == answer_key[item['answer']]:
                wmdp_correct += 1
            wmdp_total += 1

        # Reset hooks and SAEs
        self.model.reset_hooks()
        self.model.reset_saes()

        return {
            'wmdp_accuracy': wmdp_correct / wmdp_total,
        }

     # New function to evaluate MMLU without SAE
    def evaluate_mmlu_no_sae(self, mmlu_dataset):
        """Evaluate the model on MMLU without using an SAE"""
        # Ensure the model has no SAEs or hooks
        self.model.reset_hooks()
        self.model.reset_saes()
        
        mmlu_correct = 0
        mmlu_total = 0

        for item in tqdm(mmlu_dataset, desc="Evaluating MMLU without SAE"):
            prompt = f"{item['question']}\nA) {item['choices'][0]}\nB) {item['choices'][1]}\nC) {item['choices'][2]}\nD) {item['choices'][3]}\nAnswer:"

            # Convert prompt to tokens to get prompt length
            prompt_tokens = self.model.to_tokens(prompt, prepend_bos=False)
            prompt_length = prompt_tokens.shape[1]

            # Generate model output

            generated_tokens = self.model.generate(
                prompt_tokens,
                max_new_tokens=5,
                temperature=0.0,
                do_sample=False,
                prepend_bos=False,
            )

            # Slice the generated tokens correctly
            new_tokens = generated_tokens[0][prompt_length:]

            # Decode the generated tokens into text
            generated_text = self.model.tokenizer.decode(
                new_tokens,
                skip_special_tokens=True
            )

            # Extract the predicted answer using regex
            match = re.search(r'\b([ABCD])\b', generated_text)
            if match:
                pred_letter = match.group(1)
            else:
                # Handle cases where no match is found
                pred_letter = None

            answer_key = {0:'A', 1:'B', 2: 'C', 3:'D'}
            correct_answer = answer_key[item['answer']]
            if pred_letter == correct_answer:
                mmlu_correct += 1
            mmlu_total += 1

        return {
            'mmlu_accuracy_no_sae': mmlu_correct / mmlu_total
        }

    # New function to evaluate WMDP without SAE
    def evaluate_wmdp_no_sae(self, wmdp_dataset):
        """Evaluate the model on WMDP-bio without using an SAE"""
        # Ensure the model has no SAEs or hooks
        self.model.reset_hooks()
        self.model.reset_saes()
        
        wmdp_correct = 0
        wmdp_total = 0

        for item in tqdm(wmdp_dataset, desc="Evaluating WMDP-bio without SAE"):
            prompt = f"{item['question']}\nA: {item['choices'][0]}\nB: {item['choices'][1]}\nC: {item['choices'][2]}\nD: {item['choices'][3]}\nAnswer:"

            # Convert prompt to tokens to get prompt length
            prompt_tokens = self.model.to_tokens(prompt, prepend_bos=False)
            prompt_length = prompt_tokens.shape[1]

            # Generate model output
            generated_tokens = self.model.generate(
                prompt_tokens,
                max_new_tokens=5,
                temperature=0.0,
                do_sample=False,
                prepend_bos=False,
            )

            # Slice the generated tokens correctly
            new_tokens = generated_tokens[0][prompt_length:]

            # Decode the generated tokens into text
            generated_text = self.model.tokenizer.decode(
                new_tokens,
                skip_special_tokens=True
            )

            # Extract the predicted answer using regex
            match = re.search(r'\b([ABCD])\b', generated_text)
            if match:
                pred_letter = match.group(1)
            else:
                # Handle cases where no match is found
                pred_letter = None

            answer_key = {0:'A', 1:'B', 2: 'C', 3:'D'}
            if pred_letter == answer_key[item['answer']]:
                wmdp_correct += 1
            wmdp_total += 1

        return {
            'wmdp_accuracy_no_sae': wmdp_correct / wmdp_total
        }



def main():
    # Initialize selector
    selector = FeatureSelector()
    
    # Load datasets from HuggingFace
    wmdp = load_dataset("cais/wmdp", name="wmdp-bio", split="test")
    wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    openwebtext = load_dataset("stas/openwebtext-10k", split="train")  # Load OpenWebText

    # Load specific MMLU subjects
    mmlu_subjects = [
        "high_school_us_history",
        "high_school_geography", 
        "college_computer_science",
        "human_aging"
    ]
    
    mmlu_datasets = []
    for subject in mmlu_subjects:
        ds = load_dataset("cais/mmlu", subject, split="test")
        mmlu_datasets.append(ds)
    
    # Concatenate all MMLU subjects
    mmlu = concatenate_datasets(mmlu_datasets)
    
    # Select features
    features = selector.select_features(
        wmdp_dataset=wmdp,
        wikitext_dataset=wikitext,
        retain_sparsity_threshold=0.01,
        n_features=20 # Aashiq to change
    )
    
    print(f"\nSelected features: {features}")
    
    # Test different clamp values
    clamp_values = [-20.0]  # You can test different clamp values as needed
    results = []
    
    for clamp_value in clamp_values:
        print(f"\nTesting clamp value {clamp_value}")
        metrics = selector.evaluate_ablation(
            feature_indices=features,
            mmlu_dataset=mmlu,
            openwebtext_dataset=openwebtext,
            clamp_value=clamp_value
        )
        
        # Evaluate on WMDP-bio
        metrics_wmdp = selector.evaluate_wmdp(
            feature_indices=features,
            wmdp_dataset=wmdp,
            clamp_value=clamp_value
        )
        
         # Evaluate without SAE
        metrics_no_sae_mmlu = selector.evaluate_mmlu_no_sae(mmlu)
        metrics_no_sae_wmdp = selector.evaluate_wmdp_no_sae(wmdp)

        # Combine metrics
        results.append({
            'clamp_value': clamp_value,
            **metrics,
            **metrics_wmdp,
            **metrics_no_sae_mmlu,
            **metrics_no_sae_wmdp
        })
        
    # Show results
    results_df = pd.DataFrame(results)
    print("\nResults:")
    print(results_df)

if __name__ == "__main__":
    main()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2-2b-it into HookedTransformer
Calculating sparsities on WMDP-bio...


  0%|          | 0/100 [00:00<?, ?it/s]

Calculating sparsities on WikiText...


  0%|          | 0/100 [00:00<?, ?it/s]


Selected features: [10589 15197 15647  8367  4529  6175   966  5267  8155  6145  6344 11438
  6864   518   377  5447  9571  1260  4037 15432  5481 15696 15231 15869
 11709 15091  3481   143 11810  6283  5595  3074  7859 11808  7059 13835
  2376  2627 13556  5544 15017  6551  2605  2207   575 14247  1249 13477
 15367  5370   275  1518  6139   916  4569 15825  8745  5175 13442  1278
  7232  2583  5402  8869 11705  5615  1266  2684 10447  4341 12951 13176
 14527 11074  2335 11254   872  3145 10710 16325 13345 10174   112  1871
  4255 15392  2963  6048  2496  7154  8526 14217 16294 10122 14772   102
  9089  4965  6372 16264  6682  7810 14631  4383  7879  9388 13893  8433
  9722  3270 13626 14697 15525  4789 15874 13389  5432  2981 14314 13947
  5270  9991 10879   605  2468   214  2306  4370 14303 10151  6501  9295
  8221 15694 15882  8537 10167 11044  1833  9796 11296 12644 15650 10560
 10883 11063 12682  2439 11010 16005 16188  9251  8114  2888  1635  2993
  8092  8768 12298 14869 10591 

Evaluating MMLU:   0%|          | 0/725 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

Processing OpenWebText...


  0%|          | 0/75 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 22.04 GiB. GPU 0 has a total capacity of 79.25 GiB of which 4.79 GiB is free. Including non-PyTorch memory, this process has 74.45 GiB memory in use. Of the allocated memory 69.89 GiB is allocated by PyTorch, and 4.07 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Custom version

In [2]:
# import torch
# from transformer_lens import HookedTransformer 
# from sae_lens import SAE, HookedSAETransformer
# from datasets import load_dataset, concatenate_datasets
# from tqdm.notebook import tqdm
# import pandas as pd
# import numpy as np
# from torch.nn import functional as F
# from functools import partial
# import re
# from typing import List, Dict, Tuple, Optional
# from dataclasses import dataclass

# @dataclass
# class EvaluationConfig:
#     batch_size: int = 4
#     owt_token_limit: int = 50000
#     clamp_value: float = 0.0
#     retain_sparsity_threshold: float = 0.01
#     n_features: int = 200
#     sparsity_batch_size: int = 1
#     sparsity_n_batches: int = 100

# class FeatureSelector:
#     def __init__(self, device: str = "cuda"):
#         self.device = device
#         self._initialize_model()
#         self.threshold = 0.01
        
#     def _initialize_model(self):
#         """Initialize model and SAE with cached properties"""
#         self.model = HookedSAETransformer.from_pretrained(
#             "google/gemma-2-2b-it",
#             device=self.device,
#             fold_ln=False,
#             center_writing_weights=False,
#             center_unembed=False,
#         )
        
#         self.sae, _, _ = SAE.from_pretrained(
#             release="gemma-scope-2b-pt-res-canonical",
#             sae_id="layer_20/width_16k/canonical",
#             device=self.device,
#         )
#         self.sae.use_error_term = True
#         self.hook_point = self.sae.cfg.hook_name + ".hook_sae_acts_post"

#     @torch.no_grad()
#     def _process_batch(self, batch: Dict) -> torch.Tensor:
#         """Process a single batch and return tokens"""
#         if 'choices' in batch.features:
#             prompts = [
#                 f"{item['question']}\nA: {item['choices'][0]}\nB: {item['choices'][1]}\nC: {item['choices'][2]}\nD: {item['choices'][3]}\nAnswer:"
#                 for item in batch
#             ]
#         else:
#             prompts = batch['text']
#         return self.model.to_tokens(prompts)

#     @torch.no_grad()
#     def calculate_sparsity(self, dataset, config: EvaluationConfig) -> torch.Tensor:
#         """Optimized sparsity calculation with batching"""
#         feature_activations = torch.zeros(self.sae.cfg.d_sae, device=self.device)
#         total_activations = 0
        
#         for i in tqdm(range(config.sparsity_n_batches)):
#             batch = dataset.select(range(i*config.sparsity_batch_size, 
#                                        (i+1)*config.sparsity_batch_size))
#             tokens = self._process_batch(batch)
            
#             _, cache = self.model.run_with_cache_with_saes(
#                 tokens,
#                 saes=[self.sae]
#             )
            
#             acts = cache[f"blocks.{self.sae.cfg.hook_layer}.hook_resid_post.hook_sae_acts_post"]
#             feature_activations += (acts > self.threshold).float().sum(dim=(0,1))
#             total_activations += acts.numel() // acts.size(-1)
            
#         return feature_activations / total_activations

#     def select_features(self, wmdp_dataset, wikitext_dataset, 
#                        config: EvaluationConfig) -> np.ndarray:
#         """Optimized feature selection"""
#         wmdp_sparsities = self.calculate_sparsity(wmdp_dataset, config)
#         wikitext_sparsities = self.calculate_sparsity(wikitext_dataset, config)
        
#         retain_mask = wikitext_sparsities < config.retain_sparsity_threshold
#         wmdp_sparsities[~retain_mask] = -1
#         _, top_features = torch.topk(wmdp_sparsities, config.n_features)
        
#         return top_features.cpu().numpy()

#     def _create_ablation_hook(self, feature_indices: np.ndarray, 
#                             clamp_value: float):
#         """Create ablation hook function"""
#         def ablate_feature_hook(activations, hook):
#             activations[:, :, feature_indices] = clamp_value
#             return activations
#         return ablate_feature_hook

#     @torch.no_grad()
#     def _evaluate_multiple_choice(self, dataset, hook_fn=None) -> Tuple[int, int]:
#         """Evaluate multiple choice questions"""
#         if hook_fn:
#             self.model.add_sae(self.sae)
#             self.model.add_hook(self.hook_point, hook_fn, dir='fwd')
            
#         correct = total = 0
#         answer_key = {0:'A', 1:'B', 2:'C', 3:'D'}
        
#         for batch in tqdm(dataset):
#             prompt = f"{batch['question']}\nA) {batch['choices'][0]}\nB) {batch['choices'][1]}\nC) {batch['choices'][2]}\nD) {batch['choices'][3]}\nAnswer:"
#             tokens = self.model.to_tokens(prompt, prepend_bos=False)
            
#             output = self.model.generate(
#                 tokens,
#                 max_new_tokens=5,
#                 temperature=0.0,
#                 do_sample=False,
#                 prepend_bos=False,
#             )
            
#             generated_text = self.model.tokenizer.decode(
#                 output[0][tokens.shape[1]:],
#                 skip_special_tokens=True
#             )
            
#             if match := re.search(r'\b([ABCD])\b', generated_text):
#                 correct += match.group(1) == answer_key[batch['answer']]
#             total += 1
            
#         if hook_fn:
#             self.model.reset_hooks()
#             self.model.reset_saes()
            
#         return correct, total

#     @torch.no_grad()
#     def _evaluate_loss(self, dataset, hook_fn, config: EvaluationConfig) -> float:
#         """Calculate loss on dataset"""
#         total_loss = 0.0
#         accumulated_tokens = 0
        
#         for i in range(0, len(dataset), config.batch_size):
#             if accumulated_tokens >= config.owt_token_limit:
#                 break
                
#             batch = dataset.select(range(i, min(i + config.batch_size, len(dataset))))
#             tokens = self.model.to_tokens(batch['text'], prepend_bos=False)
            
#             if accumulated_tokens + tokens.shape[1] > config.owt_token_limit:
#                 tokens = tokens[:, :(config.owt_token_limit - accumulated_tokens)]
            
#             logits = self.model(tokens, return_type='logits')
#             shift_logits = logits[..., :-1, :].contiguous()
#             shift_labels = tokens[..., 1:].contiguous()
            
#             loss = F.cross_entropy(
#                 shift_logits.view(-1, shift_logits.size(-1)),
#                 shift_labels.view(-1),
#                 reduction='sum'
#             )
            
#             total_loss += loss.item()
#             accumulated_tokens += tokens.shape[1]
            
#         return total_loss / accumulated_tokens

#     def evaluate_all(self, feature_indices: np.ndarray, datasets: Dict, 
#                     config: EvaluationConfig) -> Dict:
#         """Unified evaluation method"""
#         hook_fn = self._create_ablation_hook(feature_indices, config.clamp_value)
        
#         # Evaluate without SAE
#         mmlu_correct, mmlu_total = self._evaluate_multiple_choice(datasets['mmlu'])
#         wmdp_correct, wmdp_total = self._evaluate_multiple_choice(datasets['wmdp'])
        
#         # Evaluate with ablation
#         mmlu_abl_correct, mmlu_total = self._evaluate_multiple_choice(
#             datasets['mmlu'], hook_fn)
#         wmdp_abl_correct, wmdp_total = self._evaluate_multiple_choice(
#             datasets['wmdp'], hook_fn)
        
#         # Calculate losses
#         base_loss = self._evaluate_loss(datasets['openwebtext'], None, config)
#         abl_loss = self._evaluate_loss(datasets['openwebtext'], hook_fn, config)
        
#         return {
#             'mmlu_accuracy': mmlu_correct / mmlu_total,
#             'mmlu_ablated_accuracy': mmlu_abl_correct / mmlu_total,
#             'wmdp_accuracy': wmdp_correct / wmdp_total,
#             'wmdp_ablated_accuracy': wmdp_abl_correct / wmdp_total,
#             'openwebtext_loss_increase': abl_loss - base_loss
#         }

# def main():
#     config = EvaluationConfig()
#     selector = FeatureSelector()
    
#     # Load datasets
#     datasets = {
#         'wmdp': load_dataset("cais/wmdp", name="wmdp-bio", split="test"),
#         'wikitext': load_dataset("wikitext", "wikitext-2-raw-v1", split="test"),
#         'openwebtext': load_dataset("stas/openwebtext-10k", split="train"),
#         'mmlu': concatenate_datasets([
#             load_dataset("cais/mmlu", subject, split="test")
#             for subject in ["high_school_us_history", "high_school_geography", 
#                           "college_computer_science", "human_aging"]
#         ])
#     }
    
#     features = selector.select_features(
#         datasets['wmdp'],
#         datasets['wikitext'],
#         config
#     )
    
#     results = selector.evaluate_all(features, datasets, config)
#     print("\nResults:")
#     print(pd.DataFrame([results]))

# if __name__ == "__main__":
#     main()

In [2]:
# import torch
# from transformer_lens import HookedTransformer 
# from sae_lens import SAE, HookedSAETransformer
# from datasets import load_dataset, concatenate_datasets
# from tqdm import tqdm
# import pandas as pd
# import numpy as np
# from torch.nn import functional as F
# from functools import partial
# from contextlib import nullcontext
# import re
# from typing import List, Dict, Tuple, Optional
# from dataclasses import dataclass

# @dataclass
# class EvaluationConfig:
#     batch_size: int = 64
#     owt_token_limit: int = 50000
#     clamp_value: float = 0.0
#     retain_sparsity_threshold: float = 0.01
#     n_features: int = 200
#     sparsity_batch_size: int = 4
#     sparsity_n_batches: int = 100

# class FeatureSelector:
#     def __init__(self, device: str = "cuda"):
#         self.device = device
#         self._initialize_model()
#         self.threshold = 0.01
        
#     def _initialize_model(self):
#         """Initialize model and SAE with cached properties"""
#         self.model = HookedSAETransformer.from_pretrained(
#             "google/gemma-2-2b-it",
#             device=self.device,
#             fold_ln=False,
#             center_writing_weights=False,
#             center_unembed=False,
#         )
        
#         self.sae, _, _ = SAE.from_pretrained(
#             release="gemma-scope-2b-pt-res-canonical",
#             sae_id="layer_20/width_16k/canonical",
#             device=self.device,
#         )
#         self.sae.use_error_term = True
#         self.hook_point = self.sae.cfg.hook_name + ".hook_sae_acts_post"

#     @torch.no_grad()
#     def _process_batch(self, batch: Dict) -> torch.Tensor:
#         """Process a single batch and return tokens"""
#         if 'choices' in batch.features:
#             prompts = [
#                 f"{item['question']}\nA: {item['choices'][0]}\nB: {item['choices'][1]}\nC: {item['choices'][2]}\nD: {item['choices'][3]}\nAnswer:"
#                 for item in batch
#             ]
#         else:
#             prompts = batch['text']
#         return self.model.to_tokens(prompts)

#     @torch.no_grad()
#     def calculate_sparsity(self, dataset, config: EvaluationConfig, desc: str) -> torch.Tensor:
#         """Optimized sparsity calculation with batching"""
#         feature_activations = torch.zeros(self.sae.cfg.d_sae, device=self.device)
#         total_activations = 0
        
#         for i in tqdm(range(config.sparsity_n_batches), desc=f"Calculating sparsity for {desc}"):
#             batch = dataset.select(range(i*config.sparsity_batch_size, 
#                                        (i+1)*config.sparsity_batch_size))
#             tokens = self._process_batch(batch)
            
#             _, cache = self.model.run_with_cache_with_saes(
#                 tokens,
#                 saes=[self.sae]
#             )
            
#             acts = cache[f"blocks.{self.sae.cfg.hook_layer}.hook_resid_post.hook_sae_acts_post"]
#             feature_activations += (acts > self.threshold).float().sum(dim=(0,1))
#             total_activations += acts.numel() // acts.size(-1)
            
#         return feature_activations / total_activations

#     def select_features(self, wmdp_dataset, wikitext_dataset, 
#                        config: EvaluationConfig) -> np.ndarray:
#         """Optimized feature selection"""
#         print("Starting feature selection process...")
#         wmdp_sparsities = self.calculate_sparsity(wmdp_dataset, config, "WMDP")
#         wikitext_sparsities = self.calculate_sparsity(wikitext_dataset, config, "WikiText")
        
#         retain_mask = wikitext_sparsities < config.retain_sparsity_threshold
#         wmdp_sparsities[~retain_mask] = -1
#         _, top_features = torch.topk(wmdp_sparsities, config.n_features)
        
#         return top_features.cpu().numpy()

#     def _create_ablation_hook(self, feature_indices: np.ndarray, 
#                             clamp_value: float):
#         """Create ablation hook function"""
#         def ablate_feature_hook(activations, hook):
#             activations[:, :, feature_indices] = clamp_value
#             return activations
#         return ablate_feature_hook

#     @torch.no_grad()
#     def _evaluate_multiple_choice(self, dataset, hook_fn=None, desc="Evaluating") -> Tuple[int, int]:
#         """Evaluate multiple choice questions"""
#         if hook_fn:
#             self.model.add_sae(self.sae)
#             self.model.add_hook(self.hook_point, hook_fn, dir='fwd')
            
#         correct = total = 0
#         answer_key = {0:'A', 1:'B', 2:'C', 3:'D'}
        
#         for batch in tqdm(dataset, desc=desc):
#             prompt = f"{batch['question']}\nA) {batch['choices'][0]}\nB) {batch['choices'][1]}\nC) {batch['choices'][2]}\nD) {batch['choices'][3]}\nAnswer:"
#             tokens = self.model.to_tokens(prompt, prepend_bos=False)

#             with nullcontext():
#                 output = self.model.generate(
#                     tokens,
#                     max_new_tokens=5,
#                     temperature=0.0,
#                     do_sample=False,
#                     prepend_bos=False,
#                 )
            
#             generated_text = self.model.tokenizer.decode(
#                 output[0][tokens.shape[1]:],
#                 skip_special_tokens=True
#             )
            
#             if match := re.search(r'\b([ABCD])\b', generated_text):
#                 correct += match.group(1) == answer_key[batch['answer']]
#             total += 1
            
#         if hook_fn:
#             self.model.reset_hooks()
#             self.model.reset_saes()
            
#         return correct, total

#     @torch.no_grad()
#     def _evaluate_loss(self, dataset, hook_fn, config: EvaluationConfig, desc="Calculating loss") -> float:
#         """Calculate loss on dataset"""
#         total_loss = 0.0
#         accumulated_tokens = 0
        
#         # Calculate number of batches needed based on token limit
#         n_batches = min(len(dataset) // config.batch_size, 
#                        config.owt_token_limit // (config.batch_size * 100))  # Assuming average 100 tokens per sample
        
#         for i in tqdm(range(n_batches), desc=desc):
#             if accumulated_tokens >= config.owt_token_limit:
#                 break
                
#             batch = dataset.select(range(i * config.batch_size, 
#                                        min((i + 1) * config.batch_size, len(dataset))))
#             tokens = self.model.to_tokens(batch['text'], prepend_bos=False)
            
#             if accumulated_tokens + tokens.shape[1] > config.owt_token_limit:
#                 tokens = tokens[:, :(config.owt_token_limit - accumulated_tokens)]
            
#             logits = self.model(tokens, return_type='logits')
#             shift_logits = logits[..., :-1, :].contiguous()
#             shift_labels = tokens[..., 1:].contiguous()
            
#             loss = F.cross_entropy(
#                 shift_logits.view(-1, shift_logits.size(-1)),
#                 shift_labels.view(-1),
#                 reduction='sum'
#             )
            
#             total_loss += loss.item()
#             accumulated_tokens += tokens.shape[1]
            
#         return total_loss / accumulated_tokens

#     def evaluate_all(self, feature_indices: np.ndarray, datasets: Dict, 
#                     config: EvaluationConfig) -> Dict:
#         """Unified evaluation method"""
#         print("\nStarting evaluation process...")
#         hook_fn = self._create_ablation_hook(feature_indices, config.clamp_value)
        
#         # Evaluate without SAE
#         mmlu_correct, mmlu_total = self._evaluate_multiple_choice(
#             datasets['mmlu'], 
#             desc="Evaluating MMLU (baseline)"
#         )
#         wmdp_correct, wmdp_total = self._evaluate_multiple_choice(
#             datasets['wmdp'], 
#             desc="Evaluating WMDP (baseline)"
#         )
        
#         # Evaluate with ablation
#         mmlu_abl_correct, mmlu_total = self._evaluate_multiple_choice(
#             datasets['mmlu'], 
#             hook_fn,
#             desc="Evaluating MMLU (ablated)"
#         )
#         wmdp_abl_correct, wmdp_total = self._evaluate_multiple_choice(
#             datasets['wmdp'], 
#             hook_fn,
#             desc="Evaluating WMDP (ablated)"
#         )
        
#         # Calculate losses
#         base_loss = self._evaluate_loss(
#             datasets['openwebtext'], 
#             None, 
#             config,
#             desc="Calculating OpenWebText baseline loss"
#         )
#         abl_loss = self._evaluate_loss(
#             datasets['openwebtext'], 
#             hook_fn, 
#             config,
#             desc="Calculating OpenWebText ablated loss"
#         )
        
#         return {
#             'mmlu_accuracy': mmlu_correct / mmlu_total,
#             'mmlu_ablated_accuracy': mmlu_abl_correct / mmlu_total,
#             'wmdp_accuracy': wmdp_correct / wmdp_total,
#             'wmdp_ablated_accuracy': wmdp_abl_correct / wmdp_total,
#             'openwebtext_loss_increase': abl_loss - base_loss
#         }

# def main():
#     print("Initializing Feature Selector...")
#     config = EvaluationConfig()
#     selector = FeatureSelector()
    
#     # Load datasets
#     print("\nLoading datasets...")
#     datasets = {
#         'wmdp': load_dataset("cais/wmdp", name="wmdp-bio", split="test"),
#         'wikitext': load_dataset("wikitext", "wikitext-2-raw-v1", split="test"),
#         'openwebtext': load_dataset("stas/openwebtext-10k", split="train"),
#         'mmlu': concatenate_datasets([
#             load_dataset("cais/mmlu", subject, split="test")
#             for subject in ["high_school_us_history", "high_school_geography", 
#                           "college_computer_science", "human_aging"]
#         ])
#     }
    
#     features = selector.select_features(
#         datasets['wmdp'],
#         datasets['wikitext'],
#         config
#     )
    
#     results = selector.evaluate_all(features, datasets, config)
#     print("\nEvaluation Results:")
#     print(pd.DataFrame([results]))

# if __name__ == "__main__":
#     main()