# SANS Hyperparameter Tuning and Analysis Notebook

This comprehensive notebook provides:
1. **Hyperparameter grid search** for SANS (Self-Adversarial Negative Sampling)
2. **Real-time monitoring** of training metrics and SANS behavior
3. **Comparative analysis** across different configurations
4. **Debugging visualizations** for understanding SANS dynamics

## Key Features:
- Automated experiment management
- Live training visualization
- SANS correlation debugging
- Energy distribution analysis
- Optimal hyperparameter identification

## 1. Environment Setup and Dependencies

In [None]:
# Check if running in Google Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab")
    # Mount Google Drive for persistent storage
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Set working directory
    import os
    WORK_DIR = '/content/drive/MyDrive/sans_experiments'
    os.makedirs(WORK_DIR, exist_ok=True)
    os.chdir(WORK_DIR)
else:
    print("Running locally")
    import os
    WORK_DIR = os.getcwd()

print(f"Working directory: {WORK_DIR}")

In [None]:
# Install required packages
!pip install -q accelerate ema-pytorch einops tabulate tqdm matplotlib seaborn pandas plotly ipywidgets
print("✓ Dependencies installed")

# Import standard libraries
import os
import sys
import json
import time
import subprocess
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import itertools

# Import data science libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets

# Import PyTorch
import torch

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

print("✓ Libraries imported")

In [None]:
# GPU detection and optimization
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_memory:.2f} GB")
    
    # Set optimal batch size based on GPU
    if 'T4' in gpu_name:
        DEFAULT_BATCH_SIZE = 256
    elif 'V100' in gpu_name:
        DEFAULT_BATCH_SIZE = 512
    elif 'A100' in gpu_name:
        DEFAULT_BATCH_SIZE = 1024
    else:
        DEFAULT_BATCH_SIZE = 128
else:
    print("⚠ No GPU detected - training will be slow")
    DEFAULT_BATCH_SIZE = 32

print(f"\nDefault batch size: {DEFAULT_BATCH_SIZE}")

In [None]:
# Clone or update the repository
REPO_DIR = Path(WORK_DIR) / 'energy-based-model'

if not REPO_DIR.exists():
    print("Cloning repository...")
    !rm -rf energy-based-model
    !git clone https://github.com/mdkrasnow/energy-based-model.git
else:
    print("Repository exists, pulling latest changes...")
    os.chdir(REPO_DIR)
    !git pull

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))

# Import project modules
from utils.sans_analysis import (
    load_sans_metrics,
    analyze_correlation_quality,
    analyze_entropy_dynamics,
    analyze_energy_separation,
    plot_sans_diagnostics,
    compare_sans_configurations,
    generate_hyperparameter_report
)

print(f"✓ Repository ready at {REPO_DIR}")

## 2. Experiment Configuration

In [None]:
# Define hyperparameter grid for SANS
HYPERPARAMETER_GRID = {
    'sans_num_negs': [4, 8, 16, 32, 64],  # Number of negative samples
    'sans_temp': [0.5, 1.0, 1.5, 2.0],     # Temperature parameter
    'sans_temp_schedule': [True, False],   # Temperature scheduling
}

# Fixed parameters for all experiments
FIXED_PARAMS = {
    'dataset': 'inverse',
    'model': 'mlp',
    'rank': 20,
    'batch_size': DEFAULT_BATCH_SIZE,
    'diffusion_steps': 10,
    'max_steps': 10000,  # Reduced for quick experiments
    'supervise_energy_landscape': True,
    'sans': True,
    'sans_debug': True,
    'data_workers': 2
}

# Task-specific configurations
TASK_CONFIGS = {
    'inverse': {
        'dataset': 'inverse',
        'rank': 20,
        'metric': 'mse'
    },
    'sudoku': {
        'dataset': 'sudoku-rrn',
        'model': 'sudoku',
        'metric': 'sudoku',
        'cond_mask': True,
        'batch_size': 64
    },
    'connectivity': {
        'dataset': 'connectivity',
        'model': 'gnn-conv-1d-v2',
        'metric': 'bce',
        'batch_size': 64
    }
}

