In [1]:
import os
import json
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from datetime import datetime
import pandas as pd
from collections import Counter
from typing import Dict, List, Tuple, Optional
import torch.nn.functional as F
from tools_clean.config_clean import device, DATA_FILES, CONFIG, cuda_manager
from tools_clean.classes_clean import EnsembleModel, SleepDataManager, SleepStageEvaluator
from tools_clean.functions_clean import convert_to_serializable, format_class_distribution
from tools_clean.utils_clean import *

Setting NUMEXPR_MAX_THREADS to 24 out of 48 total cores.


2024-11-18 13:30:18,712 - INFO - 
GPU Configuration:
    Physical GPU ID: 0 (NVIDIA TITAN V)
    Logical Device ID: cuda:0
    Compute Capability: 7.0
    Total Memory: 12.65GB
    Free Memory: 12.65GB
    Memory Fraction: 0.80
    Multi-Processors: 80
                    


In [2]:
def load_sleep_data(model_dir: str, 
                   new_data_files: List[str], 
                   model_params: Optional[Dict] = None) -> Tuple[EnsembleModel, SleepDataManager]:
    """
    Load the model and data
    
    Args:
        model_dir: Directory containing the saved model
        new_data_files: List of data files to evaluate
        model_params: Optional model parameters
        
    Returns:
        Tuple of (loaded model, data manager)
    """
    # Load model configuration if not provided
    if model_params is None:
        config_path = os.path.join(model_dir, 'model_config.json')
        with open(config_path, 'r') as f:
            config = json.load(f)
            model_params = config['model_params']
    
    # Load model
    model_path = os.path.join(model_dir, 'best_model.pt')
    model = EnsembleModel(model_params).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    # Load data
    logging.info(f"\nLoading new data from {len(new_data_files)} files...")
    data_manager = SleepDataManager(
        data_files=new_data_files,
        val_ratio=0.0
    )
    data_manager.load_and_preprocess()
    
    # Verify data shapes
    logging.info("Verifying data shapes...")
    for night_idx in range(len(new_data_files)):
        night_mask = data_manager.data['night_idx'] == night_idx
        night_data = data_manager.data['x'][night_mask]
        if night_data.shape[1] != 4 or night_data.shape[2] != 3000:
            logging.warning(f"Unexpected data shape in night {night_idx}: {night_data.shape}")
            
    return model, data_manager



In [3]:

model_dir = './models/new7/'
new_data_files = DATA_FILES[60:76]
model, data_manager = load_sleep_data(model_dir, new_data_files)



  model.load_state_dict(torch.load(model_path))
2024-11-18 13:30:23,445 - INFO - 
Loading new data from 16 files...
Loading Data Files:   0%|          | 0/16 [00:00<?, ?file/s]2024-11-18 13:30:23,449 - INFO - Loading file 1/16: preprocessed_data_209_N1.mat
2024-11-18 13:30:25,192 - INFO - Spectral features shape: (996, 16)
2024-11-18 13:30:25,212 - INFO - Successfully loaded night 1: 996 samples
Loading Data Files:   6%|▋         | 1/16 [00:01<00:27,  1.80s/file]2024-11-18 13:30:25,254 - INFO - Loading file 2/16: preprocessed_data_289_N1.mat
2024-11-18 13:30:25,408 - ERROR - Error loading /userdata/jkrolik/eeg-sleepstage-classifier/tools_clean/../preprocessing/preprocessed_data/preprocessed_data_289_N1.mat: could not read bytes
2024-11-18 13:30:25,409 - ERROR - Skipping this file and continuing...
2024-11-18 13:30:25,471 - INFO - Loading file 3/16: preprocessed_data_321_N1.mat
2024-11-18 13:30:27,141 - INFO - Spectral features shape: (1012, 16)
2024-11-18 13:30:27,150 - INFO - Successf

2024-11-18 13:30:28,809 - INFO - Spectral features shape: (1041, 16)
2024-11-18 13:30:28,819 - INFO - Successfully loaded night 4: 1041 samples
Loading Data Files:  19%|█▉        | 3/16 [00:05<00:22,  1.77s/file]2024-11-18 13:30:28,821 - INFO - Loading file 5/16: preprocessed_data_207_N2.mat
2024-11-18 13:30:30,864 - INFO - Spectral features shape: (1009, 16)
2024-11-18 13:30:30,877 - INFO - Successfully loaded night 5: 1009 samples
Loading Data Files:  25%|██▌       | 4/16 [00:07<00:22,  1.89s/file]2024-11-18 13:30:30,914 - INFO - Loading file 6/16: preprocessed_data_247_N2.mat
2024-11-18 13:30:32,555 - INFO - Spectral features shape: (1125, 16)
2024-11-18 13:30:32,567 - INFO - Successfully loaded night 6: 1125 samples
Loading Data Files:  31%|███▏      | 5/16 [00:09<00:20,  1.82s/file]2024-11-18 13:30:32,613 - INFO - Loading file 7/16: preprocessed_data_334_N2.mat
2024-11-18 13:30:34,129 - INFO - Spectral features shape: (1013, 16)
2024-11-18 13:30:34,138 - INFO - Successfully loaded

In [4]:
import numpy as np
from typing import Dict, List, Tuple, Optional
import torch
import torch.nn.functional as F

