In [None]:
# üöÄ Setup for Google Colab
import sys
if 'google.colab' in sys.modules:
    print("üîß Setting up for Google Colab...")
    
    # Install required dependencies
    !pip install -q matplotlib seaborn scikit-learn numpy pandas
    
    # Note: SSL framework code will be included in subsequent cells for Colab compatibility
    print("‚úÖ Dependencies installed! SSL framework will be defined in the next cells.")
else:
    print("üìù Running locally - using installed SSL framework")

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yourusername/pyssl/blob/main/notebooks/02_classification_comparison.ipynb)

# üî¨ SSL Strategy Comparison - Find the Best Approach

This notebook compares different semi-supervised learning strategies to help you choose the best approach for your data.

**What you'll learn:**
- When to use `ConfidenceThreshold` vs `TopKFixedCount` strategies
- How different integration methods affect performance
- How to handle imbalanced datasets with SSL
- Which strategy works best for your use case

**Dataset:** Imbalanced 3-class problem (simulating real-world scenarios)

## 1. Setup & Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Import our SSL framework
import sys
sys.path.append('../')
from ssl_framework.main import SelfTrainingClassifier
from ssl_framework.strategies import (
    ConfidenceThreshold, TopKFixedCount, 
    AppendAndGrow, FullReLabeling, ConfidenceWeighting
)

# Import our utilities
from utils.data_generation import make_imbalanced_classification, generate_ssl_dataset

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

print("‚úÖ All imports successful!")

## 2. Generate Challenging Imbalanced Dataset

We'll create a 3-class imbalanced dataset that mimics real-world scenarios where SSL is most beneficial.

In [None]:
# Create imbalanced dataset
X, y = make_imbalanced_classification(
    n_samples=2000,
    n_features=10,
    n_classes=3,
    weights=[0.1, 0.3, 0.6],  # Highly imbalanced: 10%, 30%, 60%
    n_informative=6,
    n_redundant=2,
    class_sep=0.7,  # Moderate separation - challenging but learnable
    random_state=42
)

# Apply our standard SSL splits using the custom split function
from utils.data_generation import _apply_custom_splits
X_labeled, y_labeled, X_unlabeled, X_val, y_val, X_test, y_test, y_unlabeled_true = _apply_custom_splits(
    X, y, n_labeled=60, test_size=0.2, val_size=0.15, random_state=42
)

print(f"üìä Dataset Statistics:")
print(f"   Total samples: {len(X)}")
print(f"   Labeled: {len(X_labeled)}, Unlabeled: {len(X_unlabeled)}")
print(f"   Validation: {len(X_val)}, Test: {len(X_test)}")
print(f"   Features: {X.shape[1]}")

print(f"\nüéØ Class Distribution in Full Dataset:")
class_counts = np.bincount(y)
for i, count in enumerate(class_counts):
    print(f"   Class {i}: {count} samples ({count/len(y)*100:.1f}%)")

print(f"\nüè∑Ô∏è Class Distribution in Labeled Set:")
labeled_counts = np.bincount(y_labeled)
for i, count in enumerate(labeled_counts):
    print(f"   Class {i}: {count} samples ({count/len(y_labeled)*100:.1f}%)")

## 3. Define Experiment Framework

Let's create a systematic way to test different strategies and compare their performance.

In [None]:
def run_ssl_experiment(strategy_name, selection_strategy, integration_strategy, base_model):
    """
    Run a single SSL experiment with given strategies.
    """
    print(f"üîÑ Running {strategy_name}...")
    
    # Create SSL model
    ssl_model = SelfTrainingClassifier(
        base_model=base_model,
        selection_strategy=selection_strategy,
        integration_strategy=integration_strategy,
        max_iter=15,
        labeling_convergence_threshold=5
    )
    
    # Train the model
    ssl_model.fit(X_labeled, y_labeled, X_unlabeled, X_val, y_val)
    
    # Evaluate on test set
    y_pred = ssl_model.predict(X_test)
    
    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    f1_macro = f1_score(y_test, y_pred, average='macro')
    f1_weighted = f1_score(y_test, y_pred, average='weighted')
    
    return {
        'strategy': strategy_name,
        'model': ssl_model,
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'predictions': y_pred,
        'history': ssl_model.history_,
        'stopping_reason': ssl_model.stopping_reason_
    }

