In [None]:
# 🚀 Setup for Google Colab
import sys
if 'google.colab' in sys.modules:
    print("🔧 Setting up for Google Colab...")
    
    # Install required dependencies
    !pip install -q matplotlib seaborn scikit-learn numpy pandas
    
    # Note: SSL framework code will be included in subsequent cells for Colab compatibility
    print("✅ Dependencies installed! SSL framework will be defined in the next cells.")
else:
    print("📝 Running locally - using installed SSL framework")

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yourusername/pyssl/blob/main/notebooks/05_hyperparameter_tuning.ipynb)

## 3. Key Insights & Best Practices

### 🧠 Key Tuning Insights:
1. **Strategy Selection**: Choice between Confidence vs TopK has biggest impact
2. **Threshold/K Values**: Most critical SSL-specific parameters  
3. **Iteration Limits**: Balance between performance and efficiency
4. **Base Model Regularization**: Important for SSL stability
5. **Method Efficiency**: Random search often competitive with Bayesian

### 📚 Quick Selection Guide:
- **Conservative approach**: ConfidenceThreshold(0.95)
- **Balanced approach**: ConfidenceThreshold(0.90) 
- **Aggressive approach**: TopKFixedCount(10-15)
- **Integration**: AppendAndGrow() for most cases

This completes our hyperparameter tuning exploration. The framework provides flexible optimization capabilities for finding the best SSL configuration for your specific use case.

In [None]:
# Define parameter grids for systematic exploration
param_grids = {
    'ConfidenceThreshold': {
        'threshold': [0.8, 0.85, 0.9, 0.95, 0.99]
    },
    'TopKFixedCount': {
        'k': [3, 5, 8, 10, 15, 20]
    }
}

base_model_grids = {
    'C': [0.1, 1.0, 10.0],
    'class_weight': [None, 'balanced']
}

# Run grid search on medium dataset
print("🔍 Running Grid Search on Medium Dataset...")
results = []

dataset_name = 'medium'
dataset = datasets[dataset_name]

for strategy_name, strategy_grid in param_grids.items():
    print(f"  Testing {strategy_name}...")
    
    # Try all strategy parameter combinations
    strategy_param_names = list(strategy_grid.keys())
    strategy_param_values = list(strategy_grid.values())
    
    for strategy_params in product(*strategy_param_values):
        strategy_config = dict(zip(strategy_param_names, strategy_params))
        
        # Try all base model parameter combinations
        base_param_names = list(base_model_grids.keys())
        base_param_values = list(base_model_grids.values())
        
        for base_params in product(*base_param_values):
            base_config = dict(zip(base_param_names, base_params))
            
            result = evaluate_ssl_config(
                dataset_name, dataset, strategy_name, 
                strategy_config, base_config
            )
            results.append(result)

results_df = pd.DataFrame(results)
print(f"✅ Completed {len(results)} configurations")
print("\n📊 Top 5 Configurations by F1-Macro:")
top_results = results_df.nlargest(5, 'f1_macro')
for _, row in top_results.iterrows():
    print(f"  {row['strategy']}: F1={row['f1_macro']:.3f}, Params={row['strategy_params']}")

In [None]:
def evaluate_ssl_config(dataset_name, dataset, strategy_name, strategy_params, base_model_params=None):
    """Evaluate a single SSL configuration."""
    X_labeled, y_labeled, X_unlabeled, X_val, y_val, X_test, y_test, y_unlabeled_true = dataset['data']
    
    # Create base model
    if base_model_params is None:
        base_model_params = {}
    base_model = LogisticRegression(random_state=42, max_iter=1000, **base_model_params)
    
    # Create strategy
    if strategy_name == 'ConfidenceThreshold':
        strategy = ConfidenceThreshold(**strategy_params)
    elif strategy_name == 'TopKFixedCount':
        strategy = TopKFixedCount(**strategy_params)
    else:
        raise ValueError(f"Unknown strategy: {strategy_name}")
    
    # Create SSL model
    ssl_model = SelfTrainingClassifier(
        base_model=base_model,
        selection_strategy=strategy,
        integration_strategy=AppendAndGrow(),
        max_iter=10,
        labeling_convergence_threshold=5
    )
    
    try:
        # Train
        ssl_model.fit(X_labeled, y_labeled, X_unlabeled, X_val, y_val)
        
        # Evaluate
        y_pred = ssl_model.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        f1_macro = f1_score(y_test, y_pred, average='macro')
        
        # Get training info
        final_labeled_count = len(ssl_model.X_labeled_)
        iterations = len(ssl_model.history_)
        
        return {
            'dataset': dataset_name,
            'strategy': strategy_name,
            'strategy_params': strategy_params,
            'base_model_params': base_model_params,
            'accuracy': accuracy,
            'f1_macro': f1_macro,
            'final_labeled_count': final_labeled_count,
            'iterations': iterations,
            'stopping_reason': ssl_model.stopping_reason_
        }
    except Exception as e:
        return {
            'dataset': dataset_name,
            'strategy': strategy_name,
            'strategy_params': strategy_params,
            'base_model_params': base_model_params,
            'accuracy': 0.0,
            'f1_macro': 0.0,
            'final_labeled_count': 0,
            'iterations': 0,
            'stopping_reason': f'Error: {str(e)}'
        }

