<a href="https://colab.research.google.com/github/menhguin/natural_language_rl/blob/main/Copy_of_Natural_Language_RL_Common_solver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Step 1: Install

In [None]:
"""NLRL_with_Goodfire.ipynb

Natural Language Reinforcement Learning using Goodfire SDK
Following the steps outlined in the paper:
1. Query Selection & Answer Verification
2. COT Decomposition & Critical Token Identification
3. Feature Attribution & Analysis
4. Feature Steering
5. Robustness Testing & Iteration
"""

#####################################
# Step 0: Setup and Initialization
#####################################

# Install required packages
!pip install goodfire --quiet
!pip install datasets --quiet
!pip install tqdm --quiet

from google.colab import userdata
import goodfire
import asyncio
from tqdm.auto import tqdm
import pandas as pd
from typing import List, Dict, Any
import json
from typing import List
from dataclasses import dataclass

@dataclass
class Example:
    query: str
    response: str
    is_correct: bool

# Add your Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')


# Initialize SDK clients
client = goodfire.Client(GOODFIRE_API_KEY)
async_client = goodfire.AsyncClient(GOODFIRE_API_KEY)

# Instantiate a model variant
variant = goodfire.Variant("meta-llama/Meta-Llama-3.1-8B-Instruct")

# for my reference: Which is bigger, 9.9 or 9.11?
# https://platform.goodfire.ai/chat/ab5cc226-aa37-404c-87fd-4fb7949944c7
# https://platform.goodfire.ai/chat/0a670b58-0b76-44c1-a7cb-8356347d83c9

# Step 2: Load dataset

In [None]:
#####################################
# Cell 1: Load and Prepare Dataset
#####################################

from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm.auto import tqdm

# Load all datasets
dataset = load_dataset("anishthalamati/nyt-connections")
train_set = dataset['train']
validation_set = dataset['validation']
test_set = dataset['test']

# Print sample from each dataset to verify structure
print("\nTrain set sample:")
print(train_set[0])
print(f"Train set size: {len(train_set)}")

print("\nValidation set sample:")
print(validation_set[0])
print(f"Validation set size: {len(validation_set)}")

print("\nTest set sample:")
print(test_set[0])
print(f"Test set size: {len(test_set)}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/507 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/28.5k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/5.00k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/732 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/91 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/92 [00:00<?, ? examples/s]


Train set sample:
{'label': 'MUSICAL INSTRUMENTS', 'text': 'BASS BASSOON HARP RECORDER'}
Train set size: 732

Validation set sample:
{'label': 'ANIMAL GROUP NAMES', 'text': 'COLONY HERD PRIDE SWARM'}
Validation set size: 91

Test set sample:
{'label': 'HATS', 'text': 'BERET BOWLER FEDORA FEZ'}
Test set size: 92


# TEST 2 Step 2-3: Use Autosteer to identify shared question features approach

In [None]:
#####################################
# Cell 2: Chain of Thought Analysis
#####################################

import time
from collections import defaultdict, deque
from datetime import datetime, timedelta

class RateLimitHandler:
    def __init__(self, requests_per_minute=50):  # Adjust the limit as needed
        self.requests_per_minute = requests_per_minute
        self.request_times = deque()

    async def wait_if_needed(self):
        """Wait if we're approaching the rate limit."""
        now = datetime.now()

        # Remove requests older than 1 minute
        while self.request_times and self.request_times[0] < now - timedelta(minutes=1):
            self.request_times.popleft()

        # If we're at the limit, wait until next minute
        if len(self.request_times) >= self.requests_per_minute:
            wait_time = 60 - (now - self.request_times[0]).seconds
            if wait_time > 0:
                print(f"Rate limit approaching, waiting {wait_time} seconds...")
                await asyncio.sleep(wait_time)

        self.request_times.append(now)

# Create global rate limit handler
rate_limiter = RateLimitHandler()
from goodfire.api.exceptions import RateLimitException

async def extract_final_guess(response: str, model_variant) -> str:
    """Extract the final category guess from a longer response."""
    prompt = f"""From this word connection puzzle solution, what is the final category or connection that was guessed?

    Solution:
    {response}

    Extract just the final category/connection that was proposed. Be brief."""

    try:
        await rate_limiter.wait_if_needed()
        extraction = await async_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=model_variant,
            max_completion_tokens=50
        )
        return extraction.choices[0].message["content"].strip()
    except Exception as e:
        print(f"Error extracting final guess: {str(e)}")
        return response  # Return full response on error

