<a href="https://colab.research.google.com/github/menhguin/natural_language_rl/blob/main/Another_copy_of_Natural_Language_RL_Common_solver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup 1: Install

In [1]:
"""NLRL_with_Goodfire.ipynb

Natural Language Reinforcement Learning using Goodfire SDK
Following the steps outlined in the paper:
1. Query Selection & Answer Verification
2. COT Decomposition & Critical Token Identification
3. Feature Attribution & Analysis
4. Feature Steering
5. Robustness Testing & Iteration
"""

#####################################
# Step 0: Setup and Initialization
#####################################

# Install required packages
!pip install goodfire --quiet
!pip install datasets --quiet
!pip install tqdm --quiet

from google.colab import userdata
import goodfire
import asyncio
from tqdm.auto import tqdm
import pandas as pd
from typing import List, Dict, Any
import json
from typing import List
from dataclasses import dataclass

@dataclass
class Example:
    query: str
    response: str
    is_correct: bool

# Add your Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')


# Initialize SDK clients
client = goodfire.Client(GOODFIRE_API_KEY)
async_client = goodfire.AsyncClient(GOODFIRE_API_KEY)

# Instantiate a model variant
variant = goodfire.Variant("meta-llama/Meta-Llama-3.1-8B-Instruct")

# for my reference: Which is bigger, 9.9 or 9.11?
# https://platform.goodfire.ai/chat/ab5cc226-aa37-404c-87fd-4fb7949944c7
# https://platform.goodfire.ai/chat/0a670b58-0b76-44c1-a7cb-8356347d83c9

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.6/40.6 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gensim 4.3.3 requires scipy<1.14.0,>=1.7.0, but you have scipy 1.15.0 which is incompatible.[0m[31m


# Setup 2: Load dataset

In [2]:
#####################################
# Cell 1: Load and Prepare Dataset
#####################################

from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm.auto import tqdm

# Load all datasets
dataset = load_dataset("anishthalamati/nyt-connections")
train_set = dataset['train']
validation_set = dataset['validation']
test_set = dataset['test']

# Print sample from each dataset to verify structure
print("\nTrain set sample:")
print(train_set[0])
print(f"Train set size: {len(train_set)}")

print("\nValidation set sample:")
print(validation_set[0])
print(f"Validation set size: {len(validation_set)}")

print("\nTest set sample:")
print(test_set[0])
print(f"Test set size: {len(test_set)}")

README.md:   0%|          | 0.00/507 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/28.5k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/5.00k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/732 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/91 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/92 [00:00<?, ? examples/s]


Train set sample:
{'label': 'MUSICAL INSTRUMENTS', 'text': 'BASS BASSOON HARP RECORDER'}
Train set size: 732

Validation set sample:
{'label': 'ANIMAL GROUP NAMES', 'text': 'COLONY HERD PRIDE SWARM'}
Validation set size: 91

Test set sample:
{'label': 'HATS', 'text': 'BERET BOWLER FEDORA FEZ'}
Test set size: 92


In [3]:
from typing import List, Dict, Set, Optional
import random

def combine_puzzle_groups(dataset_dict, indices, split='train') -> List[Dict]:
    """
    Combine 4 consecutive groups into full 16-word puzzles
    Returns list of puzzles, each with all words and group labels
    """
    dataset = dataset_dict[split]
    puzzles = []

    # Get indices in groups of 4
    for i in range(0, len(indices), 4):
        if i + 4 <= len(indices):
            # Get 4 consecutive groups
            groups = [dataset[indices[j]] for j in range(i, i+4)]

            # Combine into one puzzle
            all_words = []
            group_labels = {}
            for group in groups:
                words = group['text'].split()
                label = group['label']
                group_labels[label] = words
                all_words.extend(words)

            # Shuffle the combined words
            random.shuffle(all_words)

            puzzles.append({
                'words': all_words,
                'groups': group_labels
            })

    return puzzles


def extract_groups_from_attempt(attempt: str) -> List[Set[str]]:
    """Try to extract word groups from model's attempt"""
    groups = []
    current_group = set()
    for line in attempt.split('\n'):
        # Look for lines with words in all caps
        words = [w for w in line.split() if w.isupper() and not w.startswith('*')]
        if words:
            current_group.update(words)
        # When we see a new group marker or reasoning section, save the current group
        if ('Group' in line or 'Correct groupings:' in line) and current_group:
            if len(current_group) == 4:  # Only keep complete groups
                groups.append(current_group)
            current_group = set()
    return groups

def grade_attempt(client, variant, attempt: str, puzzle: Dict) -> Dict:
    """Use model to grade the solution attempt"""
    correct_groups = puzzle['groups']
    attempted_groups = extract_groups_from_attempt(attempt)

    # Format the grading prompt piece by piece
    grading_prompt = "Please grade this solution attempt for a NYT Connections puzzle. The task was to group 16 words into 4 groups of 4 related words.\n\n"

    grading_prompt += "Correct groupings:\n"
    for label, words in correct_groups.items():
        grading_prompt += f"{label}: {' '.join(words)}\n"

    grading_prompt += "\nModel's attempted groupings:\n"
    for i, group in enumerate(attempted_groups):
        grading_prompt += f"Group {i+1}: {' '.join(sorted(group))}\n"

    grading_prompt += "\nPlease analyze:\n"
    grading_prompt += "1. How many groups were correctly identified? (A group is correct if all 4 words match a correct group)\n"
    grading_prompt += "2. For incorrect groups, what went wrong in the reasoning?\n"
    grading_prompt += "3. Overall score out of 10\n\n"
    grading_prompt += "Be specific but concise."

    try:
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": grading_prompt}],
            model=variant,
            stream=False,
            max_completion_tokens=500
        )
        grade = response.choices[0].message["content"]
    except Exception as e:
        grade = f"Error grading: {e}"

    # Calculate objective metrics
    correct_sets = [set(words) for words in correct_groups.values()]
    num_correct = sum(1 for group in attempted_groups if group in correct_sets)

    return {
        'grade': grade,
        'num_correct_groups': num_correct,
        'attempted_groups': attempted_groups,
        'correct_groups': correct_groups
    }

def solve_puzzle(client, variant, puzzle):
    """Attempt to solve a single puzzle with the model and grade the result"""
    prompt = format_puzzle_for_prompt(puzzle)

    print("\n=== New Puzzle Attempt ===")
    print(f"Words: {' '.join(puzzle['words'])}")
    print(f"\nModel's solution attempt:")

    try:
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=variant,
            stream=False,
            max_completion_tokens=1000
        )
        solution_attempt = response.choices[0].message["content"]
        print(solution_attempt)
    except Exception as e:
        print(f"Error getting model response: {e}")
        solution_attempt = "Error: Failed to get model response"

    print("\nCorrect groupings:")
    for label, words in puzzle['groups'].items():
        print(f"{label}: {' '.join(words)}")

    # Grade the attempt
    print("\nGrading:")
    grade_result = grade_attempt(client, variant, solution_attempt, puzzle)
    print(grade_result['grade'])

    return {
        'puzzle': puzzle,
        'attempt': solution_attempt,
        'grade': grade_result
    }

