# SCONE Causal Abstraction Experiment

This notebook implements a causal abstraction experiment on the SCONE (Semantic Composition through Negation) dataset. The experiment investigates how large language models process logical entailment by analyzing causal relationships between semantic variables and neural representations.

## Overview
- **Dataset**: SCONE entailment dataset with negation scoping
- **Model**: Llama-3.1-8B-Instruct
- **Method**: Causal intervention analysis using PyVene
- **Target**: Understanding how models process logical scope and entailment

## 1. Setup and Imports

In [1]:
!pip install --upgrade "numpy<2"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import os
import sys
import random
import torch
import re
import pyvene as pv
import json
import pandas as pd
from typing import Dict, List, Tuple

def setup_cuda():
    """Check CUDA availability and setup."""
    print("Built with CUDA:", torch.version.cuda)
    available = torch.cuda.is_available()
    print("cuda.is_available():", available)
    if available:
        print("Device name:", torch.cuda.get_device_name(0))
    return available

def add_project_root(root_path="/data/jason_lim/CausalAbstraction"):
    """Add project root to Python path."""
    if root_path not in sys.path:
        sys.path.append(root_path)
    return root_path

# Setup
add_project_root()
cuda_ok = setup_cuda()

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.
Built with CUDA: 12.1
cuda.is_available(): False


In [3]:
# Internal module imports
from causal.causal_model import CausalModel, CounterfactualDataset
from experiments.scone_filter_experiment import SconeFilterExperiment
from lm_units.fixed_pipeline import LMPipeline
from lm_units.LM_units import TokenPosition, get_last_token_index
from experiments.residual_stream_experiment import PatchResidualStream

In [4]:
import os
import sys

# Set the working directory
os.chdir('/data/jason_lim/CausalAbstraction')

# Add to Python path
sys.path.append('/data/jason_lim/CausalAbstraction')

# Verify the changes
print(f"Current working directory: {os.getcwd()}")
print(f"Path added to sys.path: {'/data/jason_lim/CausalAbstraction' in sys.path}")

Current working directory: /data/jason_lim/CausalAbstraction
Path added to sys.path: True


## 2. Data Loading Functions

In [5]:


def load_scone_datasets(data_dir="data", half=True):
    """Load and combine SCONE datasets."""
    if half:
        dataset_names = ["no_negation", "one_scoped", "one_not_scoped"]
    else:
        dataset_names = ["no_negation", "one_scoped", "one_not_scoped",
                        "one_scoped_one_not_scoped", "two_scoped", "two_not_scoped"]
    
    combined_datasets = {}
    for name in dataset_names:
        train_df = pd.read_csv(f"{data_dir}/train/{name}.csv")
        test_df = pd.read_csv(f"{data_dir}/test/{name}.csv")
        combined_datasets[name] = pd.concat([train_df, test_df], ignore_index=True)
    
    return combined_datasets

def load_entailment_data(file_path="data/lookup/entailment_lookup_filtered.jsonl"):
    """Load entailment lookup data."""
    entailment_data = []
    with open(file_path, 'r') as f:
        for line in f:
            try:
                entry = json.loads(line)
                entailment_data.append(entry)
            except json.JSONDecodeError:
                print(f"Failed to parse line: {line}")
    return entailment_data

# Load data
datasets_dict = load_scone_datasets(half=True)
entailment_data = load_entailment_data()

print(f"Loaded {len(datasets_dict)} datasets:")
for name, df in datasets_dict.items():
    print(f"  {name}: {len(df)} examples")
print(f"Loaded {len(entailment_data)} entailment entries")

Loaded 3 datasets:
  no_negation: 1202 examples
  one_scoped: 1202 examples
  one_not_scoped: 1202 examples
Loaded 499 entailment entries


## 3. Input/Output Processing Functions

