# Natural Language Reinforcement Learning using Goodfire SDK (WIP)
Following the steps outlined in the paper: https://docs.google.com/document/d/1GbzS4d0Ml9BMvK97i_TOEOwnGg-Zg270W2WNN9KHrG4/edit?tab=t.0

Rough steps (not same number as cells for now)
1. Query Selection & Answer Verification
2. COT Decomposition & Critical Token Identification
3. Feature Attribution & Analysis
4. Feature Steering
5. Robustness Testing & Iteration

# Cell 1: Setup and Initialization
Setup required packages, credentials and base model.
- Installs required packages
- Loads API keys
- Initializes Goodfire client and model variant

In [None]:
# Install required packages
!pip install goodfire --quiet
!pip install datasets --quiet
!pip install tqdm --quiet

from google.colab import userdata
import goodfire
import asyncio
from tqdm.auto import tqdm
import pandas as pd
from typing import List, Dict, Any
import json
from typing import List
from dataclasses import dataclass

@dataclass
class Example:
    query: str
    response: str
    is_correct: bool

# Add your Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')


# Initialize SDK clients
client = goodfire.Client(GOODFIRE_API_KEY)
async_client = goodfire.AsyncClient(GOODFIRE_API_KEY)

# Instantiate a model variant
variant = goodfire.Variant("meta-llama/Meta-Llama-3.1-8B-Instruct")

# for my reference: Which is bigger, 9.9 or 9.11?
# https://platform.goodfire.ai/chat/ab5cc226-aa37-404c-87fd-4fb7949944c7
# https://platform.goodfire.ai/chat/0a670b58-0b76-44c1-a7cb-8356347d83c9

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.6/40.6 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gensim 4.3.3 requires scipy<1.14.0,>=1.7.0, but you have scipy 1.15.0 which is incompatible.[0m[31m


# Cell 2: Data Structures and Core Classes
Core data structures and utility classes.
- PuzzleData: Represents puzzle structure
- Solution: Holds solution attempts and analysis
- DatasetManager: Handles dataset operations
- SolverLogger: Manages logging and debugging

In [None]:
import json
from dataclasses import dataclass
from typing import List, Dict, Optional
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm

@dataclass
class Response:
    question: str
    answer: str
    response: str
    is_correct: bool
    target: str
    doc_id: int

class GSM8KAnalyzer:
    def __init__(self, client, variant):
        self.client = client
        self.variant = variant
        self.zeroshot_responses: List[Response] = []
        self.fewshot_responses: List[Response] = []

    def load_responses(self, zeroshot_path: str, fewshot_path: str):
        """Load responses from JSONL files"""
        def load_file(path: str) -> List[Response]:
            responses = []
            with open(path, 'r') as f:
                for line in f:
                    data = json.loads(line)
                    # Extract the actual model response from the nested structure
                    model_response = data["resps"][0][0] if data["resps"] else ""
                    responses.append(Response(
                        question=data["doc"]["question"],
                        answer=data["doc"]["answer"],
                        response=model_response,
                        is_correct=data.get("exact_match", 0) == 1.0,
                        target=data["target"],
                        doc_id=data["doc_id"]
                    ))
            return responses

        self.zeroshot_responses = load_file(zeroshot_path)
        self.fewshot_responses = load_file(fewshot_path)

        print(f"Loaded {len(self.zeroshot_responses)} zero-shot responses")
        print(f"Loaded {len(self.fewshot_responses)} few-shot responses")

    def find_matching_pairs(self) -> List[tuple[Response, Response]]:
        """Find pairs of responses to the same question"""
        pairs = []
        zero_dict = {r.doc_id: r for r in self.zeroshot_responses}
        few_dict = {r.doc_id: r for r in self.fewshot_responses}

        for doc_id in set(zero_dict.keys()) & set(few_dict.keys()):
            pairs.append((zero_dict[doc_id], few_dict[doc_id]))

        return pairs

    async def analyze_feature_differences(self, max_pairs: int = 50):
        """Analyze feature differences between zero-shot and few-shot responses"""
        pairs = self.find_matching_pairs()[:max_pairs]

        # Split into correct/incorrect pairs
        improved_pairs = [(zero, few) for zero, few in pairs
                         if not zero.is_correct and few.is_correct]
        print(f"\nFound {len(improved_pairs)} pairs where few-shot improved accuracy")

        # Format for contrast analysis
        improved_zero = [[
            {"role": "user", "content": zero.question},
            {"role": "assistant", "content": zero.response}
        ] for zero, _ in improved_pairs]

        improved_few = [[
            {"role": "user", "content": few.question},
            {"role": "assistant", "content": few.response}
        ] for _, few in improved_pairs]

        print("\nRunning contrastive analysis...")
        # Get features that distinguish incorrect zero-shot from correct few-shot
        zero_features, few_features = self.client.features.contrast(
            dataset_1=improved_zero,
            dataset_2=improved_few,
            model=self.variant,
            top_k=20
        )

        # Rerank few-shot features to focus on mathematical reasoning
        few_features = self.client.features.rerank(
            features=few_features,
            query="step by step mathematical reasoning and problem solving",
            model=self.variant,
            top_k=10
        )

        print("\nAnalyzing feature activations...")
        activation_data = []

        # Analyze each pair
        for i, (zero, few) in enumerate(improved_pairs):
            # Inspect zero-shot response
            zero_context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": zero.question},
                    {"role": "assistant", "content": zero.response}
                ],
                model=self.variant,
                features=few_features
            )

            # Inspect few-shot response
            few_context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": few.question},
                    {"role": "assistant", "content": few.response}
                ],
                model=self.variant,
                features=few_features
            )

            # Get top activations
            zero_acts = zero_context.top(k=len(few_features))
            few_acts = few_context.top(k=len(few_features))

            # Store activation data
            for z_act, f_act in zip(zero_acts, few_acts):
                activation_data.append({
                    'pair_id': i,
                    'question': zero.question,
                    'feature': z_act.feature.label,
                    'zero_activation': z_act.activation,
                    'few_activation': f_act.activation,
                    'activation_diff': f_act.activation - z_act.activation
                })

        return pd.DataFrame(activation_data)

    def summarize_differences(self, df: pd.DataFrame):
        """Generate summary statistics and key findings"""
        print("\n=== Feature Activation Analysis ===")

        # Average activation differences by feature
        avg_diffs = df.groupby('feature')['activation_diff'].agg(['mean', 'std']).round(3)
        print("\nAverage activation differences (few-shot - zero-shot):")
        print(avg_diffs)

        # Features with largest differences
        print("\nTop features with largest activation differences:")
        top_features = avg_diffs.sort_values('mean', key=abs, ascending=False).head()
        print(top_features)

        # Per-question analysis
        print("\nQuestions with largest feature differences:")
        q_diffs = df.groupby(['question', 'feature'])['activation_diff'].mean()
        print(q_diffs.sort_values(key=abs, ascending=False).head(10))

        return {
            'avg_differences': avg_diffs,
            'top_features': top_features,
            'question_differences': q_diffs
        }

# Example usage:
# analyzer = GSM8KAnalyzer(client, variant)
# analyzer.load_responses('zeroshot.jsonl', 'fewshot.jsonl')
# df = await analyzer.analyze_feature_differences()
# results = analyzer.summarize_differences(df)

# Cell 3: Solution Generation
Solution generation and reasoning components.
- SolutionGenerator: Handles generating solutions and COT reasoning
- Includes prompt formatting and decomposition logic
- Manages both naive and correct solution generation

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import json
from dataclasses import dataclass
from typing import List, Dict, Optional
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
import numpy as np

@dataclass
class Response:
    question: str
    answer: str
    response: str
    is_correct: bool
    target: str
    doc_id: int