class EnhancedSleepTransitionAnalyzer:
    """Enhanced analyzer for sleep stage transitions and prediction confidence"""
    
    def __init__(self, window_size: int = 3, min_stable_length: int = 5):
        """
        Args:
            window_size: Base window size for transition analysis
            min_stable_length: Minimum length of a sleep stage to be considered stable
        """
        self.window_size = window_size
        self.min_stable_length = min_stable_length
        
    def find_stage_boundaries(self, labels: np.ndarray) -> List[Tuple[int, int, int]]:
        """
        Find boundaries of each sleep stage sequence.
        
        Returns:
            List of tuples (start_idx, end_idx, stage_length)
        """
        boundaries = []
        current_start = 0
        current_stage = labels[0]
        
        for i in range(1, len(labels)):
            if labels[i] != current_stage:
                boundaries.append((
                    current_start,
                    i - 1,
                    i - current_start
                ))
                current_start = i
                current_stage = labels[i]
                
        # Add final stage
        boundaries.append((
            current_start,
            len(labels) - 1,
            len(labels) - current_start
        ))
        
        return boundaries
    
    def calculate_dynamic_transition_tendency(self, labels: np.ndarray) -> np.ndarray:
        """
        Calculate transition tendency with dynamic window sizes based on stage length.
        
        Args:
            labels: Array of sleep stage labels
            
        Returns:
            Array of transition tendency scores (0-1)
        """
        n = len(labels)
        tendencies = np.zeros(n)
        stage_boundaries = self.find_stage_boundaries(labels)
        
        for start_idx, end_idx, stage_length in stage_boundaries:
            # Adjust window size based on stage length
            dynamic_window = min(
                max(self.window_size, stage_length // 4),
                stage_length // 2
            )
            
            # Calculate position-based tendency within the stage
            for pos in range(start_idx, end_idx + 1):
                # Distance from start and end of stage
                dist_start = pos - start_idx
                dist_end = end_idx - pos
                
                # Calculate tendencies from both boundaries
                start_tendency = np.exp(-dist_start / dynamic_window) if dist_start < dynamic_window else 0
                end_tendency = np.exp(-dist_end / dynamic_window) if dist_end < dynamic_window else 0
                
                # Combine tendencies
                tendencies[pos] = max(start_tendency, end_tendency)
                
                # Adjust for very short stages
                if stage_length < self.min_stable_length:
                    tendencies[pos] = max(tendencies[pos], 0.5)
        
        return tendencies
    
    def analyze_predictions(self, 
                          model_outputs: torch.Tensor, 
                          true_labels: np.ndarray) -> Dict:
        """
        Analyze model predictions with confidence scores.
        
        Args:
            model_outputs: Raw model outputs (before softmax)
            true_labels: Ground truth labels
            
        Returns:
            Dictionary containing detailed analysis
        """
        # Calculate probabilities and confidence scores
        probs = F.softmax(model_outputs, dim=1)
        predictions = probs.argmax(dim=1).cpu().numpy()
        confidences = probs.max(dim=1)[0].cpu().numpy()
        
        # Get top-2 predictions and their probabilities
        top2_values, top2_indices = torch.topk(probs, k=2, dim=1)
        top2_values = top2_values.cpu().numpy()
        top2_indices = top2_indices.cpu().numpy()
        
        # Calculate prediction uncertainty
        entropy = -torch.sum(probs * torch.log2(probs + 1e-10), dim=1).cpu().numpy()
        max_entropy = -np.log2(1/probs.shape[1])  # Maximum possible entropy
        uncertainty = entropy / max_entropy
        
        # Calculate transition tendencies
        tendencies = self.calculate_dynamic_transition_tendency(true_labels)
        
        # Analyze prediction patterns
        analysis = []
        for i in range(len(true_labels)):
            analysis.append({
                'position': i,
                'true_label': int(true_labels[i]),
                'predicted_label': int(predictions[i]),
                'is_correct': predictions[i] == true_labels[i],
                'confidence': float(confidences[i]),
                'uncertainty': float(uncertainty[i]),
                'transition_tendency': float(tendencies[i]),
                'top_predictions': [
                    {
                        'label': int(top2_indices[i, j]),
                        'probability': float(top2_values[i, j])
                    } for j in range(2)
                ],
                'stage_context': {
                    'in_transition': tendencies[i] > 0.5,
                    'confidence_ratio': float(top2_values[i, 0] / (top2_values[i, 1] + 1e-10))
                }
            })
        
        return {
            'detailed_analysis': analysis,
            'summary': {
                'overall_accuracy': (predictions == true_labels).mean(),
                'transition_accuracy': (predictions[tendencies > 0.5] == true_labels[tendencies > 0.5]).mean(),
                'stable_accuracy': (predictions[tendencies <= 0.5] == true_labels[tendencies <= 0.5]).mean(),
                'mean_confidence': confidences.mean(),
                'mean_uncertainty': uncertainty.mean()
            }
        }

In [5]:
def evaluate_sleep_data(model: EnsembleModel, 
                       data_manager: SleepDataManager,
                       batch_size: int = 32) -> Dict:
    """Evaluate model with enhanced analysis"""
    model.eval()
    all_outputs = []
    
    with torch.no_grad():
        for i in range(0, len(data_manager.data['x']), batch_size):
            batch_x = data_manager.data['x'][i:i+batch_size].to(device)
            batch_x_spectral = data_manager.data['x_spectral'][i:i+batch_size].to(device)
            outputs = model(batch_x, batch_x_spectral)
            all_outputs.append(outputs)
    
    # Combine all outputs
    model_outputs = torch.cat(all_outputs, dim=0)
    true_labels = data_manager.data['y'].numpy()
    
    # Initialize enhanced analyzer
    analyzer = EnhancedSleepTransitionAnalyzer(window_size=3, min_stable_length=5)
    
    # Get detailed analysis
    analysis = analyzer.analyze_predictions(model_outputs, true_labels)
    
    # Extract results for visualization
    detailed = analysis['detailed_analysis']
    
    return {
        'predictions': np.array([d['predicted_label'] for d in detailed]),
        'confidences': np.array([d['confidence'] for d in detailed]),
        'uncertainties': np.array([d['uncertainty'] for d in detailed]),
        'true_labels': true_labels,
        'tendencies': np.array([d['transition_tendency'] for d in detailed]),
        'confidence_ratios': np.array([d['stage_context']['confidence_ratio'] for d in detailed]),
        'detailed_analysis': analysis['detailed_analysis'],
        'summary': analysis['summary']
    }

In [6]:
# 2. Evaluate
results = evaluate_sleep_data(model, data_manager)

# 3. Visualize
# save_path = os.path.join(model_dir, 'evaluation', 'sleep_analysis_interactive.html')


In [7]:
def plot_sleep_analysis(results: Dict, save_path: Optional[str] = None) -> go.Figure:
    sleep_stages = {
        0: 'N3 (Deep Sleep)',
        1: 'N2 (Light Sleep)', 
        2: 'N1 (Light Sleep)',
        3: 'REM Sleep',
        4: 'Wake'
    }
    
    colors = {
        'true': '#4CAF50',    # Green
        'pred': '#2196F3',    # Blue
        'error': '#EF5350',   # Red
        'tendency': '#9C27B0', # Purple
        'confidence': '#FFA726' # Orange
    }
    
    # Create figure with three subplots
    fig = make_subplots(
        rows=3, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.1,
        row_heights=[0.5, 0.25, 0.25],
        subplot_titles=('Sleep Stage Comparison', 'Transition Tendency', 'Model Confidence')
    )
    
    total_epochs = len(results['true_labels'])
    epochs = np.arange(1, total_epochs + 1)
    
    def create_hover_template(epoch_idx, tendency, true_label, pred_label, confidence):
        is_correct = true_label == pred_label
        status_color = 'green' if is_correct else 'red'
        confidence_color = 'green' if confidence > 0.8 else 'orange' if confidence > 0.6 else 'red'
        
        return (f"<b>Epoch {epoch_idx + 1}/{total_epochs}</b><br><br>"
                f"<b>True State:</b> {sleep_stages[true_label]}<br>"
                f"<b>Predicted:</b> {sleep_stages[pred_label]} <span style='color:{status_color}'>{is_correct and '✓' or '✗'}</span><br><br>"
                f"<b>Confidence:</b> <span style='color:{confidence_color}'>{confidence:.1%}</span><br>"
                f"<b>Transition Tendency:</b> {tendency:.2f}")
    
    # Combine hover data
    hover_data = [
        create_hover_template(i, t, tl, pl, c)
        for i, (t, tl, pl, c) in enumerate(zip(
            results['tendencies'],
            results['true_labels'],
            results['predictions'],
            results['confidences']
        ))
    ]
    
    # Sleep Stage Plot
    fig.add_trace(
        go.Scatter(
            x=epochs,
            y=results['true_labels'],
            name='True Sleep Stages',
            line=dict(color=colors['true'], width=2),
            hoverinfo='skip'
        ),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Scatter(
            x=epochs,
            y=results['predictions'],
            name='Model Predictions',
            line=dict(color=colors['pred'], width=2, dash='dot'),
            hovertemplate="%{customdata}<extra></extra>",
            customdata=hover_data
        ),
        row=1, col=1
    )
    
    # Add error highlighting
    errors = results['predictions'] != results['true_labels']
    error_epochs = epochs[errors]
    error_values = results['predictions'][errors]
    
    fig.add_trace(
        go.Scatter(
            x=error_epochs,
            y=error_values,
            mode='markers',
            name='Prediction Errors',
            marker=dict(color=colors['error'], size=8, symbol='x'),
            hoverinfo='skip'
        ),
        row=1, col=1
    )
    
    # Confidence Plot
    fig.add_trace(
        go.Scatter(
            x=epochs,
            y=results['tendencies'],
            name='Transition Tendency',
            line=dict(color=colors['tendency']),
            fill='tozeroy',
            fillcolor=f"rgba{tuple(list(int(colors['tendency'][i:i+2], 16) for i in (1, 3, 5)) + [0.2])}",
            hovertemplate="%{customdata}<extra></extra>",
            customdata=hover_data
        ),
        row=2, col=1
    )

    # Confidence Plot (now at bottom)
    fig.add_trace(
        go.Scatter(
            x=epochs,
            y=results['confidences'],
            name='Prediction Confidence',
            line=dict(color=colors['confidence']),
            fill='tozeroy',
            fillcolor=f"rgba{tuple(list(int(colors['confidence'][i:i+2], 16) for i in (1, 3, 5)) + [0.2])}",
            hovertemplate="%{customdata}<extra></extra>",
            customdata=hover_data
        ),
        row=3, col=1
    )

    # Add confidence thresholds to bottom subplot
    for threshold, label in [(0.8, 'High'), (0.6, 'Medium')]:
        fig.add_shape(
            type="line",
            x0=epochs[0],
            x1=epochs[-1],
            y0=threshold,
            y1=threshold,
            line=dict(color="gray", width=1, dash="dash"),
            row=3, col=1
        )
        fig.add_annotation(
            x=epochs[0],
            y=threshold,
            text=f"{label} ({threshold:.0%})",
            showarrow=False,
            xanchor="left",
            yanchor="bottom",
            row=3, col=1
        )

    # Update y-axis titles for reordered plots
    fig.update_yaxes(title="Transition Tendency", range=[-0.05, 1.05], row=2, col=1)
    fig.update_yaxes(title="Confidence", range=[-0.05, 1.05], tickformat=".0%", row=3, col=1)
    
    # Update layout
    fig.update_layout(
        title=dict(text="Sleep Stage Analysis with Confidence", x=0.5, font=dict(size=24)),
        height=1000,
        showlegend=True,
        hovermode='x unified',
        template='plotly_white',
        legend=dict(
            yanchor="top",
            y=0.98,
            xanchor="left",
            x=0.01,
            bgcolor="rgba(255, 255, 255, 0.8)"
        ),
        margin=dict(t=100, b=50)
    )
    
    # Update axes
    fig.update_yaxes(
        title="Sleep Stage",
        ticktext=list(sleep_stages.values()),
        tickvals=list(sleep_stages.keys()),
        range=[-0.5, 4.5],
        row=1, col=1
    )
    
    fig.update_yaxes(
        title="Confidence",
        range=[-0.05, 1.05],
        tickformat=".0%",
        row=2, col=1
    )
    
    fig.update_yaxes(
        title="Transition Tendency",
        range=[-0.05, 1.05],
        row=3, col=1
    )
    
    # Update x-axes
    for row in [1, 2]:
        fig.update_xaxes(
            showspikes=True,
            spikethickness=1,
            spikecolor="gray",
            spikemode="across",
            showline=True,
            showgrid=True,
            gridcolor='#E5E5E5',
            row=row, col=1
        )
    
    # Add range slider and selector to bottom plot
    fig.update_xaxes(
        rangeslider=dict(visible=True),
        rangeselector=dict(
            buttons=list([
                dict(count=100, label="100 epochs", step="all", stepmode="backward"),
                dict(count=500, label="500 epochs", step="all", stepmode="backward"),
                dict(count=1000, label="1000 epochs", step="all", stepmode="backward"),
                dict(step="all", label="All epochs")
            ])
        ),
        title="Epoch",
        row=3, col=1
    )
    
    if save_path:
        fig.write_html(save_path)
    
    return fig

In [8]:
# create fig
fig = plot_sleep_analysis(results)
# fig.show()

save_path = './sleep_analysis.html'  # Can be any path you want
fig.write_html(save_path)


In [9]:
# def create_confidence_transition_analysis(results: Dict) -> go.Figure:
#     """Create analysis of model confidence vs transition tendency"""
    
#     # Create bins for transition tendency
#     n_bins = 5
#     tendency_bins = np.linspace(0, 1, n_bins+1)
#     confidence_matrix = {stage: np.zeros((n_bins, n_bins)) for stage in range(5)}
#     counts_matrix = {stage: np.zeros((n_bins, n_bins)) for stage in range(5)}
    
#     # Bin the data
#     for i in range(len(results['true_labels'])):
#         true_label = results['true_labels'][i]
#         # Ensure values are clipped to valid range
#         tendency_val = np.clip(results['tendencies'][i], 0, 1)
#         conf_val = np.clip(results['confidences'][i], 0, 1)
        
#         # Get bin indices, ensuring they're within bounds
#         tendency_bin = min(int(tendency_val * n_bins), n_bins-1)
#         conf_bin = min(int(conf_val * n_bins), n_bins-1)
        
#         confidence_matrix[true_label][tendency_bin, conf_bin] += results['confidences'][i]
#         counts_matrix[true_label][tendency_bin, conf_bin] += 1
    
#     # Calculate average confidence per bin
#     for stage in confidence_matrix:
#         mask = counts_matrix[stage] > 0
#         confidence_matrix[stage][mask] /= counts_matrix[stage][mask]
    
#     sleep_stages = {
#         0: 'N3 (Deep)',
#         1: 'N2 (Light)',
#         2: 'N1 (Light)',
#         3: 'REM',
#         4: 'Wake'
#     }
    
#     # Create subplots
#     fig = make_subplots(
#         rows=2, cols=3,
#         subplot_titles=[f"{sleep_stages[i]}" for i in range(5)] + ["Average"],
#         specs=[[{}, {}, {}], [{}, {}, None]],
#     )
    
#     # Plot heatmaps for each sleep stage
#     row_col = [(0,0), (0,1), (0,2), (1,0), (1,1)]
    
#     bin_labels = [f"{v:.1f}-{tendency_bins[i+1]:.1f}" for i, v in enumerate(tendency_bins[:-1])]
    
#     for stage, (row, col) in zip(range(5), row_col):
#         fig.add_trace(
#             go.Heatmap(
#                 z=confidence_matrix[stage],
#                 x=bin_labels,
#                 y=bin_labels,
#                 colorscale='RdYlBu',
#                 zmin=0,
#                 zmax=1,
#                 showscale=True,
#                 hoverongaps=False,
#                 hovertemplate=(
#                     "Transition: %{x}<br>" +
#                     "Confidence: %{y}<br>" +
#                     "Avg Confidence: %{z:.2f}<br>" +
#                     "Count: %{customdata}<extra></extra>"
#                 ),
#                 customdata=counts_matrix[stage]
#             ),
#             row=row+1, col=col+1
#         )
    
#     # Add average heatmap
#     avg_matrix = sum(confidence_matrix.values()) / len(confidence_matrix)
#     fig.add_trace(
#         go.Heatmap(
#             z=avg_matrix,
#             x=bin_labels,
#             y=bin_labels,
#             colorscale='RdYlBu',
#             zmin=0,
#             zmax=1,
#             showscale=True,
#             hovertemplate=(
#                 "Transition: %{x}<br>" +
#                 "Confidence: %{y}<br>" +
#                 "Avg Confidence: %{z:.2f}<extra></extra>"
#             )
#         ),
#         row=2, col=2
#     )
    
#     fig.update_layout(
#         title="Model Confidence vs Transition Tendency by Sleep Stage",
#         height=800,
#         width=1200,
#         showlegend=False
#     )
    
#     for i in range(6):
#         row = i // 3 + 1
#         col = i % 3 + 1
#         if i < 5 or (row == 2 and col == 2):
#             fig.update_xaxes(title="Transition Tendency", row=row, col=col)
#             fig.update_yaxes(title="Confidence", row=row, col=col)
    
#     return fig
# # create fig
# fig2 = create_confidence_transition_analysis(results)
# fig2.show()

In [10]:
# import plotly.express as px

# def create_confidence_transition_analysis(results: Dict) -> go.Figure:
#     """
#     Create bar graph analysis of model confidence vs transition tendency
#     """
#     # Create bins for transition tendency
#     n_bins = 5
#     tendency_bins = np.linspace(0, 1, n_bins+1)
#     bin_labels = [f"{tendency_bins[i]:.1f}-{tendency_bins[i+1]:.1f}" 
#                  for i in range(n_bins)]
    
#     sleep_stages = {
#         0: 'N3 (Deep)',
#         1: 'N2 (Light)',
#         2: 'N1 (Light)',
#         3: 'REM',
#         4: 'Wake'
#     }
    
#     # Initialize data structures
#     stage_data = {stage: {bin_label: [] for bin_label in bin_labels} 
#                  for stage in sleep_stages.keys()}
    
#     # Bin the data
#     for i in range(len(results['true_labels'])):
#         true_label = results['true_labels'][i]
#         tendency_val = np.clip(results['tendencies'][i], 0, 1)
#         confidence_val = results['confidences'][i]
        
#         bin_idx = min(int(tendency_val * n_bins), n_bins-1)
#         stage_data[true_label][bin_labels[bin_idx]].append(confidence_val)
    
#     # Calculate statistics
#     stats = {}
#     overall_stats = {bin_label: [] for bin_label in bin_labels}
    
#     for stage in sleep_stages.keys():
#         stats[stage] = {}
#         for bin_label in bin_labels:
#             values = stage_data[stage][bin_label]
#             stats[stage][bin_label] = {
#                 'mean': np.mean(values) if values else 0,
#                 'std': np.std(values) if values else 0,
#                 'count': len(values)
#             }
#             overall_stats[bin_label].extend(values)
    
#     # Calculate overall statistics
#     overall_means = {bin_label: np.mean(values) if values else 0 
#                     for bin_label, values in overall_stats.items()}
#     overall_stds = {bin_label: np.std(values) if values else 0 
#                    for bin_label, values in overall_stats.items()}
    
#     # Create figure with subplots
#     fig = make_subplots(
#         rows=2, cols=3,
#         subplot_titles=[f"<b>{sleep_stages[i]}</b>" for i in range(5)] + 
#                       ["<b>Average Across Stages</b>"],
#         specs=[[{}, {}, {}], [{}, {}, {}]],
#         vertical_spacing=0.2,
#         horizontal_spacing=0.1
#     )
    
#     # Color scheme
#     colors = px.colors.qualitative.Set3
    
#     # Plot individual stage bars
#     row_col = [(1,1), (1,2), (1,3), (2,1), (2,2)]
    
#     for stage, (row, col) in zip(sleep_stages.keys(), row_col):
#         stage_stats = stats[stage]
        
#         # Create bar trace
#         fig.add_trace(
#             go.Bar(
#                 name=sleep_stages[stage],
#                 x=bin_labels,
#                 y=[stage_stats[bin]['mean'] for bin in bin_labels],
#                 error_y=dict(
#                     type='data',
#                     array=[stage_stats[bin]['std'] for bin in bin_labels],
#                     visible=True
#                 ),
#                 text=[f"n={stage_stats[bin]['count']}<br>σ={stage_stats[bin]['std']:.3f}" 
#                       for bin in bin_labels],
#                 textposition='auto',
#                 marker_color=colors[stage],
#                 hovertemplate=(
#                     f"<b>{sleep_stages[stage]}</b><br>" +
#                     "Transition Range: %{x}<br>" +
#                     "Avg Confidence: %{y:.3f}<br>" +
#                     "%{text}<extra></extra>"
#                 )
#             ),
#             row=row, col=col
#         )
        
#         # Update axes for each subplot
#         fig.update_yaxes(
#             title="Average Confidence",
#             range=[0, 1],
#             tickformat='.0%',
#             row=row, col=col
#         )
#         fig.update_xaxes(
#             title="Transition Tendency",
#             tickangle=45,
#             row=row, col=col
#         )
    
#     # Add overall average plot
#     fig.add_trace(
#         go.Bar(
#             name='Overall Average',
#             x=bin_labels,
#             y=[overall_means[bin] for bin in bin_labels],
#             error_y=dict(
#                 type='data',
#                 array=[overall_stds[bin] for bin in bin_labels],
#                 visible=True
#             ),
#             text=[f"n={sum(stats[s][bin]['count'] for s in sleep_stages.keys())}"
#                   for bin in bin_labels],
#             textposition='auto',
#             marker_color='rgba(100,100,100,0.6)',
#             hovertemplate=(
#                 "<b>Overall Average</b><br>" +
#                 "Transition Range: %{x}<br>" +
#                 "Avg Confidence: %{y:.3f}<br>" +
#                 "%{text}<extra></extra>"
#             )
#         ),
#         row=2, col=3
#     )
    
#     # Update layout
#     fig.update_layout(
#         title=dict(
#             text="Model Confidence vs Transition Tendency Analysis",
#             x=0.5,
#             font=dict(size=24)
#         ),
#         height=900,
#         width=1400,
#         showlegend=False,
#         bargap=0.2,
#         bargroupgap=0.1
#     )
    
#     # Add unified y-axis label
#     fig.add_annotation(
#         text="Average Confidence",
#         textangle=-90,
#         xref="paper",
#         yref="paper",
#         x=-0.07,
#         y=0.5,
#         showarrow=False
#     )
    
#     # Add unified x-axis label
#     fig.add_annotation(
#         text="Transition Tendency (0=Stable → 1=Transitioning)",
#         xref="paper",
#         yref="paper",
#         x=0.5,
#         y=-0.15,
#         showarrow=False
#     )
    
#     return fig

# create_confidence_transition_analysis(results).show()

In [17]:
import plotly.graph_objects as go

def create_confidence_transition_analysis(results: Dict) -> go.Figure:
    """
    Create grouped bar analysis of model confidence vs transition tendency
    """
    # Create bins for transition tendency
    n_bins = 5
    tendency_bins = np.linspace(0, 1, n_bins+1)
    bin_labels = [f"{tendency_bins[i]:.1f}-{tendency_bins[i+1]:.1f}" 
                 for i in range(n_bins)]
    
    sleep_stages = {
        0: 'N3 (Deep)',
        1: 'N2 (Light)',
        2: 'N1 (Light)',
        3: 'REM',
        4: 'Wake'
    }
    
    # Initialize data structures
    stage_data = {stage: {bin_label: [] for bin_label in bin_labels} 
                 for stage in sleep_stages.keys()}
    
    # Bin the data
    for i in range(len(results['true_labels'])):
        true_label = results['true_labels'][i]
        tendency_val = np.clip(results['tendencies'][i], 0, 1)
        confidence_val = results['confidences'][i]
        
        bin_idx = min(int(tendency_val * n_bins), n_bins-1)
        stage_data[true_label][bin_labels[bin_idx]].append(confidence_val)
    
    # Calculate statistics
    stats = {}
    overall_stats = {bin_label: [] for bin_label in bin_labels}
    
    for stage in sleep_stages.keys():
        stats[stage] = {}
        for bin_label in bin_labels:
            values = stage_data[stage][bin_label]
            stats[stage][bin_label] = {
                'mean': np.mean(values) if values else 0,
                'std': np.std(values) if values else 0,
                'count': len(values)
            }
            overall_stats[bin_label].extend(values)
    
    # Calculate overall statistics
    overall_means = {bin_label: np.mean(values) if values else 0 
                    for bin_label, values in overall_stats.items()}
    overall_stds = {bin_label: np.std(values) if values else 0 
                   for bin_label, values in overall_stats.items()}
    
    # Create figure
    fig = go.Figure()
    
    # Color scheme - using a more distinct palette
    colors = {
        0: '#1f77b4',  # Deep blue for N3
        1: '#2ca02c',  # Green for N2
        2: '#98df8a',  # Light green for N1
        3: '#ff7f0e',  # Orange for REM
        4: '#d62728',  # Red for Wake
        'overall': '#7f7f7f'  # Gray for overall
    }
    
    # Add bars for each sleep stage
    for stage in sleep_stages.keys():
        stage_stats = stats[stage]
        
        fig.add_trace(
            go.Bar(
                name=sleep_stages[stage],
                x=bin_labels,
                y=[stage_stats[bin]['mean'] for bin in bin_labels],
                error_y=dict(
                    type='data',
                    array=[stage_stats[bin]['std'] for bin in bin_labels],
                    visible=True
                ),
                text=[f"n={stage_stats[bin]['count']}" for bin in bin_labels],
                textposition='auto',
                marker_color=colors[stage],
                hovertemplate=(
                    f"<b>{sleep_stages[stage]}</b><br>" +
                    "Transition Range: %{x}<br>" +
                    "Avg Confidence: %{y:.3f}<br>" +
                    "Std Dev: %{error_y.array:.3f}<br>" +
                    "%{text}<extra></extra>"
                )
            )
        )
    
    # Add overall average line
    fig.add_trace(
        go.Scatter(
            name='Overall Average',
            x=bin_labels,
            y=[overall_means[bin] for bin in bin_labels],
            mode='lines+markers',
            line=dict(
                color=colors['overall'],
                width=3,
                dash='dot'
            ),
            marker=dict(
                size=10,
                symbol='diamond'
            ),
            error_y=dict(
                type='data',
                array=[overall_stds[bin] for bin in bin_labels],
                visible=True
            ),
            hovertemplate=(
                "<b>Overall Average</b><br>" +
                "Transition Range: %{x}<br>" +
                "Avg Confidence: %{y:.3f}<br>" +
                "Std Dev: %{error_y.array:.3f}<br>" +
                "<extra></extra>"
            )
        )
    )
    
    # Update layout
    fig.update_layout(
        title=dict(
            text="Model Confidence vs Transition Tendency by Sleep Stage",
            x=0.5,
            font=dict(size=24)
        ),
        xaxis=dict(
            title=dict(
                text="Transition Tendency (0=Stable → 1=Transitioning)",
                font=dict(size=14)
            ),
            tickangle=0
        ),
        yaxis=dict(
            title=dict(
                text="Average Confidence",
                font=dict(size=14)
            ),
            range=[0, 1],
            tickformat='.0%'
        ),
        barmode='group',
        bargap=0.15,
        bargroupgap=0.1,
        height=700,
        width=1200,
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99,
            bgcolor="rgba(255, 255, 255, 0.8)"
        ),
        template="plotly_white"
    )
    
    # Add annotation for sample sizes
    total_samples = sum(sum(stats[s][bin]['count'] for bin in bin_labels) 
                       for s in sleep_stages.keys())
    fig.add_annotation(
        text=f"Total Samples: {total_samples:,}",
        xref="paper",
        yref="paper",
        x=0.01,
        y=0.99,
        showarrow=False,
        font=dict(size=12),
        bgcolor="rgba(255, 255, 255, 0.8)"
    )
    
    return fig