def format_puzzle_for_prompt(puzzle) -> str:
    """Format a puzzle into a clear prompt for the model"""
    word_display = ' '.join(puzzle['words'])

    return f"""Here are 16 words that form 4 different groups of 4 related words each:

{word_display}

Think step by step:
1. Look for patterns among the words
2. Form exactly 4 groups of 4 words each
3. Each word must be used exactly once

For each group you identify, clearly list the 4 words in capital letters and explain why they go together."""

# Rest of the code (combine_puzzle_groups and run_connection_solver) remains the same

def run_connection_solver(client, variant, dataset_dict, num_puzzles: int = 1, random_select: bool = False):
    """Run solver on multiple puzzles and collect results"""
    # Get indices for the full dataset
    dataset = dataset_dict['train']
    all_indices = list(range(len(dataset)))

    # Select indices either randomly or sequentially
    if random_select:
        # Need 4x num_puzzles since each puzzle needs 4 groups
        selected_indices = random.sample(all_indices, num_puzzles * 4)
    else:
        selected_indices = all_indices[:num_puzzles * 4]

    # Combine groups into full puzzles
    puzzles = combine_puzzle_groups(dataset_dict, selected_indices)

    # Solve each puzzle
    results = []
    for puzzle in puzzles:
        result = solve_puzzle(client, variant, puzzle)
        results.append(result)

    print(f"\n=== Summary ===")
    print(f"Completed {len(results)} puzzles")

    return results