class GSM8KAnalyzer:
    def __init__(self, client, variant):
        self.client = client
        self.variant = variant
        self.zeroshot_responses: List[Response] = []
        self.fewshot_responses: List[Response] = []

    def load_responses(self, zeroshot_path: str, fewshot_path: str):
        """Load responses from JSONL files"""
        # Get full paths in Google Drive
        zeroshot_path = get_full_path(zeroshot_path)
        fewshot_path = get_full_path(fewshot_path)
        def load_file(path: str) -> List[Response]:
            responses = []
            with open(path, 'r') as f:
                for line in f:
                    data = json.loads(line)
                    # Extract the actual model response from the nested structure
                    model_response = data["resps"][0][0] if data["resps"] else ""
                    responses.append(Response(
                        question=data["doc"]["question"],
                        answer=data["doc"]["answer"],
                        response=model_response,
                        is_correct=data.get("exact_match", 0) == 1.0,
                        target=data["target"],
                        doc_id=data["doc_id"]
                    ))
            return responses

        print(f"Loading files from:\n{zeroshot_path}\n{fewshot_path}")
        self.zeroshot_responses = load_file(zeroshot_path)
        self.fewshot_responses = load_file(fewshot_path)

        print(f"Loaded {len(self.zeroshot_responses)} zero-shot responses")
        print(f"Loaded {len(self.fewshot_responses)} few-shot responses")

    def find_matching_pairs(self) -> List[tuple[Response, Response]]:
        """Find pairs of responses to the same question"""
        pairs = []
        zero_dict = {r.doc_id: r for r in self.zeroshot_responses}
        few_dict = {r.doc_id: r for r in self.fewshot_responses}

        for doc_id in set(zero_dict.keys()) & set(few_dict.keys()):
            pairs.append((zero_dict[doc_id], few_dict[doc_id]))

        print(f"Found {len(pairs)} matching question pairs")
        return pairs

    async def analyze_feature_differences(self, max_pairs: int = 50):
        """Analyze feature differences between zero-shot and few-shot responses"""
        print(f"\nAnalyzing up to {max_pairs} response pairs...")
        pairs = self.find_matching_pairs()[:max_pairs]

        # Split into correct/incorrect pairs
        improved_pairs = [(zero, few) for zero, few in pairs
                         if not zero.is_correct and few.is_correct]
        print(f"\nFound {len(improved_pairs)} pairs where few-shot improved accuracy")

        if not improved_pairs:
            print("No improved pairs found to analyze")
            return pd.DataFrame()

        # Format for contrast analysis
        improved_zero = [[
            {"role": "user", "content": zero.question},
            {"role": "assistant", "content": zero.response}
        ] for zero, _ in improved_pairs]

        improved_few = [[
            {"role": "user", "content": few.question},
            {"role": "assistant", "content": few.response}
        ] for _, few in improved_pairs]

        print("\nRunning contrastive analysis...")
        # Get features that distinguish incorrect zero-shot from correct few-shot
        zero_features, few_features = self.client.features.contrast(
            dataset_1=improved_zero,
            dataset_2=improved_few,
            model=self.variant,
            top_k=20
        )

        # Rerank few-shot features to focus on mathematical reasoning
        few_features = self.client.features.rerank(
            features=few_features,
            query="step by step mathematical reasoning and problem solving",
            model=self.variant,
            top_k=10
        )

        print("\nAnalyzing feature activations...")
        activation_data = []

        # Analyze each pair
        for i, (zero, few) in enumerate(tqdm(improved_pairs)):
            # Inspect zero-shot response
            zero_context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": zero.question},
                    {"role": "assistant", "content": zero.response}
                ],
                model=self.variant,
                features=few_features
            )

            # Inspect few-shot response
            few_context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": few.question},
                    {"role": "assistant", "content": few.response}
                ],
                model=self.variant,
                features=few_features
            )

            # Get top activations
            zero_acts = zero_context.top(k=len(few_features))
            few_acts = few_context.top(k=len(few_features))

            # Store activation data
            for z_act, f_act in zip(zero_acts, few_acts):
                activation_data.append({
                    'pair_id': i,
                    'question': zero.question[:100] + "...",  # Truncate for display
                    'feature': z_act.feature.label,
                    'zero_activation': float(z_act.activation),
                    'few_activation': float(f_act.activation),
                    'activation_diff': float(f_act.activation - z_act.activation)
                })

        return pd.DataFrame(activation_data)

    def summarize_differences(self, df: pd.DataFrame):
        """Generate summary statistics and key findings"""
        if df.empty:
            print("No data to summarize")
            return {}

        print("\n=== Feature Activation Analysis ===")

        # Average activation differences by feature
        avg_diffs = df.groupby('feature')['activation_diff'].agg(['mean', 'std']).round(3)
        print("\nAverage activation differences (few-shot - zero-shot):")
        print(avg_diffs)

        # Features with largest differences
        print("\nTop features with largest activation differences:")
        top_features = avg_diffs.sort_values('mean', key=abs, ascending=False).head()
        print(top_features)

        # Per-question analysis
        print("\nQuestions with largest feature differences:")
        q_diffs = df.groupby(['question', 'feature'])['activation_diff'].mean()
        significant_diffs = q_diffs[abs(q_diffs) > q_diffs.std()]
        print(significant_diffs.sort_values(key=abs, ascending=False).head(10))

        return {
            'avg_differences': avg_diffs.to_dict(),
            'top_features': top_features.to_dict(),
            'question_differences': significant_diffs.to_dict()
        }

In [None]:
import os
import glob

def list_jsonl_files(directory='natural_language_rl'):
    """List all JSONL files in the specified directory"""
    # Construct full path to directory in Google Drive
    drive_dir = os.path.join('/content/drive/MyDrive', directory)

    # Check if directory exists
    if not os.path.exists(drive_dir):
        print(f"Directory not found: {drive_dir}")
        return []

    # Find all .jsonl files
    jsonl_files = glob.glob(os.path.join(drive_dir, '*.jsonl'))

    # Get just the filenames without the full path
    return [os.path.basename(f) for f in jsonl_files]

async def setup_analysis():
    """Setup and run initial analysis"""
    # List available files
    files = list_jsonl_files()
    if not files:
        print("No JSONL files found in the natural_language_rl directory")
        return None, []

    print("Available JSONL files:")
    for i, file in enumerate(files):
        print(f"{i+1}. {file}")

    # Initialize analyzer
    analyzer = GSM8KAnalyzer(client, variant)

    return analyzer, files

# Update GSM8KAnalyzer's load_responses method to use the correct path
def get_full_path(filename: str) -> str:
    """Get full path to file in Google Drive"""
    return os.path.join('/content/drive/MyDrive/natural_language_rl', filename)

# Cell 4: Evaluation and Analysis
Evaluation and analysis components.
- Evaluator: Handles grading and pattern analysis
- ConnectionsSolver: Main solving orchestration
- Includes detailed result printing and analysis

In [None]:
# First all the imports and base class
import json
from dataclasses import dataclass
from typing import List, Dict, Optional
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
import numpy as np
import os
import glob
import re
from google.colab import drive

# Mount drive if not already mounted
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

@dataclass
class Response:
    question: str
    answer: str
    response: str
    is_correct: bool
    target: str
    doc_id: int

def list_jsonl_files(directory='natural_language_rl'):
    """List all JSONL files in the specified directory"""
    drive_dir = os.path.join('/content/drive/MyDrive', directory)

    if not os.path.exists(drive_dir):
        print(f"Directory not found: {drive_dir}")
        return []

    jsonl_files = glob.glob(os.path.join(drive_dir, '*.jsonl'))
    return [os.path.basename(f) for f in jsonl_files]

