# Simple MCQA Example Submission - DAS Training

This notebook demonstrates how to train DAS (Distributed Alignment Search) featurizers for the Simple MCQA (Multiple Choice Question Answering) task using a Gemma model.

## Overview
1. Load Simple MCQA datasets and setup the model
2. Filter datasets based on model performance  
3. Configure experiment settings for DAS training
4. Train DAS featurizers on residual stream representations
5. Load and test trained models

The Simple MCQA task involves answering multiple choice questions where the model needs to identify the correct answer from 4 options (A, B, C, D).

## Step 1: Setup and Data Loading

Load the Simple MCQA task components and initialize the model pipeline.

In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path().resolve().parent.parent))

from tasks.simple_MCQA.simple_MCQA import get_token_positions, get_counterfactual_datasets, get_causal_model
from CausalAbstraction.experiments.aggregate_experiments import residual_stream_baselines
from CausalAbstraction.neural.pipeline import LMPipeline
from CausalAbstraction.experiments.filter_experiment import FilterExperiment
import gc
import torch
import os

gc.collect()
torch.cuda.empty_cache()

# Get counterfactual datasets and causal model
counterfactual_datasets = get_counterfactual_datasets(hf=True, size=None, load_private_data=False)
causal_model = get_causal_model()

# Print available datasets
print("Available datasets:", counterfactual_datasets.keys())

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def clear_memory():
    # Clear Python garbage collector
    gc.collect()
    
    # Clear CUDA cache if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    # Force a synchronization point to ensure memory is freed
    if torch.cuda.is_available():
        torch.cuda.synchronize()


def checker(output_text, expected):
    return expected in output_text

model_name = "google/gemma-2-2b"
pipeline = LMPipeline(model_name, max_new_tokens=1, device=device, dtype=torch.float16)
pipeline.tokenizer.padding_side = "left"
print("DEVICE:", pipeline.model.device)

# Get a sample input and check model's prediction
sampled_example = next(iter(counterfactual_datasets.values()))[0]
print("INPUT:", sampled_example["input"])
print("EXPECTED OUTPUT:", causal_model.run_forward(sampled_example["input"])["raw_output"])
print("MODEL PREDICTION:", pipeline.dump(pipeline.generate(sampled_example["input"])))

Available datasets: dict_keys(['answerPosition_train', 'randomLetter_train', 'answerPosition_randomLetter_train', 'answerPosition_validation', 'randomLetter_validation', 'answerPosition_randomLetter_validation', 'answerPosition_test', 'randomLetter_test', 'answerPosition_randomLetter_test'])


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

DEVICE: cuda:0
INPUT: {'choice0': 'red', 'choice1': 'orange', 'choice2': 'brown', 'choice3': 'purple', 'question': ['brown', 'question: coconuts'], 'raw_input': 'Question: Coconuts are brown. What color are coconuts?\nA. red\nB. orange\nC. brown\nD. purple\nAnswer:', 'symbol0': 'A', 'symbol1': 'B', 'symbol2': 'C', 'symbol3': 'D'}
EXPECTED OUTPUT:  C
MODEL PREDICTION:  C


## Step 2: Filter Datasets Based on Model Performance

Filter datasets to keep only examples where the model produces correct outputs. This ensures we train featurizers on cases where the model actually succeeds at the task.

In [3]:
# Filter the datasets based on model performance
print("\nFiltering datasets based on model performance...")
exp = FilterExperiment(pipeline, causal_model, checker)
filtered_datasets = exp.filter(counterfactual_datasets, verbose=True, batch_size=1024)

token_positions = get_token_positions(pipeline, causal_model)

# Display token highlighting for a sample
print("\nToken positions highlighted in samples:")
for dataset in filtered_datasets.values():
    for token_position in token_positions:
        example = dataset[0]
        print(token_position.highlight_selected_token(example["counterfactual_inputs"][0]))
        break
    break

gc.collect()
torch.cuda.empty_cache()




Filtering datasets based on model performance...


Filtering answerPosition_train: 100%|██████████| 1/1 [00:01<00:00,  1.12s/it]


