<a href="https://colab.research.google.com/github/menhguin/natural_language_rl/blob/main/Another_copy_of_Natural_Language_RL_Common_solver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup 1: Install

In [None]:
"""NLRL_with_Goodfire.ipynb

Natural Language Reinforcement Learning using Goodfire SDK
Following the steps outlined in the paper:
1. Query Selection & Answer Verification
2. COT Decomposition & Critical Token Identification
3. Feature Attribution & Analysis
4. Feature Steering
5. Robustness Testing & Iteration
"""

#####################################
# Step 0: Setup and Initialization
#####################################

# Install required packages
!pip install goodfire --quiet
!pip install datasets --quiet
!pip install tqdm --quiet

from google.colab import userdata
import goodfire
import asyncio
from tqdm.auto import tqdm
import pandas as pd
from typing import List, Dict, Any
import json
from typing import List
from dataclasses import dataclass

@dataclass
class Example:
    query: str
    response: str
    is_correct: bool

# Add your Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')


# Initialize SDK clients
client = goodfire.Client(GOODFIRE_API_KEY)
async_client = goodfire.AsyncClient(GOODFIRE_API_KEY)

# Instantiate a model variant
variant = goodfire.Variant("meta-llama/Meta-Llama-3.1-8B-Instruct")

# for my reference: Which is bigger, 9.9 or 9.11?
# https://platform.goodfire.ai/chat/ab5cc226-aa37-404c-87fd-4fb7949944c7
# https://platform.goodfire.ai/chat/0a670b58-0b76-44c1-a7cb-8356347d83c9

# Setup 2: Load dataset

In [None]:
#####################################
# Cell 1: Load and Prepare Dataset
#####################################

from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm.auto import tqdm

# Load all datasets
dataset = load_dataset("anishthalamati/nyt-connections")
train_set = dataset['train']
validation_set = dataset['validation']
test_set = dataset['test']

# Print sample from each dataset to verify structure
print("\nTrain set sample:")
print(train_set[0])
print(f"Train set size: {len(train_set)}")

print("\nValidation set sample:")
print(validation_set[0])
print(f"Validation set size: {len(validation_set)}")

print("\nTest set sample:")
print(test_set[0])
print(f"Test set size: {len(test_set)}")

# Run Autosteer

In [None]:
#####################################
# Part 1: Skill Discovery
#####################################

class ExampleManager:
    """Manages and filters training examples"""

    def __init__(self, train_set, validation_set, test_set):
        self.train_set = train_set
        self.validation_set = validation_set
        self.test_set = test_set
        print(f"\nInitialized ExampleManager with:")
        print(f"- {len(train_set)} training examples")
        print(f"- {len(validation_set)} validation examples")
        print(f"- {len(test_set)} test examples")

        # Debug: Print structure of first example
        print("\nExample data structure:")
        example = train_set[0]
        print(json.dumps(example, indent=2))

    def prepare_examples(self, dataset, max_examples: int = 100) -> List[Example]:
        """Convert dataset examples into Example objects"""
        print(f"\nPreparing up to {max_examples} examples...")
        examples = []

        # Convert to list to use slicing
        dataset_subset = dataset.select(range(min(max_examples, len(dataset))))

        for i, item in enumerate(tqdm(dataset_subset, desc="Converting examples")):
            # Print first few examples to verify conversion
            if i < 3:
                print(f"\nExample {i + 1}:")
                print(f"Words: {item['text']}")
                print(f"Category: {item['label']}")

            example = Example(
                query=f"puzzle: Consider these words and explain how they might be related: {item['text']}",
                response=f"These words belong to the category: {item['label']}",
                is_correct=True
            )
            examples.append(example)

            if i == 3:
                print("\n... (continuing conversion)")

        print(f"\nPrepared {len(examples)} examples total")
        return examples