async def check_answer_similarity(response: str, category: str, model_variant) -> bool:
    """Use the LLM to check if the response matches the category."""
    # First extract the final guess
    final_guess = await extract_final_guess(response, model_variant)

    prompt = f"""Consider these two answers to a word connection puzzle where words need to be grouped by a common theme:

    Correct Category: {category}
    Model's Guess: {final_guess}

    Questions to consider:
    1. Does the model's guess capture the main concept of the correct category?
    2. Would the model's categorization lead to the same group of words?
    3. Is the guess essentially correct even if expressed differently?

    Based on these criteria, is the model's answer correct? Be lenient and accept answers that capture the core concept.
    Reply with just 'yes' or 'no'."""

    try:
        await rate_limiter.wait_if_needed()
        check_response = await async_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=model_variant,
            max_completion_tokens=10
        )
        answer = check_response.choices[0].message["content"].lower().strip()
        return answer.startswith('yes')
    except Exception as e:
        print(f"Error checking answer similarity: {str(e)}")
        return False  # Default to false on error

async def check_feature_relevance(feature, model_variant) -> bool:
    """Check if a feature is relevant to solving word categorization puzzles."""
    prompt = f"""Consider this feature detected in an AI model: "{feature.label}"

    For solving word connection puzzles where we group words by common themes:
    1. Is this feature about pattern recognition, categorization, or understanding relationships?
    2. Is this a general reasoning/explanation feature that helps analyze connections?
    3. Is this feature about a specific topic (like anatomy, clothing, etc.) rather than general analysis?

    Examples of relevant features:
    - Features about finding patterns or relationships
    - Features about analyzing or explaining connections
    - Features about general reasoning and categorization
    - Features about understanding word meanings and associations

    Examples of irrelevant features:
    - Features specific to particular topics (e.g., "knowledge about hats")
    - Features about specific text formats or structures
    - Features about physical descriptions or properties
    - Features focused on particular domains of knowledge

    Based on these criteria, is this feature relevant for solving word connection puzzles?
    Reply with just 'yes' or 'no'."""

    try:
        await rate_limiter.wait_if_needed()
        check_response = await async_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=model_variant,
            max_completion_tokens=10
        )
        answer = check_response.choices[0].message["content"].lower().strip()
        return answer.startswith('yes')
    except Exception as e:
        print(f"Error checking feature relevance: {str(e)}")
        return False  # Default to false on error

    try:
        await rate_limiter.wait_if_needed()
        check_response = await async_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=model_variant,
            max_completion_tokens=10
        )
        answer = check_response.choices[0].message["content"].lower().strip()
        return answer.startswith('yes')
    except Exception as e:
        print(f"Error checking feature relevance: {str(e)}")
        return False  # Default to false on error

async def filter_relevant_features(features: List[Any], model_variant) -> List[Any]:
    """Filter features based on their relevance to word categorization."""
    relevant_features = []

    print("\nChecking feature relevance...")
    for feat in tqdm(features):
        if await check_feature_relevance(feat, model_variant):
            relevant_features.append(feat)

    return relevant_features
    def __init__(self):
        self.features = {}  # Store feature objects
        self.feature_counts = defaultdict(int)
        self.feature_activations = defaultdict(list)
        self.feature_success = defaultdict(int)

    def add_feature(self, feature, activation, success=False):
        self.features[feature.uuid] = feature  # Store the full feature object
        self.feature_counts[feature.uuid] += 1
        self.feature_activations[feature.uuid].append(activation)
        if success:
            self.feature_success[feature.uuid] += 1

    def get_stats(self):
        stats = []
        for uuid in self.feature_counts:
            activations = self.feature_activations[uuid]
            feature = self.features[uuid]
            stats.append({
                'feature': feature,
                'count': self.feature_counts[uuid],
                'avg_activation': sum(activations) / len(activations),
                'success_rate': self.feature_success[uuid] / self.feature_counts[uuid] if self.feature_counts[uuid] > 0 else 0
            })
        return sorted(stats, key=lambda x: x['success_rate'], reverse=True)