fig2 = create_confidence_transition_analysis(results)
fig2.show()
fig2.write_html("confidence_analysis.html")

In [30]:
def create_confidence_transition_analysis(results: Dict) -> go.Figure:
    """
    Create separate bar graphs for confidence and accuracy vs transition tendency
    """
    # Create bins for transition tendency
    n_bins = 5
    tendency_bins = np.linspace(0, 1, n_bins+1)
    bin_labels = [f"{tendency_bins[i]:.1f}-{tendency_bins[i+1]:.1f}" 
                 for i in range(n_bins)]
    
    sleep_stages = {
        0: 'N3 (Deep)',
        1: 'N2 (Light)',
        2: 'N1 (Light)',
        3: 'REM',
        4: 'Wake'
    }
    
    # Initialize data structures
    stage_data = {stage: {bin_label: {'confidences': [], 'correct': []} 
                         for bin_label in bin_labels} 
                 for stage in sleep_stages.keys()}
    
    # Bin the data
    for i in range(len(results['true_labels'])):
        true_label = results['true_labels'][i]
        predicted_label = results['predictions'][i]
        tendency_val = np.clip(results['tendencies'][i], 0, 1)
        confidence_val = results['confidences'][i]
        is_correct = true_label == predicted_label
        
        bin_idx = min(int(tendency_val * n_bins), n_bins-1)
        bin_label = bin_labels[bin_idx]
        stage_data[true_label][bin_label]['confidences'].append(confidence_val)
        stage_data[true_label][bin_label]['correct'].append(is_correct)
    
    # Calculate statistics
    stats = {}
    overall_stats = {bin_label: {'confidences': [], 'correct': []} 
                    for bin_label in bin_labels}
    
    for stage in sleep_stages.keys():
        stats[stage] = {}
        for bin_label in bin_labels:
            bin_data = stage_data[stage][bin_label]
            confidences = bin_data['confidences']
            correct = bin_data['correct']
            
            stats[stage][bin_label] = {
                'mean_confidence': np.mean(confidences) if confidences else 0,
                'std_confidence': np.std(confidences) if confidences else 0,
                'accuracy': np.mean(correct) if correct else 0,
                'count': len(confidences)
            }
            
            overall_stats[bin_label]['confidences'].extend(confidences)
            overall_stats[bin_label]['correct'].extend(correct)
    
    # Calculate overall statistics
    overall_means = {}
    for bin_label in bin_labels:
        confidences = overall_stats[bin_label]['confidences']
        correct = overall_stats[bin_label]['correct']
        overall_means[bin_label] = {
            'mean_confidence': np.mean(confidences) if confidences else 0,
            'std_confidence': np.std(confidences) if confidences else 0,
            'accuracy': np.mean(correct) if correct else 0,
            'count': len(confidences)
        }
    
    # Create figure with two subplots
    fig = make_subplots(
        rows=2, cols=1,
        subplot_titles=("<b>Model Confidence by Sleep Stage and Transition Tendency</b>",
                       "<b>Model Accuracy by Sleep Stage and Transition Tendency</b>"),
        vertical_spacing=0.15,
        row_heights=[0.5, 0.5]
    )
    
    # Color scheme
    colors = {
        0: '#1f77b4',  # Deep blue for N3
        1: '#2ca02c',  # Green for N2
        2: '#98df8a',  # Light green for N1
        3: '#ff7f0e',  # Orange for REM
        4: '#d62728',  # Red for Wake
        'overall': '#7f7f7f'  # Gray for overall
    }
    
    # Add confidence bars
    for stage in sleep_stages.keys():
        stage_stats = stats[stage]
        
        fig.add_trace(
            go.Bar(
                name=sleep_stages[stage],
                x=bin_labels,
                y=[stage_stats[bin]['mean_confidence'] for bin in bin_labels],
                error_y=dict(
                    type='data',
                    array=[stage_stats[bin]['std_confidence'] for bin in bin_labels],
                    visible=True
                ),
                text=[f"n={stage_stats[bin]['count']}" for bin in bin_labels],
                textposition='auto',
                marker_color=colors[stage],
                hovertemplate=(
                    f"<b>{sleep_stages[stage]}</b><br>" +
                    "Transition Range: %{x}<br>" +
                    "Avg Confidence: %{y:.3f}<br>" +
                    "Std Dev: %{error_y.array:.3f}<br>" +
                    "%{text}<extra></extra>"
                )
            ),
            row=1, col=1
        )
    
    # Add overall confidence line
    fig.add_trace(
        go.Scatter(
            name='Overall Confidence/Accuracy',
            x=bin_labels,
            y=[overall_means[bin]['mean_confidence'] for bin in bin_labels],
            mode='lines+markers',
            line=dict(
                color=colors['overall'],
                width=3,
                dash='dot'
            ),
            marker=dict(
                size=10,
                symbol='star'
            ),
            hovertemplate=(
                "<b>Overall Confidence</b><br>" +
                "Transition Range: %{x}<br>" +
                "Avg Confidence: %{y:.3f}<br>" +
                "<extra></extra>"
            )
        ),
        row=1, col=1
    )
    
    # Add accuracy bars
    for stage in sleep_stages.keys():
        stage_stats = stats[stage]
        
        fig.add_trace(
            go.Bar(
                name=sleep_stages[stage],
                x=bin_labels,
                y=[stage_stats[bin]['accuracy'] for bin in bin_labels],
                marker_color=colors[stage],
                text=[f"{stage_stats[bin]['accuracy']:.1%}" for bin in bin_labels],
                textposition='auto',
                showlegend=False,
                hovertemplate=(
                    f"<b>{sleep_stages[stage]}</b><br>" +
                    "Transition Range: %{x}<br>" +
                    "Accuracy: %{y:.1%}<br>" +
                    "<extra></extra>"
                )
            ),
            row=2, col=1
        )
    
    # Add overall accuracy line
    fig.add_trace(
        go.Scatter(
            name='Overall Confidence/Accuracy',
            x=bin_labels,
            y=[overall_means[bin]['accuracy'] for bin in bin_labels],
            mode='lines+markers',
            line=dict(
                color=colors['overall'],
                width=3,
            ),
            marker=dict(
                size=10,
                symbol='star'
            ),
            showlegend=False,
            hovertemplate=(
                "<b>Overall Accuracy</b><br>" +
                "Transition Range: %{x}<br>" +
                "Accuracy: %{y:.1%}<br>" +
                "<extra></extra>"
            )
        ),
        row=2, col=1
    )
    
    # # Update layout
    # fig.update_layout(
    #     title=dict(
    #         text="Model Performance Analysis by Transition Tendency",
    #         x=0.5,
    #         font=dict(size=24)
    #     ),
    #     barmode='group',
    #     bargap=0.15,
    #     bargroupgap=0.1,
    #     height=1200,
    #     width=1400,
    #     showlegend=True,
    #     legend=dict(
    #         yanchor="top",
    #         y=0.95,
    #         xanchor="right",
    #         x=0.99,
    #         bgcolor="rgba(255, 255, 255, 0.8)"
    #     ),
    #     template="plotly_white"
    # )
    # Update layout
    fig.update_layout(
        title=dict(
            # text="Model Performance Analysis by Transition Tendency",
            x=0.5,
            font=dict(size=24)
        ),
        barmode='group',
        bargap=0.15,
        bargroupgap=0.1,
        height=1200,
        width=1400,
        showlegend=True,
        legend=dict(
            orientation="h",  # Set legend to horizontal orientation
            yanchor="bottom",
            y=.5,  # Position it slightly above the plot area
            xanchor="center",
            x=0.5,
            bgcolor="rgba(255, 255, 255, 0.8)"
        ),
        template="plotly_white"
    )
    
    # Update axes
    for row in [1, 2]:
        fig.update_xaxes(
            title=dict(
                text="Transition Tendency (0=Stable → 1=Transitioning)",
                font=dict(size=14)
            ),
            tickangle=0,
            row=row, col=1
        )
        
        fig.update_yaxes(
            range=[0, 1],
            tickformat='.0%',
            title=dict(
                text="Average Confidence" if row == 1 else "Accuracy",
                font=dict(size=14)
            ),
            row=row, col=1
        )
    
    # Add total samples annotation
    total_samples = sum(sum(stats[s][bin]['count'] for bin in bin_labels) 
                       for s in sleep_stages.keys())
    fig.add_annotation(
        text=f"Total Samples: {total_samples:,}",
        xref="paper",
        yref="paper",
        x=0.01,
        y=0.99,
        showarrow=False,
        font=dict(size=12),
        bgcolor="rgba(255, 255, 255, 0.8)"
    )
    
    return fig

# Usage example:
fig2 = create_confidence_transition_analysis(results)
fig2.show()  # For interactive display
fig2.write_html("confidence_analysis2.html")  # To save as HTML file