# Select task
SELECTED_TASK = 'inverse'  # Change this to switch tasks
print(f"Selected task: {SELECTED_TASK}")

# Update fixed params with task-specific config
if SELECTED_TASK in TASK_CONFIGS:
    FIXED_PARAMS.update(TASK_CONFIGS[SELECTED_TASK])

# Generate all experiment configurations
def generate_experiment_configs():
    """Generate all combinations of hyperparameters."""
    configs = []
    
    # Get all combinations
    keys = list(HYPERPARAMETER_GRID.keys())
    values = [HYPERPARAMETER_GRID[k] for k in keys]
    
    for combo in itertools.product(*values):
        config = FIXED_PARAMS.copy()
        for i, key in enumerate(keys):
            config[key] = combo[i]
        
        # Create experiment name
        exp_name = f"sans_K{config['sans_num_negs']}_T{config['sans_temp']}"
        if config['sans_temp_schedule']:
            exp_name += "_sched"
        
        configs.append((exp_name, config))
    
    return configs

EXPERIMENT_CONFIGS = generate_experiment_configs()
print(f"Total experiments to run: {len(EXPERIMENT_CONFIGS)}")

# Add baseline configuration (no SANS)
baseline_config = FIXED_PARAMS.copy()
baseline_config['sans'] = False
baseline_config['supervise_energy_landscape'] = False
EXPERIMENT_CONFIGS.insert(0, ('baseline_no_sans', baseline_config))

print(f"\nFirst 5 experiment names:")
for name, _ in EXPERIMENT_CONFIGS[:5]:
    print(f"  - {name}")

## 3. Experiment Management

In [None]:
class ExperimentManager:
    """Manages running and tracking experiments."""
    
    def __init__(self, base_dir: str = 'experiments'):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(exist_ok=True)
        self.experiments_file = self.base_dir / 'experiments.json'
        self.load_experiments()
    
    def load_experiments(self):
        """Load existing experiments from file."""
        if self.experiments_file.exists():
            with open(self.experiments_file, 'r') as f:
                self.experiments = json.load(f)
        else:
            self.experiments = {}
    
    def save_experiments(self):
        """Save experiments to file."""
        with open(self.experiments_file, 'w') as f:
            json.dump(self.experiments, f, indent=2)
    
    def run_experiment(self, name: str, config: dict, force: bool = False):
        """Run a single experiment."""
        
        # Check if already run
        if name in self.experiments and not force:
            print(f"Experiment '{name}' already exists. Skipping...")
            return self.experiments[name]
        
        # Create experiment directory
        exp_dir = self.base_dir / name
        exp_dir.mkdir(exist_ok=True)
        
        # Save configuration
        config_file = exp_dir / 'config.json'
        with open(config_file, 'w') as f:
            json.dump(config, f, indent=2)
        
        # Build command
        cmd = ['python', 'train.py']
        for key, value in config.items():
            if key == 'results_dir':
                continue
            cmd.append(f'--{key.replace("_", "-")}')
            if not isinstance(value, bool):
                cmd.append(str(value))
            elif value is False:
                cmd[-1] = f'--no-{key.replace("_", "-")}'
        
        # Add results directory
        results_dir = str(exp_dir / 'results')
        os.makedirs(results_dir, exist_ok=True)
        
        # Export configuration
        cmd.extend(['--export-config', 'experiment_config'])
        
        print(f"\n{'='*60}")
        print(f"Running experiment: {name}")
        print(f"{'='*60}")
        print(f"Command: {' '.join(cmd)}")
        print(f"Results directory: {results_dir}")
        
        # Run training
        start_time = time.time()
        
        # Change to repo directory for running
        original_dir = os.getcwd()
        os.chdir(REPO_DIR)
        
        try:
            # Run with output capture
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                env={**os.environ, 'RESULTS_FOLDER': results_dir}
            )
            
            # Stream output
            for line in process.stdout:
                print(line, end='')
            
            process.wait()
            success = process.returncode == 0
            
        except Exception as e:
            print(f"Error running experiment: {e}")
            success = False
        finally:
            os.chdir(original_dir)
        
        end_time = time.time()
        duration = end_time - start_time
        
        # Record experiment
        self.experiments[name] = {
            'config': config,
            'start_time': start_time,
            'end_time': end_time,
            'duration': duration,
            'success': success,
            'results_dir': results_dir
        }
        
        self.save_experiments()
        
        print(f"\nExperiment '{name}' completed in {duration:.1f} seconds")
        return self.experiments[name]
    
    def run_all_experiments(self, configs: list, parallel: bool = False):
        """Run all experiments from configuration list."""
        
        for name, config in configs:
            self.run_experiment(name, config)
            
            # Add small delay between experiments
            time.sleep(2)
    
    def get_results_paths(self) -> dict:
        """Get paths to all experiment results."""
        paths = {}
        for name, exp in self.experiments.items():
            if exp.get('success'):
                paths[name] = exp['results_dir']
        return paths