async def get_cot_solution(words: List[str], model_variant) -> str:
    """Get a chain-of-thought solution attempt with rate limit handling."""
    max_retries = 3
    retry_delay = 60  # seconds - wait for next minute's quota

    for attempt in range(max_retries):
        try:
            await rate_limiter.wait_if_needed()
            prompt = f"""puzzle: Consider these words and explain step by step how they might be related: {', '.join(words)}
            1. First, look for obvious patterns or relationships
            2. Consider different categories like: synonyms, parts of larger things, elements of a phrase
            3. Explain your thinking process
            4. Make a final guess about what category connects these words"""

            response = await async_client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                model=model_variant,
                max_completion_tokens=600
            )
            return response.choices[0].message["content"]

        except RateLimitException:
            if attempt < max_retries - 1:
                print(f"Rate limit hit, waiting {retry_delay} seconds...")
                await asyncio.sleep(retry_delay)
                retry_delay *= 2  # Exponential backoff
            else:
                raise

async def analyze_cot_features(cot_response: str, model_variant, top_k=10) -> List[Any]:
    """Analyze features with rate limit handling."""
    max_retries = 3
    retry_delay = 2

    for attempt in range(max_retries):
        try:
            context = await async_client.features.inspect(
                messages=[{
                    "role": "assistant",
                    "content": cot_response
                }],
                model=model_variant,
                features=None
            )
            return context.top(k=top_k)
        except RateLimitException:
            if attempt < max_retries - 1:
                print(f"Rate limit hit, waiting {retry_delay} seconds...")
                await asyncio.sleep(retry_delay)
                retry_delay *= 2
            else:
                raise

async def analyze_dataset_with_cot(dataset_subset, num_samples=30):
    """Analyze dataset examples with feature tracking."""
    results = []
    feature_tracker = FeatureTracker()

    for i in tqdm(range(min(num_samples, len(dataset_subset)))):
        try:
            example = dataset_subset[i]
            words = example['text'].split()
            category = example['label']

            # Get chain-of-thought solution
            cot_response = await get_cot_solution(words, variant)

            # Analyze features
            cot_features = await analyze_cot_features(cot_response, variant)

            # Track features and check if solution was correct using LLM
            is_correct = await check_answer_similarity(cot_response, category, variant)
            for feat in cot_features:
                feature_tracker.add_feature(feat.feature, feat.activation, is_correct)

            # Store results
            results.append({
                'words': words,
                'category': category,
                'cot_response': cot_response,
                'features': [(feat.feature, feat.activation) for feat in cot_features],
                'is_correct': is_correct
            })

            # Print ongoing analysis
            print(f"\nAnalyzing group {i+1}:")
            print(f"Words: {words}")
            print(f"Category: {category}")
            print("Solution correct:", "✓" if is_correct else "✗")
            print("\nChain of Thought:")
            print(cot_response)
            print("\nTop Features:")
            for feat in cot_features:
                print(f"- {feat.feature.label}: {feat.activation:.3f}")
            print("-" * 50)

        except Exception as e:
            print(f"Error processing example {i}: {str(e)}")
            continue

    # Print feature statistics
    print("\n=== Feature Statistics ===")
    stats = feature_tracker.get_stats()
    for stat in stats[:10]:  # Show top 10 features by success rate
        print(f"\nFeature: {stat['feature'].label}")
        print(f"Used {stat['count']} times")
        print(f"Average activation: {stat['avg_activation']:.3f}")
        print(f"Success rate: {stat['success_rate']:.2%}")

    return results, feature_tracker

#####################################
# Cell 3: Feature Selection and Steering
#####################################