class GSM8KAnalyzer:
    def __init__(self, client, variant, analysis_mode="full"):
        self.client = client
        self.variant = variant
        self.analysis_mode = analysis_mode
        self.zeroshot_responses: List[Response] = []
        self.fewshot_responses: List[Response] = []

    def load_responses(self, zeroshot_path: str, fewshot_path: str):
        """Load responses from JSONL files"""
        def load_file(path: str) -> List[Response]:
            responses = []
            with open(path, 'r') as f:
                for line in f:
                    data = json.loads(line)
                    model_response = data["resps"][0][0] if data["resps"] else ""
                    responses.append(Response(
                        question=data["doc"]["question"],
                        answer=data["doc"]["answer"],
                        response=model_response,
                        is_correct=data.get("exact_match", 0) == 1.0,
                        target=data["target"],
                        doc_id=data["doc_id"]
                    ))
            return responses

        zeroshot_full = os.path.join('/content/drive/MyDrive/natural_language_rl', zeroshot_path)
        fewshot_full = os.path.join('/content/drive/MyDrive/natural_language_rl', fewshot_path)

        print(f"Loading response files...")
        self.zeroshot_responses = load_file(zeroshot_full)
        self.fewshot_responses = load_file(fewshot_full)

        print(f"Loaded {len(self.zeroshot_responses)} zero-shot responses")
        print(f"Loaded {len(self.fewshot_responses)} few-shot responses")

    def find_matching_pairs(self) -> List[tuple[Response, Response]]:
        """Find pairs of responses to the same question"""
        pairs = []
        zero_dict = {r.doc_id: r for r in self.zeroshot_responses}
        few_dict = {r.doc_id: r for r in self.fewshot_responses}

        # Calculate accuracy stats
        zero_correct = sum(1 for r in self.zeroshot_responses if r.is_correct)
        few_correct = sum(1 for r in self.fewshot_responses if r.is_correct)

        print(f"\nAccuracy stats:")
        print(f"Zero-shot: {zero_correct}/{len(self.zeroshot_responses)} correct ({zero_correct/len(self.zeroshot_responses)*100:.1f}%)")
        print(f"Few-shot: {few_correct}/{len(self.fewshot_responses)} correct ({few_correct/len(self.fewshot_responses)*100:.1f}%)")

        # Find matching pairs and analyze improvements
        matching_pairs = []
        improved_count = 0
        degraded_count = 0
        same_count = 0

        for doc_id in set(zero_dict.keys()) & set(few_dict.keys()):
            zero_resp = zero_dict[doc_id]
            few_resp = few_dict[doc_id]
            pairs.append((zero_resp, few_resp))

            if not zero_resp.is_correct and few_resp.is_correct:
                improved_count += 1
            elif zero_resp.is_correct and not few_resp.is_correct:
                degraded_count += 1
            else:
                same_count += 1

        print(f"\nPair analysis:")
        print(f"Total matching pairs: {len(pairs)}")
        print(f"Improved (0→1): {improved_count} pairs")
        print(f"Degraded (1→0): {degraded_count} pairs")
        print(f"Same result: {same_count} pairs")

        return pairs

    def extract_final_answer(self, response: str) -> str:
        """Extract the final numerical answer from a response"""
        numbers = re.findall(r'-?\d+\.?\d*', response)
        return numbers[-1] if numbers else ""

    async def analyze_feature_differences(self, max_pairs: int = 50):
        """Analyze feature differences between zero-shot and few-shot responses"""
        print(f"\nAnalysis mode: {self.analysis_mode}")
        all_pairs = self.find_matching_pairs()

        improved_pairs = [(zero, few) for zero, few in all_pairs
                         if not zero.is_correct and few.is_correct]

        print(f"\nFound {len(improved_pairs)} pairs where few-shot improved accuracy")
        print("\nExample improvement cases:")
        for i, (zero, few) in enumerate(improved_pairs[:3]):
            print(f"\nImprovement Case {i+1}:")
            print(f"Question prompt being analyzed:\n{zero.question}")
            print(f"Correct target answer: {zero.target}")
            print(f"\nZero-shot full response being analyzed:\n{zero.response}")
            print(f"\nFew-shot full response being analyzed:\n{few.response}")
            if self.analysis_mode == "final":
                print(f"\nFinal answer tokens being analyzed:")
                print(f"Zero-shot final: {self.extract_final_answer(zero.response)}")
                print(f"Few-shot final: {self.extract_final_answer(few.response)}")
            print("-" * 80)

        improved_pairs = improved_pairs[:max_pairs]

        # Format messages based on analysis mode
        def format_message_pair(zero, few):
            if self.analysis_mode == "final":
                return (
                    [{"role": "user", "content": zero.question},
                     {"role": "assistant", "content": self.extract_final_answer(zero.response)}],
                    [{"role": "user", "content": few.question},
                     {"role": "assistant", "content": self.extract_final_answer(few.response)}]
                )
            else:
                return (
                    [{"role": "user", "content": zero.question},
                     {"role": "assistant", "content": zero.response}],
                    [{"role": "user", "content": few.question},
                     {"role": "assistant", "content": few.response}]
                )

        improved_zero = []
        improved_few = []
        for zero, few in improved_pairs:
            zero_msg, few_msg = format_message_pair(zero, few)
            improved_zero.append(zero_msg)
            improved_few.append(few_msg)

        print("\nRunning contrastive analysis...")
        zero_features, few_features = self.client.features.contrast(
            dataset_1=improved_zero,
            dataset_2=improved_few,
            model=self.variant,
            top_k=50
        )

        if self.analysis_mode == "final":
            query = "numerical output format and calculation result patterns"
        else:
            query = "mathematical problem solving patterns and calculation strategies"

        few_features = self.client.features.rerank(
            features=few_features,
            query=query,
            model=self.variant,
            top_k=10
        )

        print("\nTop solution patterns found:")
        meaningful_features = [f for f in few_features if not f.label.startswith('feature_')][:5]
        for i, feat in enumerate(meaningful_features, 1):
            print(f"{i}. {feat.label}")

        print("\nAnalyzing feature activations and logits...")
        activation_data = []

        for pair_idx, (zero, few) in enumerate(tqdm(improved_pairs)):
            # Get messages for this pair
            zero_msg, few_msg = format_message_pair(zero, few)

            # Inspect features
            zero_context = self.client.features.inspect(
                messages=zero_msg,
                model=self.variant,
                features=few_features
            )

            few_context = self.client.features.inspect(
                messages=few_msg,
                model=self.variant,
                features=few_features
            )

            # Get logits for final answers if in final mode
            if self.analysis_mode == "final":
                zero_final = self.extract_final_answer(zero.response)
                few_final = self.extract_final_answer(few.response)

                zero_logits = self.client.chat.logits(
                    messages=zero_msg[:-1],  # Exclude the last message
                    model=self.variant,
                    filter_vocabulary=[zero_final]
                ).logits if zero_final else {}

                few_logits = self.client.chat.logits(
                    messages=few_msg[:-1],  # Exclude the last message
                    model=self.variant,
                    filter_vocabulary=[few_final]
                ).logits if few_final else {}

                print(f"\nLogits comparison for pair {pair_idx + 1}:")
                print(f"Zero-shot final answer '{zero_final}' logits: {zero_logits}")
                print(f"Few-shot final answer '{few_final}' logits: {few_logits}")

            # Process activations
            zero_acts = {act.feature.label: act.activation for act in zero_context.top(k=len(few_features))}
            few_acts = {act.feature.label: act.activation for act in few_context.top(k=len(few_features))}

            for feature in few_features:
                label = feature.label
                if not label.startswith('feature_'):
                    activation_data.append({
                        'pair_id': pair_idx,
                        'question': zero.question[:100] + "...",
                        'feature': label,
                        'zero_activation': float(zero_acts.get(label, 0)),
                        'few_activation': float(few_acts.get(label, 0)),
                        'activation_diff': float(few_acts.get(label, 0) - zero_acts.get(label, 0))
                    })

        return pd.DataFrame(activation_data)

    def summarize_differences(self, df: pd.DataFrame):
        """Generate summary statistics and key findings"""
        if df.empty:
            print("No data to summarize")
            return {}

        print("\n=== Feature Activation Analysis ===")

        # Filter out features with no meaningful difference
        meaningful_diffs = df[abs(df['activation_diff']) > 0.1]
        if meaningful_diffs.empty:
            print("No meaningful feature activation differences found")
            return {}

        # Average activation differences by feature
        avg_diffs = meaningful_diffs.groupby('feature').agg({
            'activation_diff': ['mean', 'std', 'count']
        }).round(3)
        avg_diffs = avg_diffs.sort_values(('activation_diff', 'mean'), key=abs, ascending=False)

        print("\nMost significant feature differences (few-shot vs zero-shot):")
        for feature in avg_diffs.index[:5]:
            stats = avg_diffs.loc[feature]
            mean = stats[('activation_diff', 'mean')]
            std = stats[('activation_diff', 'std')]
            count = stats[('activation_diff', 'count')]
            print(f"\n{feature}:")
            print(f"  Mean difference: {mean:+.3f}")
            print(f"  Std deviation: {std:.3f}")
            print(f"  Found in {count} pairs")

        return {
            'significant_features': avg_diffs.to_dict(),
            'feature_examples': meaningful_diffs.to_dict()
        }