# Usage:
# dataset = load_dataset("anishthalamati/nyt-connections")
# results = run_connection_solver(client, variant, dataset)

In [4]:
# Example usage:
dataset = load_dataset("anishthalamati/nyt-connections")
results = run_connection_solver(client, variant, dataset)


=== New Puzzle Attempt ===
Words: FOSTER SHOOT CORONA FLARE BUD RADIATION BASSOON BASS NURSE SPROUT HARP RECORDER LIGHT REAR RAISE BLOOM

Model's solution attempt:
Let's break down the words step by step.

First, I'll look for patterns among the words. I'll start by examining the words for any common themes or associations.

After examining the words, I've identified some potential patterns:

- Some words seem to be related to music: BASSOON, BASS, HARP, and RECORDER.
- Some words seem to be related to growth or development: FOSTER, SPROUT, BLOOM, and RAISE.
- Some words seem to be related to light or fire: FLARE, CORONA, RADIATION, and LIGHT.
- Some words seem to be related to a specific context or setting: SHOOT, REAR, NURSE, and BUD.

Now, let's form exactly 4 groups of 4 words each, using each word exactly once.

Group 1: 
**BASSOON BASS HARP RECORDER**
These words are all types of musical instruments, with the BASSOON being a woodwind, BASS being a stringed instrument, HARP being

In [15]:
def decompose_cot(client, variant, solution_attempt: str) -> Dict:
    """Break down a solution attempt into individual reasoning steps"""

    decompose_prompt = f"""Given this solution attempt for a word grouping puzzle, break it down into numbered reasoning steps.
Each step should be a single logical decision or observation. Include both the thought process and any conclusions reached.

Original solution:
{solution_attempt}

Extract and number each reasoning step, being sure to preserve the exact wording used for groups and words. Format as:
1. [reasoning step 1]
2. [reasoning step 2]
etc.

Important: Keep all original group decisions and word groupings exactly as given, just break down the reasoning into clear steps. Include only actual reasoning steps, not introductory statements like Here Are The Steps"""

    try:
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": decompose_prompt}],
            model=variant,
            stream=False,
            max_completion_tokens=2000
        )
        decomposition = response.choices[0].message["content"]

        # Extract steps into a structured format
        steps = []
        current_step = ""
        for line in decomposition.split('\n'):
            if line.strip():
                # Check if line starts with a number followed by period
                if line[0].isdigit() and '.' in line[:3]:
                    if current_step:
                        steps.append(current_step.strip())
                    current_step = line.split('.', 1)[1].strip()
                else:
                    current_step += " " + line.strip()
        if current_step:
            steps.append(current_step.strip())

        return {
            'original': solution_attempt,
            'decomposed_steps': steps,
            'raw_decomposition': decomposition
        }
    except Exception as e:
        print(f"Error in decomposition: {e}")
        return {
            'original': solution_attempt,
            'decomposed_steps': [],
            'raw_decomposition': f"Error: {str(e)}"
        }

def analyze_reasoning_patterns(puzzle_results):
    """Analyze patterns in reasoning steps across multiple solutions"""
    all_steps = []
    for result in puzzle_results:
        if result and 'decomposed_steps' in result:
            all_steps.extend(result['decomposed_steps'])

    return {
        'total_solutions': len(puzzle_results),
        'avg_steps_per_solution': len(all_steps) / len(puzzle_results) if puzzle_results else 0,
        'steps_for_analysis': all_steps
    }