async def identify_successful_features(feature_tracker, cot_analysis_results):
    """Identify features using tracked statistics."""
    # Get features with highest success rates
    stats = feature_tracker.get_stats()
    successful_features = [stat['feature'] for stat in stats if stat['success_rate'] > 0.5]

    if successful_features:

        # Use rerank with rate limit handling
        max_retries = 3
        retry_delay = 2

        for attempt in range(max_retries):
            try:
                return client.features.rerank(
                    features=goodfire.FeatureGroup(successful_features),
                    query="pattern recognition and category identification",
                    model=variant,
                    top_k=10
                )
            except RateLimitException:
                if attempt < max_retries - 1:
                    print(f"Rate limit hit, waiting {retry_delay} seconds...")
                    await asyncio.sleep(retry_delay)
                    retry_delay *= 2
                else:
                    raise

    return []

async def get_optimal_feature_values(successful_features):
    """Get optimal values with rate limit handling."""
    max_retries = 3
    retry_delay = 2

    for attempt in range(max_retries):
        try:
            edits = client.features.AutoSteer(
                specification="solve word categorization puzzles by identifying patterns and relationships",
                model=variant
            )

            feature_values = {}
            for feature in successful_features:
                if feature in edits:
                    feature_values[feature] = edits[feature]
                else:
                    feature_values[feature] = 0.5

            return feature_values

        except RateLimitException:
            if attempt < max_retries - 1:
                print(f"Rate limit hit, waiting {retry_delay} seconds...")
                await asyncio.sleep(retry_delay)
                retry_delay *= 2
            else:
                raise

async def test_enhanced_solver(words: List[str], category: str, feature_values: Dict):
    """Test solver with rate limit handling."""
    max_retries = 3
    retry_delay = 2

    for attempt in range(max_retries):
        try:
            variant.reset()
            variant.set(feature_values)

            prompt = f"""puzzle: Consider these words and explain step by step how they might be related: {', '.join(words)}
            1. First, look for obvious patterns or relationships
            2. Consider different categories like: synonyms, parts of larger things, elements of a phrase
            3. Explain your thinking process
            4. Make a final guess about what category connects these words"""

            response = await async_client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                model=variant,
                max_completion_tokens=300
            )

            solution_text = response.choices[0].message["content"]
            is_correct = await check_answer_similarity(solution_text, category, variant)
            return {
                'solution': solution_text,
                'correct': is_correct
            }

        except RateLimitException:
            if attempt < max_retries - 1:
                print(f"Rate limit hit, waiting {retry_delay} seconds...")
                await asyncio.sleep(retry_delay)
                retry_delay *= 2
            else:
                raise

# Main execution flow for cells 2 and 3
async def run_enhanced_solver(num_samples=30):
    # First, analyze examples with chain of thought
    print("Starting Chain of Thought analysis...")
    cot_analysis, feature_tracker = await analyze_dataset_with_cot(test_set)

    # Identify successful features
    print("\nIdentifying successful features...")
    successful_features = await identify_successful_features(feature_tracker, cot_analysis)

    # Get optimal feature values
    print("\nDetermining optimal feature values...")
    feature_values = await get_optimal_feature_values(successful_features)

    # Test enhanced solver
    print("\nTesting enhanced solver...")
    results = []
    for i in range(num_samples):
        try:
            example = test_set[i]  # Get actual dataset row
            result = await test_enhanced_solver(
                example['text'].split(),
                example['label'],
                feature_values
            )
            results.append(result)

            # Print ongoing results
            print(f"\nPuzzle words: {example['text']}")
            print(f"True category: {example['label']}")
            print(f"Solution: {result['solution'][:100]}...")
            print(f"Correct: {'✓' if result['correct'] else '✗'}")

        except Exception as e:
            print(f"Error processing example: {str(e)}")
            continue

    # Calculate accuracy
    accuracy = sum(1 for r in results if r['correct']) / len(results)
    print(f"\nFinal accuracy: {accuracy:.2%}")

    return results, feature_values, feature_tracker