Dataset 'answerPosition_train': kept 106/110 examples (96.4%)


Filtering randomLetter_train: 100%|██████████| 1/1 [00:01<00:00,  1.14s/it]


Dataset 'randomLetter_train': kept 75/110 examples (68.2%)


Filtering answerPosition_randomLetter_train: 100%|██████████| 1/1 [00:01<00:00,  1.15s/it]


Dataset 'answerPosition_randomLetter_train': kept 71/110 examples (64.5%)


Filtering answerPosition_validation: 100%|██████████| 1/1 [00:00<00:00,  1.27it/s]


Dataset 'answerPosition_validation': kept 50/50 examples (100.0%)


Filtering randomLetter_validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s]


Dataset 'randomLetter_validation': kept 37/50 examples (74.0%)


Filtering answerPosition_randomLetter_validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s]


Dataset 'answerPosition_randomLetter_validation': kept 35/50 examples (70.0%)


Filtering answerPosition_test: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s]


Dataset 'answerPosition_test': kept 50/50 examples (100.0%)


Filtering randomLetter_test: 100%|██████████| 1/1 [00:00<00:00,  1.26it/s]


Dataset 'randomLetter_test': kept 34/50 examples (68.0%)


Filtering answerPosition_randomLetter_test: 100%|██████████| 1/1 [00:00<00:00,  1.22it/s]

Dataset 'answerPosition_randomLetter_test': kept 37/50 examples (74.0%)

Total filtering results:
Original examples: 630
Kept examples: 495
Overall keep rate: 78.6%

Token positions highlighted in samples:
<bos>Question: Coconuts are brown. What color are coconuts?
A. red
B. orange
C. purple
**D**. brown
Answer:





## Step 3: Configure Experiment Settings for DAS Training

Set up the training configuration including batch sizes, training epochs, feature dimensions, and target variables. For Simple MCQA, we focus on the `answer_pointer` variable which represents the model's ability to point to the correct answer choice.

In [3]:
import os

start = 0 
end = 1

# Use original config for all models
config = {"batch_size": 64, "evaluation_batch_size": 1024, "training_epoch": 1, "n_features": 16, "regularization_coefficient": 0.0, "output_scores": False}
    
names = ["answerPosition", "randomLetter", "answerPosition_randomLetter"]

# Prepare train and test data dictionaries
train_data = {}
test_data = {}

for name in names:
    if name + "_train" in filtered_datasets:
        train_data[name + "_train"] = filtered_datasets[name + "_train"]
    if name + "_test" in filtered_datasets:
        test_data[name + "_test"] = filtered_datasets[name + "_test"]
    # Uncomment the line below if testprivate datasets are available
    if name + "_testprivate" in filtered_datasets:
        test_data[name + "_testprivate"] = filtered_datasets[name + "_testprivate"]

verbose = False 
results_dir = "mock_submission_results"
model_dir = os.path.join("mock_submission", "4_answer_MCQA_Gemma2ForCausalLM_answer_pointer")
target_variables=["answer_pointer"]

## Step 4: Train DAS Featurizers on Residual Stream

Train DAS (Direct Attribution with Subspace) featurizers on the residual stream representations. DAS learns a low-dimensional subspace that captures the relevant features for the causal variable we want to model.

The training process:
1. Extract residual stream activations from the specified layers
2. Learn a featurizer that maps high-dimensional activations to a lower-dimensional feature space
3. Train the feature space to predict the target causal variable
4. Save the trained featurizers for later use in interventions

In [4]:
# Run DAS method (Direct Attribution with Subspace)
print("Running DAS method...")

residual_stream_baselines(
    pipeline=pipeline,
    task=causal_model,
    token_positions=token_positions,
    train_data=train_data,
    test_data=test_data,
    config=config,
    target_variables=target_variables,
    checker=checker,
    start=start,
    end=end,
    verbose=verbose,
    model_dir=model_dir,
    results_dir=results_dir,
    methods=["DAS"]  # Only run DAS method
)

clear_memory()

Running DAS method...



Epoch:   0%|          | 0/1 [00:00<?, ?it/s]