def get_correct_cot(client, variant, puzzle):
    """Get and decompose the correct solution's reasoning"""
    # Format puzzle with correct answers revealed
    correct_groups = puzzle['groups']
    correct_prompt = f"""Here are the correct groupings for these words:
{' '.join(puzzle['words'])}

Correct groups:
{json.dumps(correct_groups, indent=2)}

Please explain the reasoning process that leads to these correct groupings. Break down:
1. What patterns should be noticed first
2. How the connections between words become apparent
3. Why each word belongs in its group
4. What makes these the optimal groupings"""

    try:
        correct_response = client.chat.completions.create(
            messages=[{"role": "user", "content": correct_prompt}],
            model=variant,
            stream=False,
            max_completion_tokens=3000
        )
        correct_reasoning = correct_response.choices[0].message["content"]

        # Decompose the correct reasoning
        decomposed_correct = decompose_cot(client, variant, correct_reasoning)

        return {
            'correct_cot': correct_reasoning,
            'decomposed_steps': decomposed_correct['decomposed_steps']
        }
    except Exception as e:
        print(f"Error getting correct reasoning: {e}")
        return None

def solve_puzzle(client, variant, puzzle):
    """Attempt to solve puzzle with model and analyze reasoning"""
    # Get basic puzzle solution
    prompt = format_puzzle_for_prompt(puzzle)

    try:
        print("Getting naive solution...")
        # Get initial naive solution
        naive_response = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=variant,
            stream=False,
            max_completion_tokens=1000
        )
        naive_solution = naive_response.choices[0].message["content"]
        print(f"Got naive solution of length: {len(naive_solution)}")

        print("Expanding into COT...")
        # Expand naive solution into COT
        cot_prompt = f"""Given this initial solution attempt for a word grouping puzzle, expand it into a detailed chain-of-thought reasoning:

Original solution:
{naive_solution}

Please provide a step-by-step explanation of:
1. What patterns you noticed first
2. How you formed each group
3. Why words belong together
4. What made you confident or uncertain about each grouping

Explain your reasoning thoroughly."""

        cot_response = client.chat.completions.create(
            messages=[{"role": "user", "content": cot_prompt}],
            model=variant,
            stream=False,
            max_completion_tokens=3000
        )
        cot_reasoning = cot_response.choices[0].message["content"]
        print("Got COT reasoning")

        print("Getting correct solution COT...")
        # Get correct solution COT
        correct_cot = get_correct_cot(client, variant, puzzle)
        print("Got correct COT")

        # Decompose both COTs
        print("Decomposing naive COT...")
        naive_decomposed = decompose_cot(client, variant, cot_reasoning)
        print(f"Got {len(naive_decomposed['decomposed_steps'])} naive steps")

        # Store results
        result = {
            'puzzle': puzzle,
            'naive_solution': naive_solution,
            'naive_cot': cot_reasoning,
            'naive_steps': naive_decomposed['decomposed_steps'],
            'correct_cot': correct_cot['correct_cot'] if correct_cot else None,
            'correct_steps': correct_cot['decomposed_steps'] if correct_cot else None
        }

        # Grade the attempt
        print("Grading attempt...")
        grade_result = grade_attempt(client, variant, naive_solution, puzzle)
        result['grade'] = grade_result
        print("Finished grading")

        return result

    except Exception as e:
        print(f"Error in puzzle solving: {str(e)}")
        traceback.print_exc()
        return None