# Now run the selector code
print("Select analysis mode:")
print("1. Full response analysis")
print("2. Final answer token analysis")
mode = input("Enter choice (1 or 2): ")

Select analysis mode:
1. Full response analysis
2. Final answer token analysis
Enter choice (1 or 2): 2


In [None]:
import json
from dataclasses import dataclass
from typing import List, Dict, Optional
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
import numpy as np
import os
import glob

# First mount drive if not already mounted
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

@dataclass
class Response:
    question: str
    answer: str
    response: str
    is_correct: bool
    target: str
    doc_id: int

class GSM8KAnalyzer:
    def __init__(self, client, variant):
        self.client = client
        self.variant = variant
        self.zeroshot_responses: List[Response] = []
        self.fewshot_responses: List[Response] = []

    def load_responses(self, zeroshot_path: str, fewshot_path: str):
        """Load responses from JSONL files"""
        def load_file(path: str) -> List[Response]:
            responses = []
            with open(path, 'r') as f:
                for line in f:
                    data = json.loads(line)
                    # Extract the actual model response from the nested structure
                    model_response = data["resps"][0][0] if data["resps"] else ""
                    is_correct = data.get("exact_match", 0) == 1.0
                    print(f"Reading response {len(responses)+1}:")
                    print(f"exact_match value: {data.get('exact_match', 'not found')}")
                    print(f"is_correct parsed as: {is_correct}")
                    if len(responses) < 2:  # Print details for first few entries
                        print(f"Question: {data['doc']['question'][:100]}...")
                        print(f"Target: {data['target']}")
                        print("---")
                    responses.append(Response(
                        question=data["doc"]["question"],
                        answer=data["doc"]["answer"],
                        response=model_response,
                        is_correct=is_correct,
                        target=data["target"],
                        doc_id=data["doc_id"]
                    ))
            return responses

        # Get full paths in Drive
        zeroshot_full = os.path.join('/content/drive/MyDrive/natural_language_rl', zeroshot_path)
        fewshot_full = os.path.join('/content/drive/MyDrive/natural_language_rl', fewshot_path)

        print(f"Loading files from:\n{zeroshot_full}\n{fewshot_full}")
        self.zeroshot_responses = load_file(zeroshot_full)
        self.fewshot_responses = load_file(fewshot_full)

        print(f"Loaded {len(self.zeroshot_responses)} zero-shot responses")
        print(f"Loaded {len(self.fewshot_responses)} few-shot responses")

    def find_matching_pairs(self) -> List[tuple[Response, Response]]:
        """Find pairs of responses to the same question"""
        pairs = []
        zero_dict = {r.doc_id: r for r in self.zeroshot_responses}
        few_dict = {r.doc_id: r for r in self.fewshot_responses}

        # Calculate accuracy stats for each set
        zero_correct = sum(1 for r in self.zeroshot_responses if r.is_correct)
        few_correct = sum(1 for r in self.fewshot_responses if r.is_correct)

        print(f"\nAccuracy stats:")
        print(f"Zero-shot: {zero_correct}/{len(self.zeroshot_responses)} correct ({zero_correct/len(self.zeroshot_responses)*100:.1f}%)")
        print(f"Few-shot: {few_correct}/{len(self.fewshot_responses)} correct ({few_correct/len(self.fewshot_responses)*100:.1f}%)")

        # Find matching pairs and analyze improvements
        matching_pairs = []
        improved_count = 0
        degraded_count = 0
        same_count = 0

        for doc_id in set(zero_dict.keys()) & set(few_dict.keys()):
            zero_resp = zero_dict[doc_id]
            few_resp = few_dict[doc_id]
            pairs.append((zero_resp, few_resp))

            # Track changes in accuracy
            if not zero_resp.is_correct and few_resp.is_correct:
                improved_count += 1
            elif zero_resp.is_correct and not few_resp.is_correct:
                degraded_count += 1
            else:
                same_count += 1

        print(f"\nPair analysis:")
        print(f"Total matching pairs: {len(pairs)}")
        print(f"Improved (0→1): {improved_count} pairs")
        print(f"Degraded (1→0): {degraded_count} pairs")
        print(f"Same result: {same_count} pairs")

        # Print a few examples to verify correctness
        print("\nExample pair details:")
        for i, (zero, few) in enumerate(pairs[:3]):
            print(f"\nPair {i+1}:")
            print(f"Question: {zero.question[:100]}...")
            print(f"Zero-shot correct: {zero.is_correct}")
            print(f"Few-shot correct: {few.is_correct}")

        return pairs

    async def analyze_feature_differences(self, max_pairs: int = 50):
        """Analyze feature differences between zero-shot and few-shot responses"""
        print(f"\nAnalyzing up to {max_pairs} response pairs...")
        all_pairs = self.find_matching_pairs()

        # Split into correct/incorrect pairs
        improved_pairs = [(zero, few) for zero, few in all_pairs
                         if not zero.is_correct and few.is_correct]

        print(f"\nFound {len(improved_pairs)} pairs where few-shot improved accuracy")
        print("\nExample improvement cases:")
        for i, (zero, few) in enumerate(improved_pairs[:3]):
            print(f"\nImprovement Case {i+1}:")
            print(f"Question: {zero.question}")
            print(f"Correct answer: {zero.target}")
            print(f"\nZero-shot response (incorrect):\n{zero.response}")
            print(f"\nFew-shot response (correct):\n{few.response}")
            print("-" * 80)

        # Take only up to max_pairs
        improved_pairs = improved_pairs[:max_pairs]
        print(f"\nAnalyzing {len(improved_pairs)} improvement cases in detail...")

        # Format for contrast analysis
        improved_zero = [[
            {"role": "user", "content": zero.question},
            {"role": "assistant", "content": zero.response}
        ] for zero, few in improved_pairs]

        improved_few = [[
            {"role": "user", "content": few.question},
            {"role": "assistant", "content": few.response}
        ] for zero, few in improved_pairs]

        print("\nRunning contrastive analysis...")
        # Get features that distinguish incorrect zero-shot from correct few-shot
        zero_features, few_features = self.client.features.contrast(
            dataset_1=improved_zero,
            dataset_2=improved_few,
            model=self.variant,
            top_k=50  # Get more features initially
        )

        # Rerank for mathematical reasoning patterns
        few_features = self.client.features.rerank(
            features=few_features,
            query="step by step mathematical problem solving, arithmetic reasoning, and calculation patterns",
            model=self.variant,
            top_k=10
        )

        print("\nTop distinguishing features in few-shot responses:")
        for i, feat in enumerate(few_features, 1):
            # Skip features that are just numbers
            if not feat.label.startswith('feature_'):
                print(f"{i}. {feat.label}")

        print("\nAnalyzing feature activations...")
        activation_data = []
        significant_examples = []

        # Analyze each pair
        for pair_idx, (zero, few) in enumerate(tqdm(improved_pairs)):
            zero_context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": zero.question},
                    {"role": "assistant", "content": zero.response}
                ],
                model=self.variant,
                features=few_features
            )

            few_context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": few.question},
                    {"role": "assistant", "content": few.response}
                ],
                model=self.variant,
                features=few_features
            )

            # Get activations
            zero_acts = {act.feature.label: act.activation for act in zero_context.top(k=len(few_features))}
            few_acts = {act.feature.label: act.activation for act in few_context.top(k=len(few_features))}

            # Find significant differences
            diffs = {}
            for feature in few_features:
                label = feature.label
                if not label.startswith('feature_'):  # Skip numerical features
                    z_act = float(zero_acts.get(label, 0))
                    f_act = float(few_acts.get(label, 0))
                    diff = f_act - z_act
                    diffs[label] = diff

                    activation_data.append({
                        'pair_id': pair_idx,
                        'question': zero.question[:100] + "...",
                        'feature': label,
                        'zero_activation': z_act,
                        'few_activation': f_act,
                        'activation_diff': diff
                    })

            # Store significant examples
            max_diff_feature = max(diffs.items(), key=lambda x: abs(x[1]), default=(None, 0))
            if abs(max_diff_feature[1]) > 0.5:  # Threshold for significance
                significant_examples.append({
                    'pair_id': pair_idx,
                    'question': zero.question,
                    'feature': max_diff_feature[0],
                    'activation_diff': max_diff_feature[1],
                    'zero_response': zero.response[:200],
                    'few_response': few.response[:200]
                })

        df = pd.DataFrame(activation_data)

        # Print significant examples
        if significant_examples:
            print("\nMost significant feature activation differences:")
            for example in sorted(significant_examples,
                                key=lambda x: abs(x['activation_diff']),
                                reverse=True)[:3]:
                print(f"\nQuestion: {example['question']}")
                print(f"Feature: {example['feature']}")
                print(f"Activation difference: {example['activation_diff']:.3f}")
                print(f"Zero-shot: {example['zero_response']}...")
                print(f"Few-shot: {example['few_response']}...")

        return df

    def summarize_differences(self, df: pd.DataFrame):
        """Generate summary statistics and key findings"""
        if df.empty:
            print("No data to summarize")
            return {}

        print("\n=== Feature Activation Analysis ===")

        # Filter out features with no meaningful difference
        meaningful_diffs = df[abs(df['activation_diff']) > 0.1]
        if meaningful_diffs.empty:
            print("No meaningful feature activation differences found")
            return {}

        # Average activation differences by feature
        avg_diffs = meaningful_diffs.groupby('feature').agg({
            'activation_diff': ['mean', 'std', 'count']
        }).round(3)
        avg_diffs = avg_diffs.sort_values(('activation_diff', 'mean'), key=abs, ascending=False)

        print("\nMost significant feature differences (few-shot vs zero-shot):")
        for feature in avg_diffs.index[:5]:
            stats = avg_diffs.loc[feature]
            mean = stats[('activation_diff', 'mean')]
            std = stats[('activation_diff', 'std')]
            count = stats[('activation_diff', 'count')]
            print(f"\n{feature}:")
            print(f"  Mean difference: {mean:+.3f}")
            print(f"  Std deviation: {std:.3f}")
            print(f"  Found in {count} pairs")

        # Find questions with largest differences for each significant feature
        print("\nExample questions with largest feature differences:")
        for feature in avg_diffs.index[:3]:
            feature_data = df[df['feature'] == feature]
            max_diff_row = feature_data.loc[feature_data['activation_diff'].abs().idxmax()]
            print(f"\nFeature: {feature}")
            print(f"Question: {max_diff_row['question']}")
            print(f"Activation difference: {max_diff_row['activation_diff']:.3f}")
            print(f"Zero-shot activation: {max_diff_row['zero_activation']:.3f}")
            print(f"Few-shot activation: {max_diff_row['few_activation']:.3f}")

        return {
            'significant_features': avg_diffs.to_dict(),
            'feature_examples': meaningful_diffs.to_dict()
        }