Epoch: 0:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch: 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=2.34, accuracy=0.27, token_accuracy=0.27]


Epoch: 0:  25%|██▌       | 1/4 [00:00<00:02,  1.37it/s, loss=2.34, accuracy=0.27, token_accuracy=0.27]


Epoch: 0:  25%|██▌       | 1/4 [00:01<00:02,  1.37it/s, loss=1.94, accuracy=0.47, token_accuracy=0.47]


Epoch: 0:  50%|█████     | 2/4 [00:01<00:01,  1.45it/s, loss=1.94, accuracy=0.47, token_accuracy=0.47]


Epoch: 0:  50%|█████     | 2/4 [00:01<00:01,  1.45it/s, loss=1.8, accuracy=0.56, token_accuracy=0.56] 


Epoch: 0:  75%|███████▌  | 3/4 [00:02<00:00,  1.53it/s, loss=1.8, accuracy=0.56, token_accuracy=0.56]


Epoch: 0:  75%|███████▌  | 3/4 [00:02<00:00,  1.53it/s, loss=1.71, accuracy=0.5, token_accuracy=0.5] 


Epoch: 0: 100%|██████████| 4/4 [00:02<00:00,  1.59it/s, loss=1.71, accuracy=0.5, token_accuracy=0.5]


Epoch: 0: 100%|██████████| 4/4 [00:02<00:00,  1.54it/s, loss=1.71, accuracy=0.5, token_accuracy=0.5]



Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.59s/it]


Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.59s/it]





Epoch:   0%|          | 0/1 [00:00<?, ?it/s]


Epoch: 0:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch: 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.99, accuracy=0.38, token_accuracy=0.38]


Epoch: 0:  25%|██▌       | 1/4 [00:00<00:01,  1.60it/s, loss=1.99, accuracy=0.38, token_accuracy=0.38]


Epoch: 0:  25%|██▌       | 1/4 [00:01<00:01,  1.60it/s, loss=2.09, accuracy=0.3, token_accuracy=0.3]  


Epoch: 0:  50%|█████     | 2/4 [00:01<00:01,  1.61it/s, loss=2.09, accuracy=0.3, token_accuracy=0.3]


Epoch: 0:  50%|█████     | 2/4 [00:01<00:01,  1.61it/s, loss=2.16, accuracy=0.27, token_accuracy=0.27]


Epoch: 0:  75%|███████▌  | 3/4 [00:02<00:00,  1.45it/s, loss=2.16, accuracy=0.27, token_accuracy=0.27]


Epoch: 0:  75%|███████▌  | 3/4 [00:02<00:00,  1.45it/s, loss=2.17, accuracy=0.25, token_accuracy=0.25]


Epoch: 0: 100%|██████████| 4/4 [00:02<00:00,  1.50it/s, loss=2.17, accuracy=0.25, token_accuracy=0.25]


Epoch: 0: 100%|██████████| 4/4 [00:02<00:00,  1.51it/s, loss=2.17, accuracy=0.25, token_accuracy=0.25]



Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.64s/it]


Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.64s/it]





Epoch:   0%|          | 0/1 [00:00<?, ?it/s]


Epoch: 0:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch: 0:   0%|          | 0/4 [00:00<?, ?it/s, loss=2.16, accuracy=0.3, token_accuracy=0.3]


Epoch: 0:  25%|██▌       | 1/4 [00:00<00:02,  1.49it/s, loss=2.16, accuracy=0.3, token_accuracy=0.3]


Epoch: 0:  25%|██▌       | 1/4 [00:01<00:02,  1.49it/s, loss=2.23, accuracy=0.23, token_accuracy=0.23]


Epoch: 0:  50%|█████     | 2/4 [00:01<00:01,  1.56it/s, loss=2.23, accuracy=0.23, token_accuracy=0.23]


Epoch: 0:  50%|█████     | 2/4 [00:01<00:01,  1.56it/s, loss=2.11, accuracy=0.42, token_accuracy=0.42]


Epoch: 0:  75%|███████▌  | 3/4 [00:01<00:00,  1.60it/s, loss=2.11, accuracy=0.42, token_accuracy=0.42]