In [None]:
#####################################
# NYT Connections Solver Execution
#####################################

# Configure number of samples to analyze
NUM_SAMPLES = 10  # Adjust this number as needed

print("============ Feature Analysis and Enhanced Solving ============")

# Debug print the dataset
print("\nFirst test set item:", test_set[0])
print("Type:", type(test_set[0]))
print("Keys:", test_set[0].keys() if hasattr(test_set[0], 'keys') else "No keys")

# Step 1: Initial Chain of Thought Analysis
print(f"\nAnalyzing {NUM_SAMPLES} examples with Chain of Thought...")
cot_analysis, feature_tracker = await analyze_dataset_with_cot(test_set, num_samples=NUM_SAMPLES)

print("\n=== Initial Feature Statistics ===")
stats = feature_tracker.get_stats()
print(f"Total unique features tracked: {len(stats)}")

print("\nTop 10 Most Successful Features:")
for stat in stats[:10]:
    print(f"\nFeature: {stat['feature'].label}")
    print(f"Used {stat['count']} times")
    print(f"Average activation: {stat['avg_activation']:.3f}")
    print(f"Success rate: {stat['success_rate']:.2%}")

baseline_correct = sum(1 for result in cot_analysis if result['is_correct'])
baseline_accuracy = baseline_correct/len(cot_analysis) if cot_analysis else 0
print(f"\nBaseline accuracy: {baseline_accuracy:.2%}")

# Step 2: Enhanced Solving with Feature Steering
print("\n=== Testing Enhanced Solver ===")
try:
    # Identify successful features and get optimal values
    successful_features = await identify_successful_features(feature_tracker, cot_analysis)
    if successful_features:
        feature_values = await get_optimal_feature_values(successful_features)
        print("\nOptimal Feature Values:")
        for feature, value in feature_values.items():
            print(f"Feature: {feature.label}")
            print(f"Value: {value:.3f}")

        # Test enhanced solver with debug prints
        results = []
        print("\nTesting enhanced solver...")
        for i in range(NUM_SAMPLES):
            try:
                example = test_set[i]  # Get actual dataset row
                print(f"\nProcessing example {i}:")
                print("Example type:", type(example))
                print("Example content:", example)
                result = await test_enhanced_solver(
                    example['text'].split(),
                    example['label'],
                    feature_values
                )
                if result:
                    results.append(result)
                    print(f"\nPuzzle words: {example['text']}")
                    print(f"True category: {example['label']}")
                    print(f"Solution: {result['solution'][:100]}...")
                    print(f"Correct: {'✓' if result['correct'] else '✗'}")

            except Exception as e:
                print(f"Error processing example: {str(e)}")
                print("Failed example:", example)
                continue

        # Calculate final accuracy
        if results:
            valid_results = [r for r in results if isinstance(r, dict) and 'correct' in r]
            if valid_results:
                enhanced_accuracy = sum(1 for r in valid_results if r['correct']) / len(valid_results)
                print(f"\nEnhanced accuracy: {enhanced_accuracy:.2%}")
                print(f"Improvement: {(enhanced_accuracy - baseline_accuracy)*100:.1f} percentage points")
            else:
                print("\nNo valid results to calculate enhanced accuracy.")
        else:
            print("\nNo results were generated.")
    else:
        print("\nNo successful features identified for enhancement.")

except Exception as e:
    print(f"\nError in enhancement testing: {str(e)}")
    print("Please check the implementation and try again.")


First test set item: {'label': 'HATS', 'text': 'BERET BOWLER FEDORA FEZ'}
Type: <class 'dict'>
Keys: dict_keys(['label', 'text'])

Analyzing 10 examples with Chain of Thought...


  0%|          | 0/10 [00:00<?, ?it/s]


Analyzing group 1:
Words: ['BERET', 'BOWLER', 'FEDORA', 'FEZ']
Category: HATS
Solution correct: ✓

Chain of Thought:
Let's break down the words step by step.

1. **Obvious patterns or relationships**: At first glance, these words all seem to be related to headwear. Each word is a type of hat. 