class SkillDiscoverer:
    """Discovers reasoning skills using AutoSteer"""

    def __init__(self, client, variant):
        self.client = client
        self.variant = variant
        self.discovered_skills = None
        self.activation_logs = {}

    async def discover_skills(self, correct_examples: List[Example], incorrect_examples: List[Example], top_k: int = 5):
        """Use AutoSteer to discover skills that distinguish correct from incorrect examples"""
        print("\n=== Skill Discovery and Analysis ===")

        # Convert examples to message format
        print("\nFormatting examples for contrast...")
        print("\nCorrect examples:")
        correct_messages = []
        for i, ex in enumerate(correct_examples):
            messages = [
                {"role": "user", "content": ex.query},
                {"role": "assistant", "content": ex.response}
            ]
            correct_messages.append(messages)
            if i < 3:  # Show first 3 examples
                print(f"\nExample {i+1}:")
                print(f"Query: {ex.query}")
                print(f"Response: {ex.response}")

        print("\nIncorrect examples:")
        incorrect_messages = []
        for i, ex in enumerate(incorrect_examples):
            messages = [
                {"role": "user", "content": ex.query},
                {"role": "assistant", "content": ex.response}
            ]
            incorrect_messages.append(messages)
            if i < 3:  # Show first 3 examples
                print(f"\nExample {i+1}:")
                print(f"Query: {ex.query}")
                print(f"Response: {ex.response}")

        # Original skill discovery
        print(f"\nFinding contrasting features between {len(correct_messages)} correct and {len(incorrect_messages)} incorrect examples...")
        _, skill_features = self.client.features.contrast(
            dataset_1=incorrect_messages,
            dataset_2=correct_messages,
            model=self.variant,
            top_k=top_k*2
        )
        print(f"\nFound {len(skill_features)} initial features:")
        for i, feat in enumerate(skill_features):
            print(f"{i+1}. {feat.label}")

        # Rerank for task relevance
        print("\nReranking features...")
        self.discovered_skills = self.client.features.rerank(
            features=skill_features,
            query="pattern recognition and categorization for word puzzles",
            model=self.variant,
            top_k=top_k
        )

        # Analyze example activations
        print("\n=== Feature Activation Analysis ===")
        for i, ex in enumerate(tqdm(correct_examples + incorrect_examples, desc="Analyzing examples")):
            context = self.client.features.inspect(
                messages=[
                    {"role": "user", "content": ex.query},
                    {"role": "assistant", "content": ex.response}
                ],
                model=self.variant,
                features=self.discovered_skills
            )

            activations = context.top(k=len(self.discovered_skills))

            # Store activation data
            self.activation_logs[i] = {
                'query': ex.query,
                'response': ex.response,
                'is_correct': i < len(correct_examples),
                'activations': [(act.feature.label, float(act.activation)) for act in activations]
            }

            # Print analysis for this example
            print(f"\nExample {i+1}:")
            print(f"Query: {ex.query[:100]}...")
            print(f"{'Correct' if i < len(correct_examples) else 'Incorrect'} example")
            print("Top Feature Activations:")
            for feat, strength in sorted(self.activation_logs[i]['activations'],
                                      key=lambda x: abs(x[1]),
                                      reverse=True):
                print(f"  {feat}: {strength:.3f}")

        # Calculate feature statistics
        print("\n=== Feature Importance Summary ===")
        feature_stats = {}
        for feat in self.discovered_skills:
            activations = []
            correct_activations = []
            incorrect_activations = []

            for log in self.activation_logs.values():
                for f, strength in log['activations']:
                    if f == feat.label:
                        activations.append(strength)
                        if log['is_correct']:
                            correct_activations.append(strength)
                        else:
                            incorrect_activations.append(strength)

            avg_activation = np.mean(activations) if activations else 0
            avg_correct = np.mean(correct_activations) if correct_activations else 0
            avg_incorrect = np.mean(incorrect_activations) if incorrect_activations else 0

            print(f"\nFeature: {feat.label}")
            print(f"Average activation: {avg_activation:.3f}")
            print(f"Average for correct examples: {avg_correct:.3f}")
            print(f"Average for incorrect examples: {avg_incorrect:.3f}")
            print(f"Activation difference: {(avg_correct - avg_incorrect):.3f}")

        return self.discovered_skills


#####################################
# Main Pipeline
#####################################

async def run_skill_discovery(num_samples: int = 20):
    """Run the skill discovery part and return both skills and the discoverer instance"""
    print(f"\n=== Starting Skill Discovery Pipeline (using {num_samples} samples) ===")

    # Initialize components
    print("\nInitializing components...")
    example_manager = ExampleManager(train_set, validation_set, test_set)
    skill_discoverer = SkillDiscoverer(client, variant)

    # Prepare training examples
    train_examples = example_manager.prepare_examples(train_set, max_examples=num_samples)

    # Split examples
    mid_point = len(train_examples) // 2
    correct_examples = train_examples[:mid_point]
    incorrect_examples = train_examples[mid_point:]

    print(f"\nSplit examples into {len(correct_examples)} correct and {len(incorrect_examples)} incorrect")

    # Discover skills
    skills = await skill_discoverer.discover_skills(correct_examples, incorrect_examples)

    print("\nDiscovered Skills:")
    for i, skill in enumerate(skills, 1):
        print(f"{i}. {skill}")
        print(f"   Label: {skill.label}")
        print(f"   UUID: {skill.uuid}")

    print("\n=== Skill Discovery Complete ===")

    return skills, skill_discoverer  # Return both