In [6]:
def input_dumper(example):
    """Convert example to prompt format."""
    if isinstance(example, str):
        return example
    
    premise = example.get("PremiseSentence", "")
    hypothesis = example.get("HypothesisSentence", "")
    
    prompt = (
        f"Sentence 1: {premise}\n"
        f"Sentence 2: {hypothesis}\n"
        f"Does Sentence 1 entail Sentence 2? Please respond only with either 'entailment' or 'neutral'.\n"
        "Answer: "
    )
    return prompt

def output_dumper(setting):
    """Extract final relation from setting."""
    return setting['FinalRelation']

def checker(neural_output, causal_output, debug=False):
    """Enhanced checker that handles both old and new formats."""
    # Handle the case where causal_output is a string (old format)
    if isinstance(causal_output, str):
        expected_label = causal_output
    # Handle case where it's a dict with raw_output
    elif isinstance(causal_output, dict) and "raw_output" in causal_output:
        expected_label = causal_output["raw_output"]
    else:
        expected_label = str(causal_output)
    
    # Simple containment check
    return expected_label.lower().strip() in str(neural_output).lower().strip()

## 4. Model Pipeline Setup

In [7]:
def init_pipeline(model_name="meta-llama/Llama-3.1-8B-Instruct"):
    """Initialize the language model pipeline."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = LMPipeline(model_name, max_new_tokens=10, device=device, dtype=torch.float16)
    pipe.tokenizer.padding_side = "right"
    print(f"Loaded model on {pipe.model.device}")
    return pipe

# Initialize pipeline
pipeline = init_pipeline()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded model on cpu


## 5. Token Position Functions

In [None]:
def get_predicate_premise(prompt, pipeline, debug=False):
    """Return token indices for the premise predicate."""
    try:
        tokens = pipeline.tokenizer.tokenize(prompt)
        
        # Look for "Sentence 1:" and find predicate after it
        for i in range(len(tokens) - 1):
            if 'sentence' in tokens[i].lower() and '1' in tokens[i+1]:
                search_start = i + 2
                search_end = min(search_start + 20, len(tokens))
                
                for j in range(search_start, search_end):
                    clean_token = tokens[j].lower().replace('▁', '')
                    if len(clean_token) > 3 and clean_token not in ["that", "with", "from", "this"]:
                        return [j]
                
                return [search_start + 3]  
        
        return [5]  
    except Exception as e:
        if debug:
            print(f"Error in get_predicate_premise: {e}")
        return [5]

def get_predicate_hypothesis(prompt, pipeline, debug=False):
    """Return token indices for the hypothesis predicate."""
    try:
        tokens = pipeline.tokenizer.tokenize(prompt)
        
        # Look for "Sentence 2:" and find predicate after it
        for i in range(len(tokens) - 1):
            if 'sentence' in tokens[i].lower() and '2' in tokens[i+1]:
                search_start = i + 2
                search_end = min(search_start + 20, len(tokens))
                
                for j in range(search_start, search_end):
                    clean_token = tokens[j].lower().replace('▁', '')
                    if len(clean_token) > 3 and clean_token not in ["that", "with", "from", "this"]:
                        return [j]
                
                return [search_start + 3]
        
        return [15]  # Default fallback
    except Exception as e:
        if debug:
            print(f"Error in get_predicate_hypothesis: {e}")
        return [15]

def get_negation_index(prompt, pipeline, debug=False):
    """Return token indices for negation terms."""
    try:
        tokens = pipeline.tokenizer.tokenize(prompt)
        
        # Look for negation terms
        negation_terms = ["not", "n't", "no", "never", "none"]
        for i, token in enumerate(tokens):
            clean_token = token.lower().replace('▁', '')
            for term in negation_terms:
                if term in clean_token:
                    return [i]
        
        return [3]  # Fallback if no negation found
    except Exception as e:
        if debug:
            print(f"Error in get_negation_index: {e}")
        return [3]

## 6. Causal Model Definition

In [None]:
# Mechanism functions
def create_raw_input(premise_sentence, hypothesis_sentence):
    """Create the raw input prompt from premise and hypothesis."""
    if premise_sentence is None or hypothesis_sentence is None:
        return None
    
    prompt = (
        f"Sentence 1: {premise_sentence}\n"
        f"Sentence 2: {hypothesis_sentence}\n"
        f"Does Sentence 1 entail Sentence 2? Answer only 'entailment' or 'neutral'. \n "
        "Answer: "
    )
    return prompt

def create_raw_output(final_relation):
    """Create the raw output from final relation."""
    return final_relation

def compute_base_entailment(pred1, pred2):
    """Compute base entailment between predicates using entailment lookup."""
    entailment_lookup = {}
    for entry in entailment_data:
        if isinstance(entry, dict) and entry.get('label') == 1:
            entailment_lookup[(entry['hyponym'], entry['hypernym'])] = "entailment"
    return entailment_lookup.get((pred1, pred2), "neutral")

def compute_final_relation(base_entailment, scope):
    """Pure function: compute final relation from base entailment and scope"""
    if scope in ["no_negation", "one_not_scoped"]:  # Even scopes
        return base_entailment  # Pass through
    elif scope == "one_scoped":  # Odd scope  
        return "neutral" if base_entailment == "entailment" else "entailment"  # Flip
    else:
        return base_entailment  # Default fallback

def get_premise_sentence(pred1, pred2, scope, datasets_dict):
    """Get premise sentence based on predicates and scope."""
    if scope not in datasets_dict:
        return None
        
    df = datasets_dict[scope]
    
    if scope == "one_scoped":
        sent1_col = 'sentence1'
    else:
        sent1_col = 'sentence1_edited' if 'sentence1_edited' in df.columns else 'sentence1'
    
    matches = df[(df['sentence1_lex'] == pred1) & (df['sentence2_lex'] == pred2)]
    
    if matches.empty:
        return None
        
    chosen_row = matches.sample(n=1).iloc[0]
    return chosen_row[sent1_col]

def get_hypothesis_sentence(pred1, pred2, scope, datasets_dict):
    """Get hypothesis sentence based on predicates and scope."""
    if scope not in datasets_dict:
        return None
        
    df = datasets_dict[scope]
    
    if scope == "one_scoped":
        sent2_col = 'sentence2'
    else:
        sent2_col = 'sentence2_edited' if 'sentence2_edited' in df.columns else 'sentence2'
    
    matches = df[(df['sentence1_lex'] == pred1) & (df['sentence2_lex'] == pred2)]
    
    if matches.empty:
        return None
        
    chosen_row = matches.sample(n=1).iloc[0]
    return chosen_row[sent2_col]


In [None]:
class SconeCausalModel(CausalModel):
    """SCONE-specific causal model for entailment reasoning.
        Inherits from CausalModel and implements specific mechanisms for SCONE datasets."""
    
    def __init__(self, datasets_dict, entailment_data, variables, values, parents, mechanisms):
        super().__init__(variables, values, parents, mechanisms, id="SCONE_entailment")
        self.datasets_dict = datasets_dict
        self.entailment_data = entailment_data
        self.current_sentence_pair = None
        self.current_gold_label = None
        
        # Create entailment lookup from entailment_data
        self.entailment_lookup = {}
        for entry in entailment_data:
            if isinstance(entry, dict) and entry.get('label') == 1:
                self.entailment_lookup[(entry['hyponym'], entry['hypernym'])] = "entailment"
    
    def run_forward(self, intervention=None, debug=False):
        """Reset cached data at start of forward pass."""
        self.current_sentence_pair = None
        self.current_gold_label = None
        result = super().run_forward(intervention=intervention)
        return result
    
    def label_counterfactual_data(self, dataset, target_variables):
        """Proper intervention: Take original input, change only target variables, then recompute through causal model."""
        inputs = []
        counterfactual_inputs_list = []
        labels = []
        settings = []
        
        for i, example in enumerate(dataset):
            input_data = example["input"] 
            counterfactual_inputs = example["counterfactual_inputs"]
            
            if isinstance(counterfactual_inputs, list):
                counterfactual_input = counterfactual_inputs[0]
            else:
                counterfactual_input = counterfactual_inputs
            
            # PROPER INTERVENTION: Start with original, change only target variables
            intervention_input = {}

            # Always keep the exogenous variables (predicates)
            intervention_input["Predicate1"] = input_data["Predicate1"] 
            intervention_input["Predicate2"] = input_data["Predicate2"]
            intervention_input["Scope"] = input_data["Scope"]

            # Add the target variable intervention
            for var in target_variables:
                if var in counterfactual_input:
                    intervention_input[var] = counterfactual_input[var]
                                
            # Recompute through causal model with intervention
            intervention_result = self.run_forward(intervention=intervention_input)
            expected_output = intervention_result["raw_output"]
            
            # Create setting with intervention values
            setting = {}
            for var in target_variables:
                if var in intervention_input:
                    setting[var] = intervention_input[var]
            setting["raw_output"] = expected_output
            
            inputs.append(input_data)
            counterfactual_inputs_list.append(example["counterfactual_inputs"])
            labels.append(expected_output)
            settings.append(setting)
        
        data_dict = {
            "input": inputs,
            "counterfactual_inputs": counterfactual_inputs_list,
            "label": labels,
            "setting": settings
        }
        
        dataset_id = getattr(dataset, 'id', 'unknown')
        target_vars_str = "_".join(target_variables)
        return CounterfactualDataset.from_dict(data_dict, id=f"{dataset_id}_{target_vars_str}_labeled")

## 7. Counterfactual Generation Functions

In [None]:
def create_base_entailment_counterfactual(datasets_dict, scone_model, debug=False):
    """Create base entailment counterfactual by finding pairs with opposite BaseEntailment values."""
    allowed_scopes = [s for s in ['no_negation', 'one_scoped', 'one_not_scoped'] 
                     if s in datasets_dict]
    if not allowed_scopes:
        raise ValueError("None of the allowed scopes are available in datasets_dict.")
    
    # Sample original input
    orig_scope = random.choice(allowed_scopes)
    df_orig = datasets_dict[orig_scope]
    
    # Sample a random row
    idx = random.randint(0, len(df_orig) - 1)
    row_orig = df_orig.iloc[idx]
    
    # Extract predicate pair
    pred1 = row_orig.get("sentence1_lex", None)
    pred2 = row_orig.get("sentence2_lex", None)
    
    # Build original input
    original_input = {
        "Scope": orig_scope,
        "Predicate1": pred1,
        "Predicate2": pred2
    }
    
    # Run through causal model to get BaseEntailment
    original_result = scone_model.run_forward(intervention=original_input)
    original_base_entailment = original_result["BaseEntailment"]
    
    # Determine target BaseEntailment (opposite of original)
    target_base_entailment = "neutral" if original_base_entailment == "entailment" else "entailment"
    
    # Search for a counterfactual with opposite BaseEntailment
    max_attempts = 50
    counterfactual_found = False
    
    for _ in range(max_attempts):
        # Sample random input from any allowed scope
        random_scope = random.choice(allowed_scopes)
        df_random = datasets_dict[random_scope]
        random_idx = random.randint(0, len(df_random) - 1)
        row_random = df_random.iloc[random_idx]
        
        # Get predicates
        random_pred1 = row_random.get("sentence1_lex", None)
        random_pred2 = row_random.get("sentence2_lex", None)
        
        # Build candidate counterfactual
        candidate_input = {
            "Scope": random_scope,
            "Predicate1": random_pred1,
            "Predicate2": random_pred2
        }
        
        # Check BaseEntailment using causal model
        candidate_result = scone_model.run_forward(intervention=candidate_input)
        candidate_base_entailment = candidate_result["BaseEntailment"]
        
        if candidate_base_entailment == target_base_entailment:
            counterfactual_input = candidate_result
            counterfactual_found = True
            break
    
    if not counterfactual_found:
        retry = 0
        while retry < 10:
            cf_scope = random.choice(allowed_scopes)
            df_cf = datasets_dict[cf_scope]
            cf_idx = random.randint(0, len(df_cf) - 1)
            
            if cf_scope == orig_scope and cf_idx == idx:
                retry += 1
                continue
                
            row_cf = df_cf.iloc[cf_idx]
            cf_pred1 = row_cf.get("sentence1_lex", None)
            cf_pred2 = row_cf.get("sentence2_lex", None)
            
            fallback_input = {
                "Scope": cf_scope,
                "Predicate1": cf_pred1,
                "Predicate2": cf_pred2
            }
            counterfactual_input = scone_model.run_forward(intervention=fallback_input)
            break
    
    return {"input": original_result, "counterfactual_inputs": [counterfactual_input]}

def create_scope_counterfactual(datasets_dict, scone_model, debug=False):
    """Create scope counterfactual by toggling between even/odd scopes with randomized predicates."""
    allowed_scopes = [s for s in ['no_negation', 'one_scoped', 'one_not_scoped'] 
                     if s in datasets_dict]
    if not allowed_scopes:
        raise ValueError("None of the allowed scopes are available in datasets_dict.")
    
    # Choose an original scope at random
    orig_scope = random.choice(allowed_scopes)
    df_orig = datasets_dict[orig_scope]
    orig_idx = random.randint(0, len(df_orig) - 1)
    row_orig = df_orig.iloc[orig_idx]
    orig_pred1 = row_orig.get("sentence1_lex", None)
    orig_pred2 = row_orig.get("sentence2_lex", None)

    even_scopes = [s for s in ['no_negation', 'one_not_scoped'] if s in allowed_scopes]
    odd_scopes = [s for s in ['one_scoped'] if s in allowed_scopes]

    if orig_scope in even_scopes and odd_scopes:
        new_scope = odd_scopes[0]
    elif orig_scope in odd_scopes and even_scopes:
        new_scope = random.choice(even_scopes)
    else:
        available_scopes = [s for s in allowed_scopes if s != orig_scope]
        new_scope = random.choice(available_scopes) if available_scopes else orig_scope

    df_new = datasets_dict[new_scope]
    new_idx = random.randint(0, len(df_new) - 1)
    if new_scope == orig_scope:
        attempts = 0
        while new_idx == orig_idx and len(df_new) > 1 and attempts < 10:
            new_idx = random.randint(0, len(df_new) - 1)
            attempts += 1

    row_new = df_new.iloc[new_idx]
    new_pred1 = row_new.get("sentence1_lex", None)
    new_pred2 = row_new.get("sentence2_lex", None)

    original_input = {"Scope": orig_scope, "Predicate1": orig_pred1, "Predicate2": orig_pred2}
    counterfactual_input = {"Scope": new_scope, "Predicate1": new_pred1, "Predicate2": new_pred2}

    original_result = scone_model.run_forward(intervention=original_input)
    counterfactual_result = scone_model.run_forward(intervention=counterfactual_input)

    return {"input": original_result, "counterfactual_inputs": [counterfactual_result]}

def build_counterfactual_datasets(datasets_dict, scone_model, size=100):
    """Build counterfactual datasets for SCONE with base entailment type."""
    counterfactual_dataset = {
        "base_entailment_counterfactuals": CounterfactualDataset.from_sampler(
            size, lambda: create_base_entailment_counterfactual(datasets_dict, scone_model)
        ),
        "scope_counterfactuals": CounterfactualDataset.from_sampler(
            size, lambda: create_scope_counterfactual(datasets_dict, scone_model)
        )
    }
    return counterfactual_dataset

## 8. Causal Model Setup

In [None]:
# Create valid predicate pairs from entailment data
valid_pairs = []
for entry in entailment_data:
    if entry['label'] == 1:
        valid_pairs.append((entry['hyponym'], entry['hypernym']))

print(f"Found {len(valid_pairs)} valid predicate pairs")

# Define causal model components
variables = ["Predicate1", "Predicate2", "Scope", "BaseEntailment", 
            "FinalRelation", "PremiseSentence", "HypothesisSentence", 
            "raw_input", "raw_output"]

scope_types = ["no_negation", "one_scoped", "one_not_scoped"]

values = {
    "Predicate1": [pair[0] for pair in valid_pairs],
    "Predicate2": [pair[1] for pair in valid_pairs], 
    "Scope": scope_types,
    "BaseEntailment": ["entailment", "neutral"],
    "FinalRelation": ["entailment", "neutral"],
    "PremiseSentence": None,
    "HypothesisSentence": None,
    "raw_input": None,
    "raw_output": ["entailment", "neutral"]
}

parents = {
    "Predicate1": [],
    "Predicate2": [],
    "Scope": [],
    "BaseEntailment": ["Predicate1", "Predicate2"],
    "FinalRelation": ["BaseEntailment", "Scope"],
    "PremiseSentence": ["Predicate1", "Predicate2", "Scope"],
    "HypothesisSentence": ["Predicate1", "Predicate2", "Scope"],
    "raw_input": ["PremiseSentence", "HypothesisSentence"],
    "raw_output": ["FinalRelation"]
}

mechanisms = {
    "Predicate1": lambda: random.choice([pair[0] for pair in valid_pairs]),
    "Predicate2": lambda: random.choice([pair[1] for pair in valid_pairs]),
    "Scope": lambda: random.choice(scope_types),
    "BaseEntailment": compute_base_entailment,
    "FinalRelation": compute_final_relation,
    "PremiseSentence": lambda p1, p2, scope: get_premise_sentence(p1, p2, scope, datasets_dict),
    "HypothesisSentence": lambda p1, p2, scope: get_hypothesis_sentence(p1, p2, scope, datasets_dict),
    "raw_input": create_raw_input,
    "raw_output": create_raw_output
}

# Create SCONE causal model
scone_model = SconeCausalModel(datasets_dict, entailment_data, variables, values, parents, mechanisms)

print("Causal model created successfully!")

## 9. Dataset Filtering and Preparation

In [None]:
def run_filter_experiment(pipeline, scone_model, datasets):
    """Filter datasets based on model performance using custom SCONE filter."""
    print("\nFiltering datasets based on model performance...")
    exp = SconeFilterExperiment(pipeline, scone_model, checker, debug=True)
    return exp.filter(datasets, verbose=True, batch_size=1)

def save_filtered_datasets(filtered_datasets):
    """Save filtered datasets to data/counterfactuals_filtered folder."""
    save_dir = "data/counterfactuals_filtered"
    os.makedirs(save_dir, exist_ok=True)

    for dataset_name, dataset in filtered_datasets.items():
        dataset_path = os.path.join(save_dir, f"{dataset_name}.jsonl")
        with open(dataset_path, 'w') as f:
            for example in dataset:
                json.dump(example, f)
                f.write('\n')
        print(f"Saved {len(dataset)} examples to {dataset_path}")

# Check if filtered datasets already exist
filtered_datasets_path = "data/counterfactuals_filtered"
if not os.path.exists(filtered_datasets_path) or not os.listdir(filtered_datasets_path):
    print("No filtered datasets found - generating new counterfactual datasets...")
    counterfactual_datasets = build_counterfactual_datasets(datasets_dict, scone_model, size=200)
    filtered_datasets = run_filter_experiment(pipeline, scone_model, counterfactual_datasets)
    save_filtered_datasets(filtered_datasets)
else:
    print("Loading existing filtered datasets...")
    filtered_datasets = {}
    for filename in os.listdir(filtered_datasets_path):
        dataset_name = filename.replace('.jsonl', '')
        dataset_path = os.path.join(filtered_datasets_path, filename)
        examples = []
        with open(dataset_path, 'r') as f:
            for line in f:
                examples.append(json.loads(line))
        filtered_datasets[dataset_name] = examples
        print(f"Loaded {len(examples)} examples from {filename}")

print(f"\nFiltered datasets ready: {list(filtered_datasets.keys())}")

## 10. Token Position Setup

In [None]:
# Define token positions for intervention
token_positions = [
    TokenPosition(lambda x: get_predicate_premise(x, pipeline), pipeline, id="premise_predicate"),
    TokenPosition(lambda x: get_predicate_hypothesis(x, pipeline), pipeline, id="hypothesis_predicate"),
    TokenPosition(lambda x: get_negation_index(x, pipeline), pipeline, id="negation"),
    TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token")
]

print(f"Created {len(token_positions)} token position functions")

## 11. Experiment Configuration and Execution

In [None]:
# Experiment configuration
start_layer = 22
end_layer = pipeline.get_num_layers()  # Full layers experiment
target_variables_list = [["FinalRelation"], ["BaseEntailment"], ["Scope"]]
results_dir = "SCONE_demo_results"

config = {
    "batch_size": 16, 
    "evaluation_batch_size": 128, 
    "training_epoch": 32, 
    "n_features": 16, 
}

print(f"Experiment configuration:")
print(f"  Layers: {start_layer} to {end_layer}")
print(f"  Target variables: {target_variables_list}")
print(f"  Results directory: {results_dir}")
print(f"  Config: {config}")

In [None]:
# Create and run the experiment
experiment = PatchResidualStream(
    pipeline, 
    scone_model, 
    list(range(start_layer, end_layer)), 
    token_positions, 
    checker, 
    config=config
)

print("Experiment object created. Starting intervention training...")

In [None]:
# Train interventions
print("Training interventions for Scope variable...")
experiment.train_interventions(
    filtered_datasets, 
    ["Scope"], 
    method="DAS", 
    verbose=False
)

print("Training completed!")

In [None]:
# Perform interventions and get results
# This will save the results in json format
print("Performing interventions...")
raw_results = experiment.perform_interventions(
    filtered_datasets, 
    verbose=True, 
    target_variables_list=target_variables_list, 
    save_dir=results_dir
)

print(f"Interventions completed! Results saved to {results_dir}")

## 12. Results Analysis

In [None]:
# Display results summary
print("\n=== EXPERIMENT RESULTS SUMMARY ===")
print(f"Results saved to: {results_dir}")
print(f"Target variables analyzed: {target_variables_list}")
print(f"Layers analyzed: {start_layer} to {end_layer-1}")
print(f"Token positions: {[pos.id for pos in token_positions]}")

if raw_results:
    print(f"\nRaw results keys: {list(raw_results.keys())}")
    first_target = target_variables_list[0]
    target_key = "_".join(first_target)
    
    if target_key in raw_results:
        sample_result = raw_results[target_key]
        print(f"\nSample results for {first_target}:")
        print(f"  Type: {type(sample_result)}")
        if hasattr(sample_result, 'shape'):
            print(f"  Shape: {sample_result.shape}")
        elif isinstance(sample_result, dict):
            print(f"  Keys: {list(sample_result.keys())}")

print("\nExperiment completed successfully!")

## Summary

This notebook implements a causal abstraction experiment on the SCONE dataset to understand how language models process logical entailment and negation scope. The experiment:

1. Loads SCONE data with different negation scoping patterns
2. Defines a causal model with variables for predicates, scope, base entailment, and final relation
3. Generates counterfactual datasets by systematically varying these causal variables
4. Filters datasets to ensure the model can solve the base task
5. Trains interventions using Distributed Alignment Search (DAS)
6. Performs causal interventions on neural representations at different layers
7. Analyzes results to understand which layers encode which semantic variables