2. **Consider different categories**: Now, let's think about other categories that these words might fit into. We've already established that they're all types of hats. Within that category, we could also consider the fact that they're all popular or well-known types of hats. However, this still doesn't seem to be the most specific or interesting connection.

3. **Explain my thinking process**: As I continue to think about these words, I start to wonder if there's a more specific connection between them. I consider the fact that each of these words is a type of hat that's often associated with a particular culture or profession. For example, the fedora is often associated with detectives, while

CancelledError: 

# TEST Step 2-3: NYT Connections -> Answer feature direct steering approach

In [None]:
#####################################
# Cell 2: Answer-Guided Feature Analysis
#####################################

async def analyze_answer_features(category: str, model_variant, top_k=5) -> List[Any]:
    """Get features relevant to the category/answer."""
    features = client.features.search(
        query=category,  # Just search for the answer/category
        model=variant,
        top_k=top_k
    )
    return [(feat, 1.0) for feat in features]  # Return with default activation of 1.0

async def analyze_dataset_features(dataset_subset, num_samples=30):
    results = []
    print("\nAnalyzing features for each category...")

    for i in tqdm(range(min(num_samples, len(dataset_subset)))):
        example = dataset_subset[i]
        words = example['text'].split()
        category = example['label']

        features = await analyze_answer_features(category, variant)
        results.append({
            'words': words,
            'category': category,
            'features': features
        })

        # Print results immediately for each example
        print(f"\nGroup {i+1}:")
        print(f"Words: {words}")
        print(f"Category: {category}")
        print("Answer-relevant Features:")
        for feat, _ in features:
            print(f"- {feat.label}")
        print("-" * 50)

    return results

# Run feature analysis on test set
print("Starting feature analysis...")
feature_analysis = await analyze_dataset_features(test_set)
print("\nFeature analysis complete.")

In [None]:
#####################################
# Cell 3: Steering Test with Rate Limit Handling
#####################################

import asyncio
from typing import List
from goodfire.api.exceptions import RateLimitException

# Run steering tests on examples
async def run_steering_tests(feature_analysis, num_samples=30):
    all_results = []

    print("Running steering tests...")
    for example in tqdm(feature_analysis[:num_samples], desc="Testing examples"):
        try:
            results = await test_steering_strengths(
                example['words'],
                example['category'],
                example['features']
            )
            all_results.append({
                'words': example['words'],
                'category': example['category'],
                'results': results
            })

            # Print ongoing results for this example
            print(f"\nJust tested: {example['words']} (Category: {example['category']})")
            for condition, result in results.items():
                print(f"{condition}: {'✓' if result['matches_category'] else '✗'}")

        except Exception as e:
            print(f"Error processing example {example['words']}: {str(e)}")

    return all_results

# Run tests and display results with accuracy metrics
steering_results = await run_steering_tests(feature_analysis)

# Calculate and display accuracy for each steering strength
accuracies = {'baseline': 0}
for strength in [0.1, 0.2, 0.5, 1.0]:
    accuracies[f'strength_{strength}'] = 0

total_tests = len(steering_results)

print("\n=== Final Results ===")
print(f"Total examples tested: {total_tests}")

for test in steering_results:
    print(f"\nTest for words: {test['words']}")
    print(f"True category: {test['category']}")
    for condition, response in test['results'].items():
        print(f"\n{condition}:")
        print(response['response'][:100] + "..." if len(response['response']) > 100 else response['response'])
        print(f"Matches category: {response['matches_category']}")
        accuracies[condition] += response['matches_category']

print("\n=== Accuracy Results ===")
for condition, correct in accuracies.items():
    accuracy = (correct / total_tests) * 100
    print(f"{condition}: {accuracy:.1f}% ({correct}/{total_tests})")

# Step 2-3: NYT Connections -> Group of words question approach

In [None]:
#####################################
# Cell 2: Feature Activation Analysis
#####################################