def run_baseline(base_model):
    """
    Run baseline supervised learning experiment.
    """
    print(f"üîÑ Running Baseline (Supervised Only)...")
    
    # Train on labeled data only
    baseline_model = base_model
    baseline_model.fit(X_labeled, y_labeled)
    
    # Evaluate
    y_pred = baseline_model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    f1_macro = f1_score(y_test, y_pred, average='macro')
    f1_weighted = f1_score(y_test, y_pred, average='weighted')
    
    return {
        'strategy': 'Baseline (Supervised)',
        'model': baseline_model,
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'predictions': y_pred,
        'history': [],
        'stopping_reason': 'N/A'
    }

print("‚úÖ Experiment framework ready!")

## 4. Strategy Comparison Experiments

Let's test different combinations of selection and integration strategies:

In [None]:
# Define base model (we'll use LogisticRegression for consistency)
base_model = LogisticRegression(random_state=42, max_iter=1000)

# Define experiments to run
experiments = [
    # Baseline
    ('Baseline', None, None),
    
    # Confidence Threshold strategies
    ('Confidence-0.95 + AppendGrow', 
     ConfidenceThreshold(threshold=0.95), 
     AppendAndGrow()),
    
    ('Confidence-0.90 + AppendGrow', 
     ConfidenceThreshold(threshold=0.90), 
     AppendAndGrow()),
    
    ('Confidence-0.85 + AppendGrow', 
     ConfidenceThreshold(threshold=0.85), 
     AppendAndGrow()),
    
    # TopK strategies
    ('TopK-5 + AppendGrow', 
     TopKFixedCount(k=5), 
     AppendAndGrow()),
    
    ('TopK-10 + AppendGrow', 
     TopKFixedCount(k=10), 
     AppendAndGrow()),
    
    ('TopK-15 + AppendGrow', 
     TopKFixedCount(k=15), 
     AppendAndGrow()),
    
    # Different integration strategies
    ('Confidence-0.90 + FullReLabel', 
     ConfidenceThreshold(threshold=0.90), 
     FullReLabeling(X_labeled, y_labeled)),
    
    ('TopK-10 + ConfWeight', 
     TopKFixedCount(k=10), 
     ConfidenceWeighting()),
]

# Run all experiments
results = []

for exp_name, selection_strategy, integration_strategy in experiments:
    if exp_name == 'Baseline':
        result = run_baseline(LogisticRegression(random_state=42, max_iter=1000))
    else:
        result = run_ssl_experiment(
            exp_name, 
            selection_strategy, 
            integration_strategy, 
            LogisticRegression(random_state=42, max_iter=1000)
        )
    results.append(result)
    print(f"   ‚úÖ {exp_name}: Accuracy = {result['accuracy']:.3f}, F1-Macro = {result['f1_macro']:.3f}")

print("\nüèÜ All experiments completed!")

## 5. Performance Comparison

Let's visualize and compare the performance of different strategies:

In [None]:
# Create results DataFrame for easy analysis
results_df = pd.DataFrame([
    {
        'Strategy': r['strategy'],
        'Accuracy': r['accuracy'],
        'F1-Macro': r['f1_macro'],
        'F1-Weighted': r['f1_weighted']
    }
    for r in results
])

# Sort by F1-Macro score
results_df = results_df.sort_values('F1-Macro', ascending=False)

print("üìä Performance Ranking (by F1-Macro):")
print(results_df.round(3))