def list_jsonl_files(directory='natural_language_rl'):
    """List all JSONL files in the specified directory"""
    drive_dir = os.path.join('/content/drive/MyDrive', directory)

    if not os.path.exists(drive_dir):
        print(f"Directory not found: {drive_dir}")
        return []

    jsonl_files = glob.glob(os.path.join(drive_dir, '*.jsonl'))
    return [os.path.basename(f) for f in jsonl_files]

# Initialize analyzer
analyzer = GSM8KAnalyzer(client, variant)

# List available files
files = list_jsonl_files()
if files:
    print("\nAvailable JSONL files:")
    for i, file in enumerate(files):
        print(f"{i+1}. {file}")

    # Get user input
    baseline_idx = int(input("\nEnter baseline file number: ")) - 1
    target_idx = int(input("Enter target file number: ")) - 1
    num_samples = int(input("Enter number of samples to analyze (default 10): ") or "10")

    if 0 <= baseline_idx < len(files) and 0 <= target_idx < len(files):
        baseline_file = files[baseline_idx]
        target_file = files[target_idx]

        print(f"\nAnalyzing:")
        print(f"Baseline: {baseline_file}")
        print(f"Target: {target_file}")
        print(f"Samples: {num_samples}")

        # Run analysis
        analyzer.load_responses(baseline_file, target_file)
        df = await analyzer.analyze_feature_differences(max_pairs=num_samples)
        results = analyzer.summarize_differences(df)

        # Save results to Drive
        output_path = '/content/drive/MyDrive/natural_language_rl/feature_analysis_results.csv'
        df.to_csv(output_path)
        print(f"\nResults saved to {output_path}")
else:
    print("No JSONL files found in the natural_language_rl directory")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
is_correct parsed as: True
Reading response 1022:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1023:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1024:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1025:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1026:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1027:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1028:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1029:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1030:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1031:
exact_match value: 0.0
is_correct parsed as: False
Reading response 1032:
exact_match value: 0.0
is_correct parsed as: False
Reading response 1033:
exact_match value: 0.0
is_correct parsed as: False
Reading response 1034:
exact_

  0%|          | 0/10 [00:00<?, ?it/s]


=== Feature Activation Analysis ===
No meaningful feature activation differences found

Results saved to /content/drive/MyDrive/natural_language_rl/feature_analysis_results.csv


In [None]:
import json
from dataclasses import dataclass
from typing import List, Dict, Optional
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
import numpy as np
import os
import glob

# First mount drive if not already mounted
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

@dataclass
class Response:
    question: str
    answer: str
    response: str
    is_correct: bool
    target: str
    doc_id: int