async def analyze_features_for_group(words: List[str], model_variant) -> List[Any]:
    """Analyze feature activations for a group of words."""
    prompt = f"puzzle: what do these have in common? {', '.join(words)}"

    context = await async_client.features.inspect(
        messages=[{
            "role": "user",
            "content": prompt
        }],
        model=model_variant
    )
    return context.top(k=5)  # Get top 5 features for initial analysis

async def analyze_dataset_features(dataset_subset, num_samples=30):  # Changed to match Cell 3 default
    """Analyze features for a subset of the dataset."""
    results = []

    for i in tqdm(range(min(num_samples, len(dataset_subset)))):
        example = dataset_subset[i]

        # Split the space-separated words into a list
        words = example['text'].split()
        category = example['label']

        activations = await analyze_features_for_group(words, variant)
        results.append({
            'words': words,
            'category': category,
            'features': [(feat.feature, feat.activation) for feat in activations]  # Store actual feature object
        })

        # Print results as we go
        print(f"\nWords: {words}")
        print(f"Category: {category}")
        print("Top Features:")
        for feat in activations:
            print(f"- {feat.feature.label}: {feat.activation:.3f}")

    return results

# Run feature analysis on test set
feature_analysis = await analyze_dataset_features(test_set)

In [None]:
#####################################
# Cell 3: Feature Steering Test
#####################################

from tqdm.auto import tqdm

async def test_steering_strengths(words: List[str], category: str, feature_pairs, strengths=[0.1, 0.2, 0.5, 1.0]):
    """Test different steering strengths for given features."""
    results = {}

    # Baseline (no steering)
    variant.reset()
    prompt = f"puzzle: what do these have in common? {', '.join(words)}"
    baseline_response = await async_client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model=variant,
        max_completion_tokens=100
    )
    baseline_content = baseline_response.choices[0].message["content"]
    results['baseline'] = {
        'response': baseline_content,
        'matches_category': category.lower() in baseline_content.lower()
    }

    # Test each steering strength
    for strength in strengths:
        variant.reset()
        # Create direct feature adjustments from Cell 2 results
        adjustments = {feat: strength for feat, _ in feature_pairs}

        if adjustments:
            variant.set(adjustments)

        steered_response = await async_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=variant,
            max_completion_tokens=100
        )
        steered_content = steered_response.choices[0].message["content"]
        results[f'strength_{strength}'] = {
            'response': steered_content,
            'matches_category': category.lower() in steered_content.lower()
        }

    return results

# Run steering tests on examples
async def run_steering_tests(feature_analysis, num_samples=30):
    all_results = []

    print("Running steering tests...")
    for example in tqdm(feature_analysis[:num_samples], desc="Testing examples"):
        try:
            results = await test_steering_strengths(
                example['words'],
                example['category'],
                example['features']
            )
            all_results.append({
                'words': example['words'],
                'category': example['category'],
                'results': results
            })

            # Print ongoing results for this example
            print(f"\nJust tested: {example['words']} (Category: {example['category']})")
            for condition, result in results.items():
                print(f"{condition}: {'✓' if result['matches_category'] else '✗'}")

        except Exception as e:
            print(f"Error processing example {example['words']}: {str(e)}")

    return all_results

# Run tests and display results with accuracy metrics
steering_results = await run_steering_tests(feature_analysis)

# Calculate and display accuracy for each steering strength
accuracies = {'baseline': 0}
for strength in [0.1, 0.2, 0.5, 1.0]:
    accuracies[f'strength_{strength}'] = 0

total_tests = len(steering_results)

print("\n=== Final Results ===")
print(f"Total examples tested: {total_tests}")

for test in steering_results:
    print(f"\nTest for words: {test['words']}")
    print(f"True category: {test['category']}")
    for condition, response in test['results'].items():
        print(f"\n{condition}:")
        print(response['response'][:100] + "..." if len(response['response']) > 100 else response['response'])
        print(f"Matches category: {response['matches_category']}")
        accuracies[condition] += response['matches_category']

print("\n=== Accuracy Results ===")
for condition, correct in accuracies.items():
    accuracy = (correct / total_tests) * 100
    print(f"{condition}: {accuracy:.1f}% ({correct}/{total_tests})")