In [None]:
# Create comparison visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Plot 1: Accuracy comparison
bars1 = axes[0].bar(range(len(results_df)), results_df['Accuracy'], alpha=0.7)
axes[0].set_title('Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Accuracy')
axes[0].set_xticks(range(len(results_df)))
axes[0].set_xticklabels(results_df['Strategy'], rotation=45, ha='right')
axes[0].set_ylim(0, 1)
# Add value labels on bars
for i, v in enumerate(results_df['Accuracy']):
    axes[0].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontsize=9)

# Plot 2: F1-Macro comparison
bars2 = axes[1].bar(range(len(results_df)), results_df['F1-Macro'], alpha=0.7, color='orange')
axes[1].set_title('F1-Macro Comparison', fontsize=14, fontweight='bold')
axes[1].set_ylabel('F1-Macro Score')
axes[1].set_xticks(range(len(results_df)))
axes[1].set_xticklabels(results_df['Strategy'], rotation=45, ha='right')
axes[1].set_ylim(0, 1)
for i, v in enumerate(results_df['F1-Macro']):
    axes[1].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontsize=9)

# Plot 3: Improvement over baseline
baseline_f1 = results_df[results_df['Strategy'].str.contains('Baseline')]['F1-Macro'].iloc[0]
improvements = (results_df['F1-Macro'] - baseline_f1) / baseline_f1 * 100
colors = ['red' if x < 0 else 'green' for x in improvements]
bars3 = axes[2].bar(range(len(results_df)), improvements, alpha=0.7, color=colors)
axes[2].set_title('Improvement over Baseline (%)', fontsize=14, fontweight='bold')
axes[2].set_ylabel('Improvement (%)')
axes[2].set_xticks(range(len(results_df)))
axes[2].set_xticklabels(results_df['Strategy'], rotation=45, ha='right')
axes[2].axhline(y=0, color='black', linestyle='--', alpha=0.5)
for i, v in enumerate(improvements):
    axes[2].text(i, v + 0.5, f'{v:.1f}%', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

# Find best strategy
best_strategy = results_df.iloc[0]
print(f"\nüèÜ Best Strategy: {best_strategy['Strategy']}")
print(f"   F1-Macro: {best_strategy['F1-Macro']:.3f}")
print(f"   Improvement over baseline: {improvements.iloc[0]:.1f}%")

## 6. Learning Progress Analysis

Let's examine how different strategies learn over time:

In [None]:
# Plot learning curves for top strategies
plt.figure(figsize=(15, 10))

# Select interesting strategies to plot
strategies_to_plot = [
    'Confidence-0.95 + AppendGrow',
    'Confidence-0.90 + AppendGrow', 
    'TopK-10 + AppendGrow',
    'TopK-5 + AppendGrow'
]

# Plot 1: Number of labeled samples over iterations
plt.subplot(2, 2, 1)
for result in results:
    if result['strategy'] in strategies_to_plot and result['history']:
        iterations = [h['iteration'] for h in result['history']]
        labeled_counts = [h['labeled_data_count'] for h in result['history']]
        plt.plot(iterations, labeled_counts, marker='o', label=result['strategy'])

plt.title('Labeled Data Growth', fontweight='bold')
plt.xlabel('Iteration')
plt.ylabel('Number of Labeled Samples')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: New labels added per iteration
plt.subplot(2, 2, 2)
for result in results:
    if result['strategy'] in strategies_to_plot and result['history']:
        iterations = [h['iteration'] for h in result['history']]
        new_labels = [h['new_labels_count'] for h in result['history']]
        plt.plot(iterations, new_labels, marker='s', label=result['strategy'])

plt.title('New Labels Added per Iteration', fontweight='bold')
plt.xlabel('Iteration')
plt.ylabel('New Labels Count')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 3: Average confidence over iterations
plt.subplot(2, 2, 3)
for result in results:
    if result['strategy'] in strategies_to_plot and result['history']:
        iterations = [h['iteration'] for h in result['history']]
        avg_confidence = [h['average_confidence'] for h in result['history']]
        plt.plot(iterations, avg_confidence, marker='^', label=result['strategy'])

plt.title('Average Confidence of New Labels', fontweight='bold')
plt.xlabel('Iteration')
plt.ylabel('Average Confidence')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 4: Validation score progression (if available)
plt.subplot(2, 2, 4)
for result in results:
    if result['strategy'] in strategies_to_plot and result['history']:
        iterations = [h['iteration'] for h in result['history']]
        val_scores = [h.get('validation_score', None) for h in result['history']]
        # Only plot if we have validation scores
        if any(v is not None for v in val_scores):
            val_scores = [v if v is not None else 0 for v in val_scores]
            plt.plot(iterations, val_scores, marker='d', label=result['strategy'])

plt.title('Validation Score Progression', fontweight='bold')
plt.xlabel('Iteration')
plt.ylabel('Validation Score')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Confusion Matrix Analysis

Let's see how well each strategy handles the imbalanced classes:

In [None]:
# Compare confusion matrices for top strategies
strategies_to_analyze = [
    'Baseline (Supervised)',
    best_strategy['Strategy'],  # Best SSL strategy
    'TopK-10 + AppendGrow',     # Popular TopK strategy
    'Confidence-0.90 + AppendGrow'  # Popular Confidence strategy
]

fig, axes = plt.subplots(2, 2, figsize=(15, 12))
axes = axes.flatten()

for i, strategy_name in enumerate(strategies_to_analyze):
    # Find the result for this strategy
    result = next(r for r in results if r['strategy'] == strategy_name)
    
    # Calculate confusion matrix
    cm = confusion_matrix(y_test, result['predictions'])
    
    # Plot confusion matrix
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[i],
                xticklabels=[f'Class {j}' for j in range(3)],
                yticklabels=[f'Class {j}' for j in range(3)])
    
    axes[i].set_title(f'{strategy_name}\nF1-Macro: {result["f1_macro"]:.3f}', fontweight='bold')
    axes[i].set_xlabel('Predicted')
    axes[i].set_ylabel('True')