class GSM8KAnalyzer:
    def __init__(self, client, variant):
        self.client = client
        self.variant = variant
        self.zeroshot_responses: List[Response] = []
        self.fewshot_responses: List[Response] = []

    def load_responses(self, zeroshot_path: str, fewshot_path: str):
        """Load responses from JSONL files"""
        def load_file(path: str) -> List[Response]:
            responses = []
            with open(path, 'r') as f:
                for line in f:
                    data = json.loads(line)
                    # Extract the actual model response from the nested structure
                    model_response = data["resps"][0][0] if data["resps"] else ""
                    is_correct = data.get("exact_match", 0) == 1.0
                    print(f"Reading response {len(responses)+1}:")
                    print(f"exact_match value: {data.get('exact_match', 'not found')}")
                    print(f"is_correct parsed as: {is_correct}")
                    if len(responses) < 2:  # Print details for first few entries
                        print(f"Question: {data['doc']['question'][:100]}...")
                        print(f"Target: {data['target']}")
                        print("---")
                    responses.append(Response(
                        question=data["doc"]["question"],
                        answer=data["doc"]["answer"],
                        response=model_response,
                        is_correct=is_correct,
                        target=data["target"],
                        doc_id=data["doc_id"]
                    ))
            return responses

        # Get full paths in Drive
        zeroshot_full = os.path.join('/content/drive/MyDrive/natural_language_rl', zeroshot_path)
        fewshot_full = os.path.join('/content/drive/MyDrive/natural_language_rl', fewshot_path)

        print(f"Loading files from:\n{zeroshot_full}\n{fewshot_full}")
        self.zeroshot_responses = load_file(zeroshot_full)
        self.fewshot_responses = load_file(fewshot_full)

        print(f"Loaded {len(self.zeroshot_responses)} zero-shot responses")
        print(f"Loaded {len(self.fewshot_responses)} few-shot responses")

    def find_matching_pairs(self) -> List[tuple[Response, Response]]:
        """Find pairs of responses to the same question"""
        pairs = []
        zero_dict = {r.doc_id: r for r in self.zeroshot_responses}
        few_dict = {r.doc_id: r for r in self.fewshot_responses}

        # Calculate accuracy stats for each set
        zero_correct = sum(1 for r in self.zeroshot_responses if r.is_correct)
        few_correct = sum(1 for r in self.fewshot_responses if r.is_correct)

        print(f"\nAccuracy stats:")
        print(f"Zero-shot: {zero_correct}/{len(self.zeroshot_responses)} correct ({zero_correct/len(self.zeroshot_responses)*100:.1f}%)")
        print(f"Few-shot: {few_correct}/{len(self.fewshot_responses)} correct ({few_correct/len(self.fewshot_responses)*100:.1f}%)")

        # Find matching pairs and analyze improvements
        matching_pairs = []
        improved_count = 0
        degraded_count = 0
        same_count = 0

        for doc_id in set(zero_dict.keys()) & set(few_dict.keys()):
            zero_resp = zero_dict[doc_id]
            few_resp = few_dict[doc_id]
            pairs.append((zero_resp, few_resp))

            # Track changes in accuracy
            if not zero_resp.is_correct and few_resp.is_correct:
                improved_count += 1
            elif zero_resp.is_correct and not few_resp.is_correct:
                degraded_count += 1
            else:
                same_count += 1

        print(f"\nPair analysis:")
        print(f"Total matching pairs: {len(pairs)}")
        print(f"Improved (0→1): {improved_count} pairs")
        print(f"Degraded (1→0): {degraded_count} pairs")
        print(f"Same result: {same_count} pairs")

        # Print a few examples to verify correctness
        print("\nExample pair details:")
        for i, (zero, few) in enumerate(pairs[:3]):
            print(f"\nPair {i+1}:")
            print(f"Question: {zero.question[:100]}...")
            print(f"Zero-shot correct: {zero.is_correct}")
            print(f"Few-shot correct: {few.is_correct}")

        return pairs

    async def analyze_feature_differences(self, max_pairs: int = 50):
        """Analyze feature differences between zero-shot and few-shot responses"""
        print(f"\nAnalyzing up to {max_pairs} response pairs...")
        all_pairs = self.find_matching_pairs()

        # Split into correct/incorrect pairs
        improved_pairs = [(zero, few) for zero, few in all_pairs
                         if not zero.is_correct and few.is_correct]

        print(f"\nFound {len(improved_pairs)} pairs where few-shot improved accuracy")
        print("\nExample improvement cases:")
        for i, (zero, few) in enumerate(improved_pairs[:3]):
            print(f"\nImprovement Case {i+1}:")
            print(f"Question: {zero.question}")
            print(f"Correct answer: {zero.target}")
            print(f"\nZero-shot response (incorrect):\n{zero.response}")
            print(f"\nFew-shot response (correct):\n{few.response}")
            print("-" * 80)

        # Take only up to max_pairs
        improved_pairs = improved_pairs[:max_pairs]
        print(f"\nAnalyzing {len(improved_pairs)} improvement cases in detail...")

        # Format for contrast analysis
        improved_zero = [[
            {"role": "user", "content": zero.question},
            {"role": "assistant", "content": zero.response}
        ] for zero, few in improved_pairs]

        improved_few = [[
            {"role": "user", "content": few.question},
            {"role": "assistant", "content": few.response}
        ] for zero, few in improved_pairs]

        print("\nRunning contrastive analysis...")
        # Get features that distinguish incorrect zero-shot from correct few-shot
        zero_features, few_features = self.client.features.contrast(
            dataset_1=improved_zero,
            dataset_2=improved_few,
            model=self.variant,
            top_k=50  # Get more features initially
        )

        # Rerank for mathematical reasoning patterns
        few_features = self.client.features.rerank(
            features=few_features,
            query="step by step mathematical problem solving, arithmetic reasoning, and calculation patterns",
            model=self.variant,
            top_k=10
        )

        print("\nTop distinguishing features in few-shot responses:")
        for i, feat in enumerate(few_features, 1):
            # Skip features that are just numbers
            if not feat.label.startswith('feature_'):
                print(f"{i}. {feat.label}")

        print("\nAnalyzing feature activations...")
        activation_data = []
        significant_examples = []

        # Analyze each pair
        for pair_idx, (zero, few) in enumerate(tqdm(improved_pairs)):
            zero_context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": zero.question},
                    {"role": "assistant", "content": zero.response}
                ],
                model=self.variant,
                features=few_features
            )

            few_context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": few.question},
                    {"role": "assistant", "content": few.response}
                ],
                model=self.variant,
                features=few_features
            )

            # Get activations
            zero_acts = {act.feature.label: act.activation for act in zero_context.top(k=len(few_features))}
            few_acts = {act.feature.label: act.activation for act in few_context.top(k=len(few_features))}

            # Find significant differences
            diffs = {}
            for feature in few_features:
                label = feature.label
                if not label.startswith('feature_'):  # Skip numerical features
                    z_act = float(zero_acts.get(label, 0))
                    f_act = float(few_acts.get(label, 0))
                    diff = f_act - z_act
                    diffs[label] = diff

                    activation_data.append({
                        'pair_id': pair_idx,
                        'question': zero.question[:100] + "...",
                        'feature': label,
                        'zero_activation': z_act,
                        'few_activation': f_act,
                        'activation_diff': diff
                    })

            # Store significant examples
            max_diff_feature = max(diffs.items(), key=lambda x: abs(x[1]), default=(None, 0))
            if abs(max_diff_feature[1]) > 0.5:  # Threshold for significance
                significant_examples.append({
                    'pair_id': pair_idx,
                    'question': zero.question,
                    'feature': max_diff_feature[0],
                    'activation_diff': max_diff_feature[1],
                    'zero_response': zero.response[:200],
                    'few_response': few.response[:200]
                })

        df = pd.DataFrame(activation_data)

        # Print significant examples
        if significant_examples:
            print("\nMost significant feature activation differences:")
            for example in sorted(significant_examples,
                                key=lambda x: abs(x['activation_diff']),
                                reverse=True)[:3]:
                print(f"\nQuestion: {example['question']}")
                print(f"Feature: {example['feature']}")
                print(f"Activation difference: {example['activation_diff']:.3f}")
                print(f"Zero-shot: {example['zero_response']}...")
                print(f"Few-shot: {example['few_response']}...")

        return df

    def summarize_differences(self, df: pd.DataFrame):
        """Generate summary statistics and key findings"""
        if df.empty:
            print("No data to summarize")
            return {}

        print("\n=== Feature Activation Analysis ===")

        # Filter out features with no meaningful difference
        meaningful_diffs = df[abs(df['activation_diff']) > 0.1]
        if meaningful_diffs.empty:
            print("No meaningful feature activation differences found")
            return {}

        # Average activation differences by feature
        avg_diffs = meaningful_diffs.groupby('feature').agg({
            'activation_diff': ['mean', 'std', 'count']
        }).round(3)
        avg_diffs = avg_diffs.sort_values(('activation_diff', 'mean'), key=abs, ascending=False)

        print("\nMost significant feature differences (few-shot vs zero-shot):")
        for feature in avg_diffs.index[:5]:
            stats = avg_diffs.loc[feature]
            mean = stats[('activation_diff', 'mean')]
            std = stats[('activation_diff', 'std')]
            count = stats[('activation_diff', 'count')]
            print(f"\n{feature}:")
            print(f"  Mean difference: {mean:+.3f}")
            print(f"  Std deviation: {std:.3f}")
            print(f"  Found in {count} pairs")

        # Find questions with largest differences for each significant feature
        print("\nExample questions with largest feature differences:")
        for feature in avg_diffs.index[:3]:
            feature_data = df[df['feature'] == feature]
            max_diff_row = feature_data.loc[feature_data['activation_diff'].abs().idxmax()]
            print(f"\nFeature: {feature}")
            print(f"Question: {max_diff_row['question']}")
            print(f"Activation difference: {max_diff_row['activation_diff']:.3f}")
            print(f"Zero-shot activation: {max_diff_row['zero_activation']:.3f}")
            print(f"Few-shot activation: {max_diff_row['few_activation']:.3f}")

        return {
            'significant_features': avg_diffs.to_dict(),
            'feature_examples': meaningful_diffs.to_dict()
        }

