# Pixelated Empathy: Edge Case Generation Pipeline

This notebook generates challenging therapy scenarios for training difficult client simulation models.

---

**Instructions:**
1. Run each cell in order.
2. Configure your API key and model.
3. Generate prompts and conversations.
4. Analyze and export results.

---

In [None]:
# Install requirements if needed (uncomment if running in Colab)
!pip install openai anthropic pandas numpy tqdm matplotlib seaborn ipywidgets requests

In [None]:
import os
import sys
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

# Import the edge case generator
from edge_case_generator import EdgeCaseGenerator

plt.style.use("default")
sns.set_palette("husl")
print("✅ Imports successful!")

## 🔧 Configuration

Choose your API provider and configure settings:

In [None]:
api_provider_widget = widgets.Dropdown(
    options=["openai", "anthropic", "ollama"],
    value="ollama",
    description="API Provider:",
    style={"description_width": "initial"},
)
api_key_widget = widgets.Password(
    placeholder="Enter your API key (not needed for Ollama)",
    description="API Key:",
    style={"description_width": "initial"},
)
model_widget = widgets.Dropdown(
    options={
        "GPT-3.5 Turbo": "gpt-3.5-turbo",
        "GPT-4": "gpt-4",
        "GPT-4 Turbo": "gpt-4-turbo-preview",
        "Claude 3 Haiku": "claude-3-haiku-20240307",
        "Claude 3 Sonnet": "claude-3-sonnet-20240229",
        "Claude 3 Opus": "claude-3-opus-20240229",
        "Custom Ollama Model": "artifish/llama3.2-uncensored",
    },
    value="gpt-3.5-turbo",
    description="Model:",
    style={"description_width": "initial"},
)
custom_model_widget = widgets.Text(
    placeholder="Enter custom model name (for Ollama)",
    description="Custom Model:",
    style={"description_width": "initial"},
)
scenarios_per_category_widget = widgets.IntSlider(
    value=20,
    min=1,
    max=50,
    step=1,
    description="Scenarios per category:",
    style={"description_width": "initial"},
)
max_conversations_widget = widgets.IntSlider(
    value=100,
    min=10,
    max=500,
    step=10,
    description="Max conversations to generate:",
    style={"description_width": "initial"},
)
output_dir_widget = widgets.Text(
    value="edge_case_output",
    description="Output directory:",
    style={"description_width": "initial"},
)
display(
    widgets.VBox(
        [
            widgets.HTML("<h3>🔧 Configuration</h3>"),
            api_provider_widget,
            api_key_widget,
            model_widget,
            custom_model_widget,
            scenarios_per_category_widget,
            max_conversations_widget,
            output_dir_widget,
        ]
    )
)

## 📊 Preview Edge Case Categories

Let's see what categories we'll be generating:

In [None]:
temp_generator = EdgeCaseGenerator()
categories = temp_generator.edge_case_categories
category_df = pd.DataFrame(
    [
        {
            "Category": cat,
            "Description": details["description"],
            "Difficulty": details["difficulty"],
            "Challenges": (
                ", ".join(details["challenges"][:2]) + "..."
                if len(details["challenges"]) > 2
                else ", ".join(details["challenges"])
            ),
        }
        for cat, details in categories.items()
    ]
)
print(f"📋 Total Categories: {len(category_df)}")
print(f"🎯 Total Scenarios to Generate: {len(category_df) * scenarios_per_category_widget.value}")
display(HTML(category_df.to_html(index=False, escape=False)))
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
difficulty_counts = category_df["Difficulty"].value_counts()
ax[0].pie(difficulty_counts.values, labels=difficulty_counts.index, autopct="%1.1f%%")
ax[0].set_title("Difficulty Level Distribution")
sns.countplot(data=category_df, x="Difficulty", ax=ax[1])
ax[1].set_title("Categories by Difficulty Level")
ax[1].set_xlabel("Difficulty Level")
ax[1].set_ylabel("Number of Categories")
plt.tight_layout()
plt.show()

## 🚀 Initialize Generator

Set up the generator with your chosen configuration:

In [None]:
api_provider = api_provider_widget.value
api_key = api_key_widget.value if api_key_widget.value else None
model_name = custom_model_widget.value if model_widget.value == "custom" else model_widget.value
scenarios_per_category = scenarios_per_category_widget.value
max_conversations = max_conversations_widget.value
output_dir = output_dir_widget.value
print(f"🔧 Configuration:")
print(f"   API Provider: {api_provider}")
print(f"   Model: {model_name}")
print(f"   Scenarios per category: {scenarios_per_category}")
print(f"   Max conversations: {max_conversations}")
print(f"   Output directory: {output_dir}")
try:
    generator = EdgeCaseGenerator(
        api_provider=api_provider,
        api_key=api_key,
        model_name=model_name,
        output_dir=output_dir,
    )
    print("✅ Generator initialized successfully!")
except Exception as e:
    print(f"❌ Error initializing generator: {e}")

## 📝 Step 1: Generate Prompts

First, we'll generate all the prompts for our edge cases:

In [None]:
print("📝 Generating prompts...")
prompts = generator.generate_prompts(scenarios_per_category=scenarios_per_category)
print(f"✅ Generated {len(prompts)} prompts")
print(f"📁 Saved to: {generator.output_dir}/edge_case_prompts.jsonl")
print("
📋 Sample prompts:")
for i, prompt in enumerate(prompts[:3]):
    print(f"
{i+1}. {prompt['scenario_id']} ({prompt['category']})")
    print(f"   Difficulty: {prompt['difficulty_level']}")
    print(f"   Instructions: {prompt['instructions'][:100]}...")

## 🤖 Step 2: Generate Conversations

Now we'll use the API to generate realistic therapy conversations:

In [None]:
from tqdm.notebook import tqdm as tqdm_nb
progress_widget = widgets.IntProgress(
    value=0,
    min=0,
    max=min(len(prompts), max_conversations),
    description='Progress:',
    bar_style='info',
    style={'bar_width': '50px'},
    orientation='horizontal'
)
status_widget = widgets.HTML(value="Starting generation...")
display(widgets.VBox([progress_widget, status_widget]))
print(f"🤖 Generating conversations using {api_provider} ({model_name})...")
print(f"📊 Target: {min(len(prompts), max_conversations)} conversations")
print("
⏳ This may take a while depending on your API provider and limits...")
conversations = []
try:
    for i, prompt in enumerate(tqdm_nb(prompts[:max_conversations])):
        conv = generator._generate_single_conversation(prompt)
        if conv:
            conversations.append(conv)
        progress_widget.value = i + 1
        status_widget.value = f'Generated {i+1}/{min(len(prompts), max_conversations)}'
    print(f'
✅ Generation completed!')
    print(f'📈 Success rate: {len(conversations)/min(len(prompts), max_conversations)*100:.1f}%')
except Exception as e:
    status_widget.value = f'❌ Error: {str(e)}'
    print(f'
❌ Error during generation: {e}')

## 📊 Step 3: Analyze Results

Let's analyze what we generated:

In [None]:
if conversations:
    analysis_data = []
    total_qa_pairs = 0
    for conv in conversations:
        qa_count = len(conv.get('qa_pairs', []))
        total_qa_pairs += qa_count
        analysis_data.append({
            'Scenario ID': conv['scenario_id'],
            'Category': conv['category'],
            'Difficulty': conv['difficulty_level'],
            'QA Pairs': qa_count,
            'Has Content': qa_count > 0
        })
    analysis_df = pd.DataFrame(analysis_data)
    print(f'📊 Analysis Results:')
    print(f'   Total Conversations: {len(conversations)}')
    print(f'   Total Q&A Pairs: {total_qa_pairs}')
    print(f'   Average Q&A per Conversation: {total_qa_pairs/len(conversations):.1f}')
    print(f'   Success Rate: {analysis_df["Has Content"].sum()/len(analysis_df)*100:.1f}%')
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    category_counts = analysis_df['Category'].value_counts()
    category_counts.plot(kind='bar', ax=axes[0,0], color='skyblue')
    axes[0,0].set_title('Conversations by Category')
    axes[0,0].set_xlabel('Category')
    axes[0,0].set_ylabel('Count')
    axes[0,0].tick_params(axis='x', rotation=45)
    difficulty_counts = analysis_df['Difficulty'].value_counts()
    axes[0,1].pie(difficulty_counts.values, labels=difficulty_counts.index, autopct='%1.1f%%')
    axes[0,1].set_title('Distribution by Difficulty')
    analysis_df['QA Pairs'].hist(bins=20, ax=axes[1,0], color='lightgreen', alpha=0.7)
    axes[1,0].set_title('Q&A Pairs per Conversation')
    axes[1,0].set_xlabel('Number of Q&A Pairs')
    axes[1,0].set_ylabel('Frequency')
    success_by_category = analysis_df.groupby('Category')['Has Content'].mean().sort_values(ascending=False)
    success_by_category.plot(kind='bar', ax=axes[1,1], color='coral')
    axes[1,1].set_title('Success Rate by Category')
    axes[1,1].set_xlabel('Category')
    axes[1,1].set_ylabel('Success Rate')
    axes[1,1].tick_params(axis='x', rotation=45)
    plt.tight_layout()
    plt.show()
    if conversations and conversations[0].get('qa_pairs'):
        print('
💬 Sample Generated Conversation:')
        sample_conv = conversations[0]
        print(f'Scenario: {sample_conv['scenario_id']} ({sample_conv['category']})')
        print(f'Difficulty: {sample_conv['difficulty_level']}')
        print('Dialogue:')
        for i, qa in enumerate(sample_conv['qa_pairs'][:2]):
            print(f'
Therapist: {qa['prompt']}')
            print(f'Client: {qa['response']}')
        if len(sample_conv['qa_pairs']) > 2:
            print(f'
... ({len(sample_conv['qa_pairs']) - 2} more exchanges)')
else:
    print('❌ No conversations generated. Please check your configuration and try again.')

## 🔄 Step 4: Create Training Format

Convert to the format needed for training:

In [None]:
if conversations:
    print('🔄 Converting to training format...')
    training_data = generator.create_training_format(conversations)
    print(f'
✅ Created {len(training_data)} training examples')
    print(f'📁 Saved to: {generator.output_dir}/edge_cases_training_format.jsonl')
    if training_data:
        print('
📋 Sample Training Data:')
        sample = training_data[0]
        print(f'Category: {sample['category']}')
        print(f'Difficulty: {sample['difficulty_level']}')
        print(f'Purpose: {sample['purpose']}')
        print(f'Prompt: {sample['prompt'][:100]}...')
        print(f'Response: {sample['response'][:100]}...')
else:
    print('❌ No conversations available for training format conversion.')

## 📄 Step 5: Generate Summary Report

Create a comprehensive report of the generation process:

In [None]:
if conversations:
    print('📄 Generating summary report...')
    report = generator.generate_summary_report(conversations)
    print(f'
✅ Report generated and saved to: {generator.output_dir}/summary_report.md')
    print('
' + '='*60)
    print(report)
    print('='*60)
else:
    print('❌ No conversations available for report generation.')

## 📁 File Management

View and manage your generated files:

In [None]:
output_path = Path(output_dir)
if output_path.exists():
    files = list(output_path.glob('*'))
    print(f'📁 Files in {output_path}:')
    print('-' * 50)
    total_size = 0
    for file in sorted(files):
        if file.is_file():
            size = file.stat().st_size
            total_size += size
            print(f'📄 {file.name:<30} {size:>10,} bytes')
    print('-' * 50)
    print(f'📊 Total size: {total_size:,} bytes ({total_size/1024/1024:.1f} MB)')
    import zipfile
    zip_path = Path(f'{output_dir}_complete.zip')
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for file in files:
            if file.is_file():
                zipf.write(file, file.name)
    print(f'
📦 Created zip file: {zip_path}')
    print(f'💾 Zip size: {zip_path.stat().st_size:,} bytes')
else:
    print(f'❌ Output directory {output_path} not found')

## 🎯 Next Steps

- Review sample conversations for realism and appropriateness
- Integrate with your main training pipeline
- Use in evaluation framework
- Download results before your Colab session ends!

---

**Happy generating! 🤖✨**