Epoch: 0:  75%|███████▌  | 3/4 [00:02<00:00,  1.60it/s, loss=2.19, accuracy=0.23, token_accuracy=0.23]


Epoch: 0: 100%|██████████| 4/4 [00:02<00:00,  1.59it/s, loss=2.19, accuracy=0.23, token_accuracy=0.23]


Epoch: 0: 100%|██████████| 4/4 [00:02<00:00,  1.58it/s, loss=2.19, accuracy=0.23, token_accuracy=0.23]



Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.53s/it]


Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.53s/it]




## Alternative Methods (Optional)

The following cells demonstrate how to train other baseline methods like DBM (Desiderata Based Masking) and DBM+SAE (with Sparse Autoencoders). These are commented out but can be uncommented to experiment with different approaches.

In [5]:
# Example: Run DBM+SAE method (uncomment to use)
# NOTE: This method requires the sae_lens library and is specific to certain models

# print("Running DBM+SAE method...")

# residual_stream_baselines(
#     pipeline=pipeline,
#     task=causal_model,
#     token_positions=token_positions,
#     train_data=train_data,
#     test_data=test_data,
#     config=config,
#     target_variables=target_variables,
#     checker=checker,
#     start=start,
#     end=end,
#     verbose=verbose,
#     model_dir=model_dir,
#     results_dir=results_dir,
#     methods=["DBM+SAE"]  # Only run DBM+SAE method
# )

# clear_memory()

In [6]:
# Example: Run DBM method (uncomment to use)
# print("Running DBM method...")

# residual_stream_baselines(
#     pipeline=pipeline,
#     task=causal_model,
#     token_positions=token_positions,
#     train_data=train_data,
#     test_data=test_data,
#     config=config,
#     target_variables=target_variables,
#     checker=checker,
#     start=start,
#     end=end,
#     verbose=verbose,
#     model_dir=model_dir,
#     results_dir=results_dir,
#     methods=["DBM"]  # Only run DBM method
# )

# clear_memory()

## Step 5: Load Trained Models and Run Inference

This section demonstrates how to load previously trained featurizers and use them for inference on test data. This is useful for:

1. **Testing trained models**: Verify that saved models work correctly
2. **Running interventions**: Use the trained featurizers to perform causal interventions
3. **Evaluation**: Test model performance on held-out test data

The process involves:
- Loading the trained featurizer from disk
- Running interventions on test datasets  
- Collecting results for analysis

This is exactly what the evaluation system will do with your submitted models.

In [7]:
# Example: Load saved models and run inference
# This demonstrates how to load previously trained featurizers and run interventions

from experiments.residual_stream_experiment import PatchResidualStream

print("Loading trained models and running inference...")

for method in ["DAS"]:  # Can also test "DBM", "DBM+SAE", etc.
    print(f"Testing {method} method...")
    
    # Create experiment with same configuration
    config["method_name"] = method
    experiment = PatchResidualStream(pipeline, causal_model, list(range(start,end)), token_positions, checker, config=config)
    
    # Set up SAE loader if needed for DBM+SAE method
    if method == "DBM+SAE":
        from sae_lens import SAE
        def sae_loader(layer):
            sae, _, _ = SAE.from_pretrained(
                release = "gemma-scope-2b-pt-res-canonical",
                sae_id = f"layer_{layer}/width_16k/canonical",
                device = "cpu",
            )
            return sae
        experiment.build_SAE_feature_intervention(sae_loader)
    
    # Load the trained featurizers
    method_model_dir = os.path.join(model_dir, f"{method}_{pipeline.model.__class__.__name__}_{"-".join(target_variables)}")

    experiment.load_featurizers(method_model_dir)
    
    # Run interventions on test data
    raw_results = experiment.perform_interventions(test_data, verbose=verbose, target_variables_list=[target_variables], save_dir=results_dir + "_loaded")
    
    # Clean up
    del experiment, raw_results
    clear_memory()

print("Inference completed!")

Loading trained models and running inference...
Testing DAS method...


Inference completed!