def list_jsonl_files(directory='natural_language_rl'):
    """List all JSONL files in the specified directory"""
    drive_dir = os.path.join('/content/drive/MyDrive', directory)

    if not os.path.exists(drive_dir):
        print(f"Directory not found: {drive_dir}")
        return []

    jsonl_files = glob.glob(os.path.join(drive_dir, '*.jsonl'))
    return [os.path.basename(f) for f in jsonl_files]

# Initialize analyzer
analyzer = GSM8KAnalyzer(client, variant)

# List available files
files = list_jsonl_files()
if files:
    print("\nAvailable JSONL files:")
    for i, file in enumerate(files):
        print(f"{i+1}. {file}")

    # Get user input
    baseline_idx = int(input("\nEnter baseline file number: ")) - 1
    target_idx = int(input("Enter target file number: ")) - 1
    num_samples = int(input("Enter number of samples to analyze (default 10): ") or "10")

    if 0 <= baseline_idx < len(files) and 0 <= target_idx < len(files):
        baseline_file = files[baseline_idx]
        target_file = files[target_idx]

        print(f"\nAnalyzing:")
        print(f"Baseline: {baseline_file}")
        print(f"Target: {target_file}")
        print(f"Samples: {num_samples}")

        # Run analysis
        analyzer.load_responses(baseline_file, target_file)
        df = await analyzer.analyze_feature_differences(max_pairs=num_samples)
        results = analyzer.summarize_differences(df)

        # Save results to Drive
        output_path = '/content/drive/MyDrive/natural_language_rl/feature_analysis_results.csv'
        df.to_csv(output_path)
        print(f"\nResults saved to {output_path}")
else:
    print("No JSONL files found in the natural_language_rl directory")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
is_correct parsed as: True
Reading response 1022:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1023:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1024:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1025:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1026:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1027:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1028:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1029:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1030:
exact_match value: 1.0
is_correct parsed as: True
Reading response 1031:
exact_match value: 0.0
is_correct parsed as: False
Reading response 1032:
exact_match value: 0.0
is_correct parsed as: False
Reading response 1033:
exact_match value: 0.0
is_correct parsed as: False
Reading response 1034:
exact_

  0%|          | 0/20 [00:00<?, ?it/s]


=== Feature Activation Analysis ===
No meaningful feature activation differences found

Results saved to /content/drive/MyDrive/natural_language_rl/feature_analysis_results.csv


# Cell 5: Execution
Main execution and result presentation.
- Runs solver on specified number of puzzles
- Displays detailed solution analysis
- Shows overall statistics and patterns
Usage: solutions, analysis = run_solver(client, variant, dataset_name, num_puzzles)

In [None]:
# After your initial setup cell, add the GSM8KAnalyzer class

# Then create an instance and set up the display
analyzer = GSM8KAnalyzer(client, variant)

# Create the React display
display(GSM8KAnalyzer())

# Function to handle analysis when triggered from the UI
async def run_analysis(baseline_file: str, target_file: str, num_samples: int):
    analyzer.load_responses(f'natural_language_rl/{baseline_file}', f'natural_language_rl/{target_file}')
    df = await analyzer.analyze_feature_differences(max_pairs=num_samples)
    results = analyzer.summarize_differences(df)
    return df.to_dict('records')  # Convert to format expected by visualization

# Run Autosteer (old version w vis, deprecated)

In [None]:
#####################################
# Part 1: Skill Discovery
#####################################

class ExampleManager:
    """Manages and filters training examples"""

    def __init__(self, train_set, validation_set, test_set):
        self.train_set = train_set
        self.validation_set = validation_set
        self.test_set = test_set
        print(f"\nInitialized ExampleManager with:")
        print(f"- {len(train_set)} training examples")
        print(f"- {len(validation_set)} validation examples")
        print(f"- {len(test_set)} test examples")

        # Debug: Print structure of first example
        print("\nExample data structure:")
        example = train_set[0]
        print(json.dumps(example, indent=2))

    def prepare_examples(self, dataset, max_examples: int = 100) -> List[Example]:
        """Convert dataset examples into Example objects"""
        print(f"\nPreparing up to {max_examples} examples...")
        examples = []

        # Convert to list to use slicing
        dataset_subset = dataset.select(range(min(max_examples, len(dataset))))

        for i, item in enumerate(tqdm(dataset_subset, desc="Converting examples")):
            # Print first few examples to verify conversion
            if i < 3:
                print(f"\nExample {i + 1}:")
                print(f"Words: {item['text']}")
                print(f"Category: {item['label']}")

            example = Example(
                query=f"puzzle: Consider these words and explain how they might be related: {item['text']}",
                response=f"These words belong to the category: {item['label']}",
                is_correct=True
            )
            examples.append(example)

            if i == 3:
                print("\n... (continuing conversion)")

        print(f"\nPrepared {len(examples)} examples total")
        return examples