print("✅ Evaluation function ready!")

## 2. Grid Search for SSL Parameters

Let's systematically explore the parameter space to understand which settings work best:

In [None]:
# Generate test datasets with different characteristics
datasets = {}

# Easy dataset - well-separated classes
easy_data = generate_ssl_dataset(
    dataset_type="moons",
    n_samples=800,
    n_labeled=20,
    test_size=0.2,
    val_size=0.1,
    random_state=42,
    noise=0.1
)
datasets['easy'] = {
    'data': easy_data,
    'name': 'Easy (Moons)',
    'description': 'Well-separated 2D data'
}

# Medium dataset - moderate separation
medium_data = generate_ssl_dataset(
    dataset_type="classification",
    n_samples=1000,
    n_labeled=40,
    test_size=0.2,
    val_size=0.1,
    random_state=42,
    n_features=10,
    n_classes=3,
    class_sep=0.8
)
datasets['medium'] = {
    'data': medium_data,
    'name': 'Medium (3-class)',
    'description': '10D, 3 classes, moderate separation'
}

# Hard dataset - challenging separation
hard_data = generate_ssl_dataset(
    dataset_type="classification",
    n_samples=1200,
    n_labeled=60,
    test_size=0.2,
    val_size=0.1,
    random_state=42,
    n_features=20,
    n_classes=4,
    class_sep=0.6
)
datasets['hard'] = {
    'data': hard_data,
    'name': 'Hard (4-class)',
    'description': '20D, 4 classes, difficult separation'
}

print("📊 Dataset Summary:")
for key, dataset in datasets.items():
    X_labeled, y_labeled, X_unlabeled, X_val, y_val, X_test, y_test, y_unlabeled_true = dataset['data']
    print(f"  {dataset['name']}: {len(X_labeled)} labeled, {len(X_unlabeled)} unlabeled, {len(X_test)} test")
    print(f"    Features: {X_labeled.shape[1]}, Classes: {len(np.unique(y_labeled))}")
    print(f"    Description: {dataset['description']}")
    print()

## 1. Setup & Dataset Generation

Let's create datasets with different characteristics to test our hyperparameter optimization:

# 🎯 SSL Hyperparameter Tuning - Finding Optimal Configurations

This notebook demonstrates systematic hyperparameter optimization for semi-supervised learning. We'll explore which parameters matter most and how to tune them efficiently.

**What you'll learn:**
- Which SSL parameters have the biggest impact on performance
- How to set up systematic hyperparameter search for SSL
- GridSearch, RandomSearch, and Bayesian optimization for SSL
- Parameter sensitivity analysis and visualization
- Best practices for SSL hyperparameter tuning

**Dataset:** Multi-class classification with varying difficulty levels

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
from itertools import product
import warnings
warnings.filterwarnings('ignore')

# Import our SSL framework
import sys
sys.path.append('../')
from ssl_framework.main import SelfTrainingClassifier
from ssl_framework.strategies import (
    ConfidenceThreshold, TopKFixedCount, 
    AppendAndGrow, FullReLabeling, ConfidenceWeighting
)

# Import utilities
from utils.data_generation import generate_ssl_dataset, create_ssl_benchmark

# Set style
plt.style.use('default')
sns.set_palette("husl")

print("✅ All imports successful!")