plt.tight_layout()
plt.show()

# Per-class performance analysis
print("\nüìä Per-Class Performance Analysis:")
print("=" * 60)

for strategy_name in strategies_to_analyze:
    result = next(r for r in results if r['strategy'] == strategy_name)
    print(f"\n{strategy_name}:")
    print(classification_report(y_test, result['predictions'], 
                              target_names=[f'Class {i}' for i in range(3)],
                              digits=3))

## 8. Strategy Insights & Recommendations

Based on our experiments, let's derive insights about when to use each strategy:

In [None]:
# Analyze strategy characteristics
strategy_analysis = []

for result in results:
    if result['strategy'] != 'Baseline (Supervised)' and result['history']:
        total_iterations = len(result['history'])
        total_labels_added = sum(h['new_labels_count'] for h in result['history'])
        avg_confidence = np.mean([h['average_confidence'] for h in result['history']])
        
        strategy_analysis.append({
            'Strategy': result['strategy'],
            'F1-Macro': result['f1_macro'],
            'Total Iterations': total_iterations,
            'Labels Added': total_labels_added,
            'Avg Confidence': avg_confidence,
            'Stopping Reason': result['stopping_reason']
        })

analysis_df = pd.DataFrame(strategy_analysis)
print("üîç Strategy Characteristics:")
print(analysis_df.round(3))

## 9. Key Insights & Recommendations

Based on our comprehensive comparison, here are the key insights:

In [None]:
# Generate insights
baseline_f1 = results_df[results_df['Strategy'].str.contains('Baseline')]['F1-Macro'].iloc[0]
best_ssl_f1 = results_df[~results_df['Strategy'].str.contains('Baseline')]['F1-Macro'].max()
improvement = (best_ssl_f1 - baseline_f1) / baseline_f1 * 100