class SkillDiscoverer:
    """Discovers reasoning skills using AutoSteer"""

    def __init__(self, client, variant):
        self.client = client
        self.variant = variant
        self.discovered_skills = None
        self.activation_logs = {}

    async def discover_skills(self, correct_examples: List[Example], incorrect_examples: List[Example], top_k: int = 5):
        """Use AutoSteer to discover skills that distinguish correct from incorrect examples"""
        print("\n=== Skill Discovery and Analysis ===")

        # Convert examples to message format
        print("\nFormatting examples for contrast...")
        print("\nCorrect examples:")
        correct_messages = []
        for i, ex in enumerate(correct_examples):
            messages = [
                {"role": "user", "content": ex.query},
                {"role": "assistant", "content": ex.response}
            ]
            correct_messages.append(messages)
            if i < 3:  # Show first 3 examples
                print(f"\nExample {i+1}:")
                print(f"Query: {ex.query}")
                print(f"Response: {ex.response}")

        print("\nIncorrect examples:")
        incorrect_messages = []
        for i, ex in enumerate(incorrect_examples):
            messages = [
                {"role": "user", "content": ex.query},
                {"role": "assistant", "content": ex.response}
            ]
            incorrect_messages.append(messages)
            if i < 3:  # Show first 3 examples
                print(f"\nExample {i+1}:")
                print(f"Query: {ex.query}")
                print(f"Response: {ex.response}")

        # Original skill discovery
        print(f"\nFinding contrasting features between {len(correct_messages)} correct and {len(incorrect_messages)} incorrect examples...")
        _, skill_features = self.client.features.contrast(
            dataset_1=incorrect_messages,
            dataset_2=correct_messages,
            model=self.variant,
            top_k=top_k*2
        )
        print(f"\nFound {len(skill_features)} initial features:")
        for i, feat in enumerate(skill_features):
            print(f"{i+1}. {feat.label}")

        # Rerank for task relevance
        print("\nReranking features...")
        self.discovered_skills = self.client.features.rerank(
            features=skill_features,
            query="pattern recognition and categorization for word puzzles",
            model=self.variant,
            top_k=top_k
        )

        # Analyze example activations
        print("\n=== Feature Activation Analysis ===")
        for i, ex in enumerate(tqdm(correct_examples + incorrect_examples, desc="Analyzing examples")):
            context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": ex.query},
                    {"role": "assistant", "content": ex.response}
                ],
                model=self.variant,
                features=self.discovered_skills
            )

            activations = context.top(k=len(self.discovered_skills))

            # Store activation data
            self.activation_logs[i] = {
                'query': ex.query,
                'response': ex.response,
                'is_correct': i < len(correct_examples),
                'activations': [(act.feature.label, float(act.activation)) for act in activations]
            }

            # Print analysis for this example
            print(f"\nExample {i+1}:")
            print(f"Query: {ex.query[:100]}...")
            print(f"{'Correct' if i < len(correct_examples) else 'Incorrect'} example")
            print("Top Feature Activations:")
            for feat, strength in sorted(self.activation_logs[i]['activations'],
                                      key=lambda x: abs(x[1]),
                                      reverse=True):
                print(f"  {feat}: {strength:.3f}")

        # Calculate feature statistics
        print("\n=== Feature Importance Summary ===")
        feature_stats = {}
        for feat in self.discovered_skills:
            activations = []
            correct_activations = []
            incorrect_activations = []

            for log in self.activation_logs.values():
                for f, strength in log['activations']:
                    if f == feat.label:
                        activations.append(strength)
                        if log['is_correct']:
                            correct_activations.append(strength)
                        else:
                            incorrect_activations.append(strength)

            avg_activation = np.mean(activations) if activations else 0
            avg_correct = np.mean(correct_activations) if correct_activations else 0
            avg_incorrect = np.mean(incorrect_activations) if incorrect_activations else 0

            print(f"\nFeature: {feat.label}")
            print(f"Average activation: {avg_activation:.3f}")
            print(f"Average for correct examples: {avg_correct:.3f}")
            print(f"Average for incorrect examples: {avg_incorrect:.3f}")
            print(f"Activation difference: {(avg_correct - avg_incorrect):.3f}")

        return self.discovered_skills


#####################################
# Main Pipeline
#####################################

async def run_skill_discovery(num_samples: int = 20):
    """Run the skill discovery part and return both skills and the discoverer instance"""
    print(f"\n=== Starting Skill Discovery Pipeline (using {num_samples} samples) ===")

    # Initialize components
    print("\nInitializing components...")
    example_manager = ExampleManager(train_set, validation_set, test_set)
    skill_discoverer = SkillDiscoverer(client, variant)

    # Prepare training examples
    train_examples = example_manager.prepare_examples(train_set, max_examples=num_samples)

    # Split examples
    mid_point = len(train_examples) // 2
    correct_examples = train_examples[:mid_point]
    incorrect_examples = train_examples[mid_point:]

    print(f"\nSplit examples into {len(correct_examples)} correct and {len(incorrect_examples)} incorrect")

    # Discover skills
    skills = await skill_discoverer.discover_skills(correct_examples, incorrect_examples)

    print("\nDiscovered Skills:")
    for i, skill in enumerate(skills, 1):
        print(f"{i}. {skill}")
        print(f"   Label: {skill.label}")
        print(f"   UUID: {skill.uuid}")

    print("\n=== Skill Discovery Complete ===")

    return skills, skill_discoverer  # Return both

# Run first step
if __name__ == "__main__":
    num_samples = 20  # Adjust this number as needed
    discovered_skills, skill_discoverer = await run_skill_discovery(num_samples)


In [None]:
#####################################
# Step 3: Visualization and Analysis
#####################################

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display, HTML

def create_analysis_dataframe(activation_logs):
    """Convert activation logs to a pandas DataFrame for analysis"""
    rows = []

    for example_id, log in activation_logs.items():
        # For each feature activation in this example
        for feature, activation in log['activations']:
            rows.append({
                'example_id': example_id,
                'query': log['query'],
                'response': log['response'],
                'is_correct': log['is_correct'],
                'feature': feature,
                'activation': activation
            })

    return pd.DataFrame(rows)

def visualize_results(skill_discoverer):
    """Create visualizations and interactive tables from results"""

    # Convert to DataFrame
    df = create_analysis_dataframe(skill_discoverer.activation_logs)

    # Save to CSV
    df.to_csv('feature_analysis.csv', index=False)
    print("Saved detailed results to feature_analysis.csv")

    # 1. Feature Activation Heatmap
    pivot_df = df.pivot_table(
        values='activation',
        index='example_id',
        columns='feature',
        aggfunc='first'
    )

    fig1 = px.imshow(
        pivot_df,
        title='Feature Activation Heatmap',
        labels=dict(x='Feature', y='Example ID', color='Activation Strength'),
        aspect='auto'
    )
    fig1.show()

    # 2. Feature Importance Bar Chart
    avg_activations = df.groupby('feature')['activation'].agg(['mean', 'std']).reset_index()
    fig2 = px.bar(
        avg_activations,
        x='feature',
        y='mean',
        error_y='std',
        title='Average Feature Activation Strength',
        labels={'mean': 'Average Activation', 'feature': 'Feature'}
    )
    fig2.update_layout(xaxis_tickangle=45)
    fig2.show()

    # 3. Interactive Table
    def generate_interactive_table(df):
        # Create unique example IDs for filtering
        examples = df[['example_id', 'query', 'response', 'is_correct']].drop_duplicates()
        features = df['feature'].unique()

        # Create HTML for filtering controls
        filter_html = f"""
        <div style="margin-bottom: 20px;">
            <h3>Filters:</h3>
            <select id="example_filter" onchange="filterTable()">
                <option value="all">All Examples</option>
                {''.join(f'<option value="{i}">Example {i+1}</option>' for i in examples['example_id'])}
            </select>

            <select id="feature_filter" onchange="filterTable()">
                <option value="all">All Features</option>
                {''.join(f'<option value="{f}">{f}</option>' for f in features)}
            </select>

            <select id="correct_filter" onchange="filterTable()">
                <option value="all">All Results</option>
                <option value="correct">Correct Only</option>
                <option value="incorrect">Incorrect Only</option>
            </select>
        </div>
        """

        # Create table
        table_html = df.to_html(classes='data-table', index=False)

        # Add styling and JavaScript
        return f"""
        <style>
            .data-table {{
                width: 100%;
                border-collapse: collapse;
            }}
            .data-table th, .data-table td {{
                padding: 8px;
                border: 1px solid #ddd;
            }}
            .data-table tr:nth-child(even) {{
                background-color: #f9f9f9;
            }}
            .filter-controls {{
                margin-bottom: 20px;
            }}
            select {{
                margin-right: 10px;
                padding: 5px;
            }}
        </style>
        {filter_html}
        {table_html}
        <script>
        function filterTable() {{
            var example = document.getElementById('example_filter').value;
            var feature = document.getElementById('feature_filter').value;
            var correct = document.getElementById('correct_filter').value;

            var rows = document.querySelectorAll('.data-table tbody tr');
            rows.forEach(function(row) {{
                var showRow = true;

                if (example !== 'all' && row.cells[0].textContent !== example) showRow = false;
                if (feature !== 'all' && row.cells[4].textContent !== feature) showRow = false;
                if (correct !== 'all') {{
                    var isCorrect = row.cells[3].textContent === 'True';
                    if (correct === 'correct' && !isCorrect) showRow = false;
                    if (correct === 'incorrect' && isCorrect) showRow = false;
                }}

                row.style.display = showRow ? '' : 'none';
            }});
        }}
        </script>
        """

    # Display interactive table
    display(HTML(generate_interactive_table(df)))

# Usage (add to main):
print("\nGenerating visualizations...")
visualize_results(skill_discoverer)