# Create experiment manager
exp_manager = ExperimentManager(base_dir=f'experiments_{SELECTED_TASK}_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
print(f"Experiment manager initialized at: {exp_manager.base_dir}")

## 4. Run Experiments

In [None]:
# Option 1: Run a single test experiment
test_config = EXPERIMENT_CONFIGS[0]
print(f"Running test experiment: {test_config[0]}")
print(f"Configuration: {json.dumps(test_config[1], indent=2)}")

# Reduce steps for test
test_config[1]['max_steps'] = 1000

result = exp_manager.run_experiment(test_config[0], test_config[1])
print(f"\nTest complete: {result['success']}")

In [None]:
# Option 2: Run all experiments (WARNING: This will take a long time!)
# Uncomment to run all experiments

# print(f"Running {len(EXPERIMENT_CONFIGS)} experiments...")
# print("This will take several hours. Consider running overnight.")
# 
# # Confirm before running
# confirm = input("Are you sure you want to run all experiments? (yes/no): ")
# if confirm.lower() == 'yes':
#     exp_manager.run_all_experiments(EXPERIMENT_CONFIGS)
#     print("\n✓ All experiments complete!")
# else:
#     print("Cancelled.")

## 5. Real-time Training Monitor

In [None]:
class TrainingMonitor:
    """Real-time monitoring of training progress."""
    
    def __init__(self, results_dir: str):
        self.results_dir = Path(results_dir)
        self.metrics_file = self.results_dir / 'metrics.csv'
        self.sans_file = self.results_dir / 'sans_debug.csv'
    
    def load_current_metrics(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """Load current metrics from files."""
        metrics_df = pd.DataFrame()
        sans_df = pd.DataFrame()
        
        if self.metrics_file.exists():
            metrics_df = pd.read_csv(self.metrics_file)
        
        if self.sans_file.exists():
            sans_df = pd.read_csv(self.sans_file)
        
        return metrics_df, sans_df
    
    def create_live_dashboard(self):
        """Create interactive dashboard with Plotly."""
        
        metrics_df, sans_df = self.load_current_metrics()
        
        if len(metrics_df) == 0:
            print("No metrics available yet...")
            return
        
        # Create subplots
        fig = make_subplots(
            rows=3, cols=3,
            subplot_titles=(
                'Training Loss', 'Loss Components', 'Learning Rate',
                'SANS Correlation', 'Entropy Ratio', 'Energy Separation',
                'Temperature Schedule', 'Negative Energy Stats', 'Training Speed'
            ),
            specs=[[{}, {}, {}],
                   [{}, {}, {}],
                   [{}, {}, {}]]
        )
        
        # 1. Training Loss
        fig.add_trace(
            go.Scatter(x=metrics_df['step'], y=metrics_df['loss'],
                      mode='lines', name='Total Loss',
                      line=dict(color='blue')),
            row=1, col=1
        )
        
        # 2. Loss Components
        fig.add_trace(
            go.Scatter(x=metrics_df['step'], y=metrics_df['loss_denoise'],
                      mode='lines', name='Denoise',
                      line=dict(color='green')),
            row=1, col=2
        )
        fig.add_trace(
            go.Scatter(x=metrics_df['step'], y=metrics_df['loss_energy'],
                      mode='lines', name='Energy',
                      line=dict(color='red')),
            row=1, col=2
        )
        
        # 3. Learning Rate
        fig.add_trace(
            go.Scatter(x=metrics_df['step'], y=metrics_df['lr'],
                      mode='lines', name='LR',
                      line=dict(color='purple')),
            row=1, col=3
        )
        
        # SANS metrics if available
        if len(sans_df) > 0:
            # 4. SANS Correlation
            fig.add_trace(
                go.Scatter(x=sans_df['step'], y=sans_df['weight_energy_corr'],
                          mode='lines', name='Correlation',
                          line=dict(color='orange')),
                row=2, col=1
            )
            fig.add_hline(y=0, line_dash="dash", line_color="gray",
                         row=2, col=1)
            
            # 5. Entropy Ratio
            fig.add_trace(
                go.Scatter(x=sans_df['step'], y=sans_df['entropy_ratio'],
                          mode='lines', name='Entropy',
                          line=dict(color='brown')),
                row=2, col=2
            )
            
            # 6. Energy Separation
            fig.add_trace(
                go.Scatter(x=sans_df['step'], y=sans_df['real_energy_mean'],
                          mode='lines', name='Real Energy',
                          line=dict(color='green')),
                row=2, col=3
            )
            fig.add_trace(
                go.Scatter(x=sans_df['step'], y=sans_df['neg_energy_mean'],
                          mode='lines', name='Neg Energy',
                          line=dict(color='red')),
                row=2, col=3
            )
            
            # 7. Temperature Schedule
            fig.add_trace(
                go.Scatter(x=sans_df['step'], y=sans_df['alpha_effective'],
                          mode='lines', name='Alpha',
                          line=dict(color='purple')),
                row=3, col=1
            )
            
            # 8. Negative Energy Stats
            fig.add_trace(
                go.Scatter(x=sans_df['step'], y=sans_df['neg_energy_std'],
                          mode='lines', name='Neg Std',
                          line=dict(color='cyan')),
                row=3, col=2
            )
        
        # 9. Training Speed
        if 'time' in metrics_df.columns:
            steps_per_sec = metrics_df['step'] / metrics_df['time']
            fig.add_trace(
                go.Scatter(x=metrics_df['step'], y=steps_per_sec,
                          mode='lines', name='Steps/sec',
                          line=dict(color='black')),
                row=3, col=3
            )
        
        # Update layout
        fig.update_layout(
            height=900,
            showlegend=False,
            title_text="Training Dashboard",
            title_font_size=20
        )
        
        # Update axes
        fig.update_xaxes(title_text="Step")
        fig.update_yaxes(title_text="Value")
        
        return fig
    
    def monitor_live(self, refresh_interval: int = 5):
        """Monitor training with live updates."""
        
        print(f"Monitoring: {self.results_dir}")
        print(f"Refresh interval: {refresh_interval} seconds")
        print("Press Ctrl+C to stop monitoring")
        
        try:
            while True:
                clear_output(wait=True)
                
                # Load and display current metrics
                metrics_df, sans_df = self.load_current_metrics()
                
                if len(metrics_df) > 0:
                    current_step = metrics_df['step'].iloc[-1]
                    current_loss = metrics_df['loss'].iloc[-1]
                    
                    print(f"Step: {current_step} | Loss: {current_loss:.6f}")
                    
                    # Create and show dashboard
                    fig = self.create_live_dashboard()
                    if fig:
                        fig.show()
                else:
                    print("Waiting for training to start...")
                
                time.sleep(refresh_interval)
                
        except KeyboardInterrupt:
            print("\nMonitoring stopped.")

# Example: Monitor a specific experiment
# monitor = TrainingMonitor('experiments_inverse_20231124_120000/baseline_no_sans/results')
# monitor.monitor_live(refresh_interval=10)

## 6. Analysis and Visualization

In [None]:
# Load and compare all experiment results
results_paths = exp_manager.get_results_paths()

if len(results_paths) > 0:
    print(f"Found {len(results_paths)} completed experiments")
    
    # Compare configurations
    comparison_df = compare_sans_configurations(results_paths)
    
    # Display summary
    display(comparison_df[['experiment', 'final_loss', 'convergence_step', 
                           'corr_mean_correlation', 'entropy_mean_entropy_ratio']].head(10))
else:
    print("No completed experiments found. Run experiments first.")

In [None]:
# Create comparative visualizations
if len(results_paths) > 1:
    # Load metrics for all experiments
    all_metrics = {}
    for name, path in results_paths.items():
        metrics_df, sans_df = load_sans_metrics(path)
        all_metrics[name] = {'metrics': metrics_df, 'sans': sans_df}
    
    # Plot training curves comparison
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Experiment Comparison', fontsize=16)
    
    for name, data in all_metrics.items():
        metrics_df = data['metrics']
        sans_df = data['sans']
        
        if len(metrics_df) > 0:
            # Training loss
            axes[0, 0].plot(metrics_df['step'], metrics_df['loss'], 
                           label=name, alpha=0.7)
            
            # Loss components
            axes[0, 1].plot(metrics_df['step'], metrics_df['loss_energy'],
                           label=name, alpha=0.7)
        
        if len(sans_df) > 0:
            # SANS correlation
            axes[1, 0].plot(sans_df['step'], sans_df['weight_energy_corr'],
                           label=name, alpha=0.7)
            
            # Entropy ratio
            axes[1, 1].plot(sans_df['step'], sans_df['entropy_ratio'],
                           label=name, alpha=0.7)
    
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].set_xlabel('Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_yscale('log')
    axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].set_title('Energy Loss Component')
    axes[0, 1].set_xlabel('Step')
    axes[0, 1].set_ylabel('Energy Loss')
    axes[0, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[0, 1].grid(True, alpha=0.3)
    
    axes[1, 0].set_title('SANS Weight-Energy Correlation')
    axes[1, 0].set_xlabel('Step')
    axes[1, 0].set_ylabel('Correlation')
    axes[1, 0].axhline(y=0, color='k', linestyle='--', alpha=0.3)
    axes[1, 0].axhline(y=-0.3, color='g', linestyle='--', alpha=0.3)
    axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].set_title('SANS Entropy Ratio')
    axes[1, 1].set_xlabel('Step')
    axes[1, 1].set_ylabel('Entropy Ratio')
    axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Hyperparameter impact heatmap
if 'comparison_df' in locals() and len(comparison_df) > 0:
    # Extract hyperparameters from experiment names
    hp_data = []
    for _, row in comparison_df.iterrows():
        name = row['experiment']
        if 'sans_K' in name:
            # Parse SANS hyperparameters
            parts = name.split('_')
            k_val = int(parts[1][1:])  # Extract K value
            t_val = float(parts[2][1:])  # Extract T value
            scheduled = 'sched' in name
            
            hp_data.append({
                'K': k_val,
                'Temperature': t_val,
                'Scheduled': scheduled,
                'Final Loss': row.get('final_loss', None),
                'Correlation': row.get('corr_mean_correlation', None)
            })
    
    if hp_data:
        hp_df = pd.DataFrame(hp_data)
        
        # Create pivot table for heatmap
        pivot_loss = hp_df.pivot_table(
            values='Final Loss',
            index='Temperature',
            columns='K',
            aggfunc='mean'
        )
        
        # Plot heatmap
        plt.figure(figsize=(10, 6))
        sns.heatmap(pivot_loss, annot=True, fmt='.4f', cmap='RdYlGn_r',
                   cbar_kws={'label': 'Final Loss'})
        plt.title('Hyperparameter Impact on Final Loss')
        plt.xlabel('Number of Negatives (K)')
        plt.ylabel('Temperature (α)')
        plt.show()

## 7. Best Configuration Analysis

In [None]:
# Find and analyze best configuration
if 'comparison_df' in locals() and len(comparison_df) > 0:
    # Find best by final loss
    best_idx = comparison_df['final_loss'].idxmin()
    best_exp = comparison_df.loc[best_idx]
    
    print("="*60)
    print("BEST CONFIGURATION FOUND")
    print("="*60)
    print(f"Experiment: {best_exp['experiment']}")
    print(f"Final Loss: {best_exp['final_loss']:.6f}")
    print(f"Convergence Step: {best_exp.get('convergence_step', 'N/A')}")
    print(f"Mean Correlation: {best_exp.get('corr_mean_correlation', 'N/A'):.3f}")
    print(f"Mean Entropy Ratio: {best_exp.get('entropy_mean_entropy_ratio', 'N/A'):.3f}")
    print(f"Mean Energy Separation: {best_exp.get('energy_mean_separation', 'N/A'):.3f}")
    
    # Load and plot detailed diagnostics for best experiment
    if best_exp['experiment'] in results_paths:
        best_path = results_paths[best_exp['experiment']]
        metrics_df, sans_df = load_sans_metrics(best_path)
        
        if len(sans_df) > 0:
            print("\nGenerating detailed SANS diagnostics...")
            fig = plot_sans_diagnostics(sans_df, 
                                       save_path=f"{exp_manager.base_dir}/best_diagnostics.png")
            plt.show()

## 8. Generate Final Report

In [None]:
# Generate comprehensive report
if len(results_paths) > 0:
    report_path = exp_manager.base_dir / 'hyperparameter_report.txt'
    report = generate_hyperparameter_report(results_paths, save_path=str(report_path))
    
    print(report)
    print(f"\n✓ Report saved to {report_path}")
    
    # Save comparison dataframe
    csv_path = exp_manager.base_dir / 'comparison_results.csv'
    comparison_df.to_csv(csv_path, index=False)
    print(f"✓ Comparison data saved to {csv_path}")

## 9. Recommendations and Next Steps

In [None]:
print("="*60)
print("ANALYSIS COMPLETE - RECOMMENDATIONS")
print("="*60)
print()

if 'comparison_df' in locals() and len(comparison_df) > 0:
    # Analyze results
    baseline_loss = comparison_df[comparison_df['experiment'] == 'baseline_no_sans']['final_loss'].values
    if len(baseline_loss) > 0:
        baseline_loss = baseline_loss[0]
        sans_losses = comparison_df[comparison_df['experiment'] != 'baseline_no_sans']['final_loss']
        improvement = (baseline_loss - sans_losses.min()) / baseline_loss * 100
        
        if improvement > 0:
            print(f"✓ SANS shows {improvement:.1f}% improvement over baseline")
        else:
            print(f"⚠ SANS did not improve over baseline in this experiment")
    
    # Check correlation quality
    mean_corrs = comparison_df['corr_mean_correlation'].dropna()
    if len(mean_corrs) > 0:
        good_corr_ratio = (mean_corrs < -0.2).mean()
        print(f"\nCorrelation Quality:")
        print(f"  - {good_corr_ratio*100:.1f}% of experiments show good negative correlation")
        print(f"  - Best correlation: {mean_corrs.min():.3f}")
    
    print("\nRecommended Hyperparameters:")
    if 'best_exp' in locals():
        name = best_exp['experiment']
        if 'sans_K' in name:
            parts = name.split('_')
            print(f"  - Number of negatives (K): {parts[1][1:]}")
            print(f"  - Temperature (α): {parts[2][1:]}")
            print(f"  - Temperature scheduling: {'Yes' if 'sched' in name else 'No'}")

print("\nNext Steps:")
print("1. Run longer training with best configuration (50k+ steps)")
print("2. Test on more challenging datasets (sudoku, connectivity)")
print("3. Fine-tune temperature schedule for your specific task")
print("4. Consider adaptive K based on training progress")
print("5. Implement curriculum learning with increasing K over time")

print("\n" + "="*60)
print("NOTEBOOK COMPLETE")
print("="*60)