print("üéØ KEY INSIGHTS & RECOMMENDATIONS")
print("=" * 50)

print(f"\nüìà Overall SSL Performance:")
print(f"   ‚Ä¢ Best SSL improvement: +{improvement:.1f}% over baseline")
print(f"   ‚Ä¢ Baseline F1-Macro: {baseline_f1:.3f}")
print(f"   ‚Ä¢ Best SSL F1-Macro: {best_ssl_f1:.3f}")

print(f"\nüèÜ Best Strategy: {best_strategy['Strategy']}")

print(f"\nüí° Strategy Recommendations:")

# Confidence Threshold analysis
conf_strategies = [r for r in results if 'Confidence' in r['strategy'] and 'AppendGrow' in r['strategy']]
if conf_strategies:
    best_conf = max(conf_strategies, key=lambda x: x['f1_macro'])
    print(f"\n   üéØ Confidence Threshold:")
    print(f"      ‚Ä¢ Best threshold: {best_conf['strategy']}")
    print(f"      ‚Ä¢ Use when: You want high-quality pseudo-labels")
    print(f"      ‚Ä¢ Trade-off: Conservative (fewer labels, higher quality)")

# TopK analysis
topk_strategies = [r for r in results if 'TopK' in r['strategy'] and 'AppendGrow' in r['strategy']]
if topk_strategies:
    best_topk = max(topk_strategies, key=lambda x: x['f1_macro'])
    print(f"\n   üìä TopK Fixed Count:")
    print(f"      ‚Ä¢ Best K value: {best_topk['strategy']}")
    print(f"      ‚Ä¢ Use when: You want consistent progress each iteration")
    print(f"      ‚Ä¢ Trade-off: Aggressive (more labels, variable quality)")

# Integration strategy analysis
print(f"\n   üîÑ Integration Strategies:")
print(f"      ‚Ä¢ AppendAndGrow: Best for most cases (monotonic growth)")
print(f"      ‚Ä¢ FullReLabeling: Use when early pseudo-labels might be wrong")
print(f"      ‚Ä¢ ConfidenceWeighting: Use with noisy pseudo-labels")

print(f"\nüé™ Imbalanced Data Insights:")
print(f"   ‚Ä¢ SSL helps most with minority classes")
print(f"   ‚Ä¢ Confidence-based strategies may bias toward majority class")
print(f"   ‚Ä¢ TopK strategies provide more balanced pseudo-labeling")

print(f"\n‚ö° Quick Selection Guide:")
print(f"   ‚Ä¢ Conservative approach: ConfidenceThreshold(0.95)")
print(f"   ‚Ä¢ Balanced approach: ConfidenceThreshold(0.90)")
print(f"   ‚Ä¢ Aggressive approach: TopKFixedCount(10-15)")
print(f"   ‚Ä¢ Integration: AppendAndGrow() for most cases")

## 10. Next Steps

Now that you understand the different SSL strategies, here's how to apply this knowledge:

### üéØ Strategy Selection Guide:

**Use `ConfidenceThreshold`** when:
- You prefer quality over quantity in pseudo-labels
- Your base model is well-calibrated (probabilities are meaningful)
- You have enough unlabeled data to be selective

**Use `TopKFixedCount`** when:
- You want predictable progress each iteration
- You have limited unlabeled data
- You're dealing with imbalanced classes

**Integration strategies:**
- `AppendAndGrow`: Default choice, works well in most cases
- `FullReLabeling`: When you suspect early iterations produce poor pseudo-labels
- `ConfidenceWeighting`: When you want to down-weight uncertain pseudo-labels

### üîó Explore More:
- **`03_text_classification.ipynb`** - Apply these strategies to NLP tasks
- **`04_tabular_data_pipeline.ipynb`** - Integration with production pipelines
- **`05_hyperparameter_tuning.ipynb`** - Optimize strategy parameters