def run_connection_solver(client, variant, dataset_dict, num_puzzles: int = 1, random_select: bool = False):
    """Run solver and analyze reasoning patterns"""
    # Get indices for the full dataset
    dataset = dataset_dict['train']
    all_indices = list(range(len(dataset)))

    # Select indices either randomly or sequentially
    if random_select:
        selected_indices = random.sample(all_indices, num_puzzles * 4)
    else:
        selected_indices = all_indices[:num_puzzles * 4]

    # Combine groups into full puzzles
    puzzles = combine_puzzle_groups(dataset_dict, selected_indices)

    # Solve each puzzle
    results = []
    for puzzle in puzzles:
        print(f"\nProcessing new puzzle...")
        result = solve_puzzle(client, variant, puzzle)
        if result:
            results.append(result)

    print(f"\n=== Completed {len(results)} puzzles ===")

    # Print detailed results for each puzzle
    for i, result in enumerate(results):
        try:
            print(f"\n=== Puzzle {i+1} Details ===")
            print("\nNaive Solution Analysis:")
            if isinstance(result, dict):
                if 'naive_solution' in result:
                    print(f"Original solution: {result['naive_solution'][:100]}...")
                if 'naive_cot' in result:
                    print(f"Naive COT reasoning: {result['naive_cot'][:100]}...")
                if 'naive_steps' in result:
                    print(f"Naive decomposed steps: {len(result['naive_steps'])} steps")
                    for j, step in enumerate(result['naive_steps'], 1):
                        print(f"{j}. {step}")

                print("\nCorrect Solution Analysis:")
                if result.get('correct_cot'):
                    print(f"Correct COT reasoning: {result['correct_cot'][:100]}...")
                    print(f"Correct decomposed steps: {len(result['correct_steps'])} steps")
                    for j, step in enumerate(result['correct_steps'], 1):
                        print(f"{j}. {step}")
                else:
                    print("No correct solution analysis available")

                if 'grade' in result:
                    print("\nGrade:")
                    print(result['grade'].get('grade', 'No grade available'))
                    print(f"Correct groups: {result['grade'].get('num_correct_groups', '?')}/4")
            else:
                print(f"Invalid result format: {type(result)}")
        except Exception as e:
            print(f"Error printing results: {str(e)}")
            traceback.print_exc()

    return {
        'puzzle_results': results,
        'reasoning_analysis': analyze_reasoning_patterns(results)
    }

# Add import
import traceback

In [17]:
# Load dataset and initialize model
from datasets import load_dataset
dataset = load_dataset("anishthalamati/nyt-connections")

# Run solver with COT analysis
results = run_connection_solver(client, variant, dataset, num_puzzles=2)

# Print overall results
print("\n=== Overall Analysis ===")
for i, puzzle_result in enumerate(results['puzzle_results'], 1):
    print(f"\n=== Puzzle {i} Details ===")
    print(f"Grade: {puzzle_result['grade']['grade']}")
    print(f"Correct groups: {puzzle_result['grade']['num_correct_groups']}/4")
    print("\nNaive Reasoning steps:")
    for j, step in enumerate(puzzle_result['naive_steps'], 1):
        print(f"{j}. {step}")
    print("\nCorrect Reasoning steps:")
    if puzzle_result.get('correct_steps'):
        for j, step in enumerate(puzzle_result['correct_steps'], 1):
            print(f"{j}. {step}")
    else:
        print("No correct reasoning steps available")

print("\n=== Overall Analysis ===")
print(f"Total puzzles solved: {results['reasoning_analysis']['total_solutions']}")
print(f"Average steps per solution: {results['reasoning_analysis']['avg_steps_per_solution']:.1f}")


Processing new puzzle...
Getting naive solution...
Got naive solution of length: 1661
Expanding into COT...
Got COT reasoning
Getting correct solution COT...
Got correct COT
Decomposing naive COT...
Got 11 naive steps
Grading attempt...
Finished grading

Processing new puzzle...
Getting naive solution...
Got naive solution of length: 1665
Expanding into COT...
Got COT reasoning
Getting correct solution COT...
Got correct COT
Decomposing naive COT...
Got 11 naive steps
Grading attempt...
Finished grading

=== Completed 2 puzzles ===

=== Puzzle 1 Details ===