# Run first step
if __name__ == "__main__":
    num_samples = 20  # Adjust this number as needed
    discovered_skills, skill_discoverer = await run_skill_discovery(num_samples)


In [None]:
#####################################
# Step 3: Visualization and Analysis
#####################################

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display, HTML

def create_analysis_dataframe(activation_logs):
    """Convert activation logs to a pandas DataFrame for analysis"""
    rows = []

    for example_id, log in activation_logs.items():
        # For each feature activation in this example
        for feature, activation in log['activations']:
            rows.append({
                'example_id': example_id,
                'query': log['query'],
                'response': log['response'],
                'is_correct': log['is_correct'],
                'feature': feature,
                'activation': activation
            })

    return pd.DataFrame(rows)

def visualize_results(skill_discoverer):
    """Create visualizations and interactive tables from results"""

    # Convert to DataFrame
    df = create_analysis_dataframe(skill_discoverer.activation_logs)

    # Save to CSV
    df.to_csv('feature_analysis.csv', index=False)
    print("Saved detailed results to feature_analysis.csv")

    # 1. Feature Activation Heatmap
    pivot_df = df.pivot_table(
        values='activation',
        index='example_id',
        columns='feature',
        aggfunc='first'
    )

    fig1 = px.imshow(
        pivot_df,
        title='Feature Activation Heatmap',
        labels=dict(x='Feature', y='Example ID', color='Activation Strength'),
        aspect='auto'
    )
    fig1.show()

    # 2. Feature Importance Bar Chart
    avg_activations = df.groupby('feature')['activation'].agg(['mean', 'std']).reset_index()
    fig2 = px.bar(
        avg_activations,
        x='feature',
        y='mean',
        error_y='std',
        title='Average Feature Activation Strength',
        labels={'mean': 'Average Activation', 'feature': 'Feature'}
    )
    fig2.update_layout(xaxis_tickangle=45)
    fig2.show()

    # 3. Interactive Table
    def generate_interactive_table(df):
        # Create unique example IDs for filtering
        examples = df[['example_id', 'query', 'response', 'is_correct']].drop_duplicates()
        features = df['feature'].unique()

        # Create HTML for filtering controls
        filter_html = f"""
        <div style="margin-bottom: 20px;">
            <h3>Filters:</h3>
            <select id="example_filter" onchange="filterTable()">
                <option value="all">All Examples</option>
                {''.join(f'<option value="{i}">Example {i+1}</option>' for i in examples['example_id'])}
            </select>

            <select id="feature_filter" onchange="filterTable()">
                <option value="all">All Features</option>
                {''.join(f'<option value="{f}">{f}</option>' for f in features)}
            </select>

            <select id="correct_filter" onchange="filterTable()">
                <option value="all">All Results</option>
                <option value="correct">Correct Only</option>
                <option value="incorrect">Incorrect Only</option>
            </select>
        </div>
        """

        # Create table
        table_html = df.to_html(classes='data-table', index=False)

        # Add styling and JavaScript
        return f"""
        <style>
            .data-table {{
                width: 100%;
                border-collapse: collapse;
            }}
            .data-table th, .data-table td {{
                padding: 8px;
                border: 1px solid #ddd;
            }}
            .data-table tr:nth-child(even) {{
                background-color: #f9f9f9;
            }}
            .filter-controls {{
                margin-bottom: 20px;
            }}
            select {{
                margin-right: 10px;
                padding: 5px;
            }}
        </style>
        {filter_html}
        {table_html}
        <script>
        function filterTable() {{
            var example = document.getElementById('example_filter').value;
            var feature = document.getElementById('feature_filter').value;
            var correct = document.getElementById('correct_filter').value;

            var rows = document.querySelectorAll('.data-table tbody tr');
            rows.forEach(function(row) {{
                var showRow = true;

                if (example !== 'all' && row.cells[0].textContent !== example) showRow = false;
                if (feature !== 'all' && row.cells[4].textContent !== feature) showRow = false;
                if (correct !== 'all') {{
                    var isCorrect = row.cells[3].textContent === 'True';
                    if (correct === 'correct' && !isCorrect) showRow = false;
                    if (correct === 'incorrect' && isCorrect) showRow = false;
                }}

                row.style.display = showRow ? '' : 'none';
            }});
        }}
        </script>
        """

    # Display interactive table
    display(HTML(generate_interactive_table(df)))

# Usage (add to main):
print("\nGenerating visualizations...")
visualize_results(skill_discoverer)