Naive Solution Analysis:
Original solution: Let's break down the words step by step.

First, I'll look for patterns among the words. I'll start ...
Naive COT reasoning: I'd be happy to guide you through the thought process behind the original solution.

**Step 1: Initi...
Naive decomposed steps: 11 steps
1. Here are the extracted reasoning steps:
2. To start forming groups, I examined the meanings of each word. I looked at the den

# Run Autosteer

In [None]:
#####################################
# Part 1: Skill Discovery
#####################################

class ExampleManager:
    """Manages and filters training examples"""

    def __init__(self, train_set, validation_set, test_set):
        self.train_set = train_set
        self.validation_set = validation_set
        self.test_set = test_set
        print(f"\nInitialized ExampleManager with:")
        print(f"- {len(train_set)} training examples")
        print(f"- {len(validation_set)} validation examples")
        print(f"- {len(test_set)} test examples")

        # Debug: Print structure of first example
        print("\nExample data structure:")
        example = train_set[0]
        print(json.dumps(example, indent=2))

    def prepare_examples(self, dataset, max_examples: int = 100) -> List[Example]:
        """Convert dataset examples into Example objects"""
        print(f"\nPreparing up to {max_examples} examples...")
        examples = []

        # Convert to list to use slicing
        dataset_subset = dataset.select(range(min(max_examples, len(dataset))))

        for i, item in enumerate(tqdm(dataset_subset, desc="Converting examples")):
            # Print first few examples to verify conversion
            if i < 3:
                print(f"\nExample {i + 1}:")
                print(f"Words: {item['text']}")
                print(f"Category: {item['label']}")

            example = Example(
                query=f"puzzle: Consider these words and explain how they might be related: {item['text']}",
                response=f"These words belong to the category: {item['label']}",
                is_correct=True
            )
            examples.append(example)

            if i == 3:
                print("\n... (continuing conversion)")

        print(f"\nPrepared {len(examples)} examples total")
        return examples

class SkillDiscoverer:
    """Discovers reasoning skills using AutoSteer"""

    def __init__(self, client, variant):
        self.client = client
        self.variant = variant
        self.discovered_skills = None
        self.activation_logs = {}

    async def discover_skills(self, correct_examples: List[Example], incorrect_examples: List[Example], top_k: int = 5):
        """Use AutoSteer to discover skills that distinguish correct from incorrect examples"""
        print("\n=== Skill Discovery and Analysis ===")

        # Convert examples to message format
        print("\nFormatting examples for contrast...")
        print("\nCorrect examples:")
        correct_messages = []
        for i, ex in enumerate(correct_examples):
            messages = [
                {"role": "user", "content": ex.query},
                {"role": "assistant", "content": ex.response}
            ]
            correct_messages.append(messages)
            if i < 3:  # Show first 3 examples
                print(f"\nExample {i+1}:")
                print(f"Query: {ex.query}")
                print(f"Response: {ex.response}")

        print("\nIncorrect examples:")
        incorrect_messages = []
        for i, ex in enumerate(incorrect_examples):
            messages = [
                {"role": "user", "content": ex.query},
                {"role": "assistant", "content": ex.response}
            ]
            incorrect_messages.append(messages)
            if i < 3:  # Show first 3 examples
                print(f"\nExample {i+1}:")
                print(f"Query: {ex.query}")
                print(f"Response: {ex.response}")

        # Original skill discovery
        print(f"\nFinding contrasting features between {len(correct_messages)} correct and {len(incorrect_messages)} incorrect examples...")
        _, skill_features = self.client.features.contrast(
            dataset_1=incorrect_messages,
            dataset_2=correct_messages,
            model=self.variant,
            top_k=top_k*2
        )
        print(f"\nFound {len(skill_features)} initial features:")
        for i, feat in enumerate(skill_features):
            print(f"{i+1}. {feat.label}")

        # Rerank for task relevance
        print("\nReranking features...")
        self.discovered_skills = self.client.features.rerank(
            features=skill_features,
            query="pattern recognition and categorization for word puzzles",
            model=self.variant,
            top_k=top_k
        )

        # Analyze example activations
        print("\n=== Feature Activation Analysis ===")
        for i, ex in enumerate(tqdm(correct_examples + incorrect_examples, desc="Analyzing examples")):
            context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": ex.query},
                    {"role": "assistant", "content": ex.response}
                ],
                model=self.variant,
                features=self.discovered_skills
            )

            activations = context.top(k=len(self.discovered_skills))

            # Store activation data
            self.activation_logs[i] = {
                'query': ex.query,
                'response': ex.response,
                'is_correct': i < len(correct_examples),
                'activations': [(act.feature.label, float(act.activation)) for act in activations]
            }

            # Print analysis for this example
            print(f"\nExample {i+1}:")
            print(f"Query: {ex.query[:100]}...")
            print(f"{'Correct' if i < len(correct_examples) else 'Incorrect'} example")
            print("Top Feature Activations:")
            for feat, strength in sorted(self.activation_logs[i]['activations'],
                                      key=lambda x: abs(x[1]),
                                      reverse=True):
                print(f"  {feat}: {strength:.3f}")

        # Calculate feature statistics
        print("\n=== Feature Importance Summary ===")
        feature_stats = {}
        for feat in self.discovered_skills:
            activations = []
            correct_activations = []
            incorrect_activations = []

            for log in self.activation_logs.values():
                for f, strength in log['activations']:
                    if f == feat.label:
                        activations.append(strength)
                        if log['is_correct']:
                            correct_activations.append(strength)
                        else:
                            incorrect_activations.append(strength)

            avg_activation = np.mean(activations) if activations else 0
            avg_correct = np.mean(correct_activations) if correct_activations else 0
            avg_incorrect = np.mean(incorrect_activations) if incorrect_activations else 0

            print(f"\nFeature: {feat.label}")
            print(f"Average activation: {avg_activation:.3f}")
            print(f"Average for correct examples: {avg_correct:.3f}")
            print(f"Average for incorrect examples: {avg_incorrect:.3f}")
            print(f"Activation difference: {(avg_correct - avg_incorrect):.3f}")

        return self.discovered_skills


#####################################
# Main Pipeline
#####################################

async def run_skill_discovery(num_samples: int = 20):
    """Run the skill discovery part and return both skills and the discoverer instance"""
    print(f"\n=== Starting Skill Discovery Pipeline (using {num_samples} samples) ===")

    # Initialize components
    print("\nInitializing components...")
    example_manager = ExampleManager(train_set, validation_set, test_set)
    skill_discoverer = SkillDiscoverer(client, variant)

    # Prepare training examples
    train_examples = example_manager.prepare_examples(train_set, max_examples=num_samples)

    # Split examples
    mid_point = len(train_examples) // 2
    correct_examples = train_examples[:mid_point]
    incorrect_examples = train_examples[mid_point:]

    print(f"\nSplit examples into {len(correct_examples)} correct and {len(incorrect_examples)} incorrect")

    # Discover skills
    skills = await skill_discoverer.discover_skills(correct_examples, incorrect_examples)

    print("\nDiscovered Skills:")
    for i, skill in enumerate(skills, 1):
        print(f"{i}. {skill}")
        print(f"   Label: {skill.label}")
        print(f"   UUID: {skill.uuid}")

    print("\n=== Skill Discovery Complete ===")

    return skills, skill_discoverer  # Return both

# Run first step
if __name__ == "__main__":
    num_samples = 20  # Adjust this number as needed
    discovered_skills, skill_discoverer = await run_skill_discovery(num_samples)


In [None]:
#####################################
# Step 3: Visualization and Analysis
#####################################

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display, HTML

def create_analysis_dataframe(activation_logs):
    """Convert activation logs to a pandas DataFrame for analysis"""
    rows = []

    for example_id, log in activation_logs.items():
        # For each feature activation in this example
        for feature, activation in log['activations']:
            rows.append({
                'example_id': example_id,
                'query': log['query'],
                'response': log['response'],
                'is_correct': log['is_correct'],
                'feature': feature,
                'activation': activation
            })

    return pd.DataFrame(rows)

def visualize_results(skill_discoverer):
    """Create visualizations and interactive tables from results"""

    # Convert to DataFrame
    df = create_analysis_dataframe(skill_discoverer.activation_logs)

    # Save to CSV
    df.to_csv('feature_analysis.csv', index=False)
    print("Saved detailed results to feature_analysis.csv")

    # 1. Feature Activation Heatmap
    pivot_df = df.pivot_table(
        values='activation',
        index='example_id',
        columns='feature',
        aggfunc='first'
    )

    fig1 = px.imshow(
        pivot_df,
        title='Feature Activation Heatmap',
        labels=dict(x='Feature', y='Example ID', color='Activation Strength'),
        aspect='auto'
    )
    fig1.show()

    # 2. Feature Importance Bar Chart
    avg_activations = df.groupby('feature')['activation'].agg(['mean', 'std']).reset_index()
    fig2 = px.bar(
        avg_activations,
        x='feature',
        y='mean',
        error_y='std',
        title='Average Feature Activation Strength',
        labels={'mean': 'Average Activation', 'feature': 'Feature'}
    )
    fig2.update_layout(xaxis_tickangle=45)
    fig2.show()

    # 3. Interactive Table
    def generate_interactive_table(df):
        # Create unique example IDs for filtering
        examples = df[['example_id', 'query', 'response', 'is_correct']].drop_duplicates()
        features = df['feature'].unique()

        # Create HTML for filtering controls
        filter_html = f"""
        <div style="margin-bottom: 20px;">
            <h3>Filters:</h3>
            <select id="example_filter" onchange="filterTable()">
                <option value="all">All Examples</option>
                {''.join(f'<option value="{i}">Example {i+1}</option>' for i in examples['example_id'])}
            </select>

            <select id="feature_filter" onchange="filterTable()">
                <option value="all">All Features</option>
                {''.join(f'<option value="{f}">{f}</option>' for f in features)}
            </select>

            <select id="correct_filter" onchange="filterTable()">
                <option value="all">All Results</option>
                <option value="correct">Correct Only</option>
                <option value="incorrect">Incorrect Only</option>
            </select>
        </div>
        """

        # Create table
        table_html = df.to_html(classes='data-table', index=False)

        # Add styling and JavaScript
        return f"""
        <style>
            .data-table {{
                width: 100%;
                border-collapse: collapse;
            }}
            .data-table th, .data-table td {{
                padding: 8px;
                border: 1px solid #ddd;
            }}
            .data-table tr:nth-child(even) {{
                background-color: #f9f9f9;
            }}
            .filter-controls {{
                margin-bottom: 20px;
            }}
            select {{
                margin-right: 10px;
                padding: 5px;
            }}
        </style>
        {filter_html}
        {table_html}
        <script>
        function filterTable() {{
            var example = document.getElementById('example_filter').value;
            var feature = document.getElementById('feature_filter').value;
            var correct = document.getElementById('correct_filter').value;

            var rows = document.querySelectorAll('.data-table tbody tr');
            rows.forEach(function(row) {{
                var showRow = true;

                if (example !== 'all' && row.cells[0].textContent !== example) showRow = false;
                if (feature !== 'all' && row.cells[4].textContent !== feature) showRow = false;
                if (correct !== 'all') {{
                    var isCorrect = row.cells[3].textContent === 'True';
                    if (correct === 'correct' && !isCorrect) showRow = false;
                    if (correct === 'incorrect' && isCorrect) showRow = false;
                }}

                row.style.display = showRow ? '' : 'none';
            }});
        }}
        </script>
        """

    # Display interactive table
    display(HTML(generate_interactive_table(df)))

# Usage (add to main):
print("\nGenerating visualizations...")
visualize_results(skill_discoverer)