In [None]:
import os
import json
import numpy as np
import pandas as pd
import warnings
from datetime import datetime
from glob import glob

# Visualization libraries
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import plotly.express as px
from sklearn.metrics import confusion_matrix, roc_curve, auc

# Configuration
warnings.filterwarnings('ignore')
pd.options.mode.chained_assignment = None

# Global variables
PLOTS_DIR = "results/plots"
RESULTS_DIR = "results/evaluation_results"
SAVE_PNG = True
SAVE_HTML = False
SAVE_SVG= True
SAVE_PDF= True
PNG_SCALE = 2

# Professional plot configuration matching 2.modelos_completos.py
DEFAULT_WIDTH = 2000
DEFAULT_HEIGHT = 1500
FONT_FAMILY = "Quicksand Medium, Quicksand, sans-serif"
FONT_COLOR = 'black'

# Font sizes
TITLE_FONT_SIZE = 28
AXIS_TITLE_FONT_SIZE = 22
AXIS_TICK_FONT_SIZE = 18
LEGEND_FONT_SIZE = 20

# Professional color palette
COLOR_PALETTE = [
    'rgb(31, 119, 180)', 'rgb(255, 127, 14)', 'rgb(44, 160, 44)', 
    'rgb(214, 39, 40)', 'rgb(148, 103, 189)', 'rgb(140, 86, 75)',
    'rgb(227, 119, 194)', 'rgb(127, 127, 127)', 'rgb(188, 189, 34)',
    'rgb(23, 190, 207)'
]

# Create necessary directories
os.makedirs(PLOTS_DIR, exist_ok=True)


def ensure_dir(file_path):
    """Ensure directory exists, create if not"""
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)


def save_plot(fig, filename, width=None, height=None, scale=5):
    """Save plot in multiple formats with professional styling"""
    
    width = width or DEFAULT_WIDTH
    height = height or DEFAULT_HEIGHT
    
    # Update layout with professional configuration
    fig.update_layout(
        font=dict(family=FONT_FAMILY, size=AXIS_TICK_FONT_SIZE, color=FONT_COLOR),
        width=width,
        height=height,
        plot_bgcolor='white',
        paper_bgcolor='white',
        margin=dict(l=150, r=100, t=150, b=150)
    )
    
    base_path = os.path.join(PLOTS_DIR, filename)
    ensure_dir(base_path)
    
    if SAVE_HTML:
        html_path = f"{base_path}.html"
        fig.write_html(html_path)
        print(f"HTML saved: {html_path}")
    
    if SAVE_PNG:
        png_path = f"{base_path}.png"
        fig.write_image(png_path, scale=PNG_SCALE)
        print(f"PNG saved: {png_path}")
    
    if SAVE_SVG:
        svg_path = f"{base_path}.svg"
        fig.write_image(svg_path)
        print(f"SVG saved: {svg_path}")

    if SAVE_PDF:
        pdf_path = f"{base_path}.pdf"
        fig.write_image(pdf_path)
        print(f"PDF saved: {pdf_path}")



def load_class_mapping():
    """Load class mapping from mapeo_clases.csv and evaluation data"""
    try:
        # Try to load from mapeo_clases.csv first
        mapeo_df = pd.read_csv('datos/mapeo_clases.csv')
        full_class_mapping = dict(zip(mapeo_df['model_class'], mapeo_df['genus']))
        print(f"Full class mapping loaded: {len(full_class_mapping)} classes")
        return full_class_mapping
    except FileNotFoundError:
        print("mapeo_clases.csv not found, using default bacterial genera")
        # Default bacterial genera from the paper
        default_genera = [
            'Flavobacterium', 'Vibrio', 'Corynebacterium', 'Pseudomonas_E', 'Pelagibacter',
            'Bradyrhizobium', 'Mycobacterium', 'Nocardioides', 'Streptomyces', 'Prevotella',
            'Prochlorococcus_A', 'Streptococcus', 'Bifidobacterium', 'Novosphingobium', 'Pedobacter',
            'Chryseobacterium', 'Micromonospora', 'Nocardia', 'Arthrobacter', 'Polynucleobacter',
            'Pelagibacter_A', 'Collinsella', 'Acinetobacter', 'Mesorhizobium', 'Microbacterium',
            'Methylobacterium', 'Rhizobium', 'Paracoccus', 'Paraburkholderia', 'Sphingomonas'
        ]
        return {i: genus for i, genus in enumerate(default_genera)}
    except Exception as e:
        print(f"Error loading class mapping: {e}")
        return {i: f'Genus_{i}' for i in range(30)}


def load_evaluation_data_mapping():
    """Load evaluation data to determine actual classes present"""
    try:
        # Try to load evaluation data
        eval_data_path = 'datos/evaluation_data.parquet'
        if os.path.exists(eval_data_path):
            eval_data = pd.read_parquet(eval_data_path)
            print(f"Evaluation data loaded: {len(eval_data)} samples")
            
            # Get unique classes present in evaluation data
            if 'clases_modelos' in eval_data.columns and 'genus' in eval_data.columns:
                eval_mapping_data = eval_data[['genus', 'clases_modelos']].drop_duplicates()
                eval_class_mapping = dict(zip(eval_mapping_data['clases_modelos'], eval_mapping_data['genus']))
                unique_classes = sorted(eval_class_mapping.keys())
                
                print(f"Classes present in evaluation: {len(unique_classes)}")
                print(f"Class range: {min(unique_classes)} to {max(unique_classes)}")
                print(f"Sample genera: {list(eval_class_mapping.values())[:5]}...")
                
                return eval_class_mapping, unique_classes
            else:
                print("Required columns not found in evaluation data")
                return None, None
        else:
            print(f"Evaluation data not found at {eval_data_path}")
            return None, None
            
    except Exception as e:
        print(f"Error loading evaluation data: {e}")
        return None, None


def load_evaluation_results():
    """Carga los resultados de evaluación COMPLETOS con mejor manejo de errores"""
    
    print("Searching for results files...")
    
    # Search for files
    json_files = glob(os.path.join(RESULTS_DIR, "evaluation_detailed.json"))
    csv_files = glob(os.path.join(RESULTS_DIR, "evaluation_results.csv"))
    excel_files = glob(os.path.join(RESULTS_DIR, "evaluation_complete.xlsx"))
    
    # NUEVO: Buscar archivos adicionales
    pred_csv_files = glob(os.path.join(RESULTS_DIR, "predictions_summary.csv"))
    quality_csv_files = glob(os.path.join(RESULTS_DIR, "data_quality_report.csv"))
    
    if not json_files and not csv_files:
        print("No results files found")
        return None, None, None, None, None
    
    # Get most recent files
    latest_json = max(json_files, key=os.path.getctime) if json_files else None
    latest_csv = max(csv_files, key=os.path.getctime) if csv_files else None
    latest_excel = max(excel_files, key=os.path.getctime) if excel_files else None
    
    print(f"Main file: {os.path.basename(latest_csv or latest_json)}")
    
    # Load main data
    df_main = None
    if latest_csv:
        try:
            df_main = pd.read_csv(latest_csv)
            print(f"Main DataFrame loaded: {len(df_main)} rows")
            
            # NUEVO: Verificar si tiene las nuevas columnas
            new_columns = ['data_loss_percentage', 'original_samples', 'valid_samples', 'unique_classes_count']
            has_new_columns = all(col in df_main.columns for col in new_columns)
            print(f"Has enhanced columns: {has_new_columns}")
            
            if has_new_columns:
                avg_loss = df_main['data_loss_percentage'].mean()
                total_samples = df_main['valid_samples'].sum()
                print(f"Average data loss: {avg_loss:.2f}%")
                print(f"Total valid samples: {total_samples:,}")
            
        except Exception as e:
            print(f"Error loading CSV: {e}")
    
    # MEJORA: Load complete JSON with ALL detailed results
    full_results = None
    if latest_json:
        try:
            with open(latest_json, 'r', encoding='utf-8') as f:
                full_results = json.load(f)
            print("JSON results loaded")
            
            # NUEVO: Verificar contenido del JSON
            if 'results' in full_results:
                total_results = len(full_results['results'])
                successful_results = len([r for r in full_results['results'] if r.get('success', False)])
                print(f"JSON contains: {successful_results}/{total_results} successful results")
                
                # Verificar si tiene predicciones completas
                sample_result = next((r for r in full_results['results'] if r.get('success', False)), None)
                if sample_result:
                    has_predictions = 'predictions' in sample_result
                    has_y_true = has_predictions and 'y_true' in sample_result['predictions']
                    has_class_metrics = 'evaluation_metrics' in sample_result and 'class_metrics' in sample_result['evaluation_metrics']
                    
                    print(f"Sample result analysis:")
                    print(f"  Has predictions: {has_predictions}")
                    print(f"  Has y_true: {has_y_true}")
                    print(f"  Has class metrics: {has_class_metrics}")
                    
                    if has_y_true:
                        sample_y_true = sample_result['predictions']['y_true']
                        print(f"  Sample y_true length: {len(sample_y_true)}")
                        print(f"  Sample classes: {len(set(sample_y_true))} unique")
            
        except Exception as e:
            print(f"Error loading JSON: {e}")
    
    # NUEVO: Load additional data from Excel and CSVs
    df_extras = {}
    if latest_excel:
        try:
            # Cargar todas las hojas disponibles
            excel_file = pd.ExcelFile(latest_excel)
            available_sheets = excel_file.sheet_names
            print(f"Excel sheets available: {available_sheets}")
            
            for sheet_name in ['Metricas_por_Clase', 'Tiempos_Ejecucion', 'Predicciones_Detalle', 'Calidad_Datos']:
                if sheet_name in available_sheets:
                    df_extras[sheet_name] = pd.read_excel(latest_excel, sheet_name=sheet_name)
                    print(f"  Loaded {sheet_name}: {len(df_extras[sheet_name])} rows")
        except Exception as e:
            print(f"Excel warning: {e}")
    
    # NUEVO: Cargar archivos CSV adicionales
    if pred_csv_files:
        try:
            latest_pred_csv = max(pred_csv_files, key=os.path.getctime)
            df_extras['predictions_summary'] = pd.read_csv(latest_pred_csv)
            print(f"Predictions summary loaded: {len(df_extras['predictions_summary'])} rows")
        except Exception as e:
            print(f"Error loading predictions CSV: {e}")
    
    if quality_csv_files:
        try:
            latest_quality_csv = max(quality_csv_files, key=os.path.getctime)
            df_extras['data_quality'] = pd.read_csv(latest_quality_csv)
            print(f"Data quality report loaded: {len(df_extras['data_quality'])} rows")
        except Exception as e:
            print(f"Error loading quality CSV: {e}")
    
    # Load class mappings (como antes pero con mejor logging)
    full_class_mapping = load_class_mapping()
    eval_class_mapping, eval_unique_classes = load_evaluation_data_mapping()
    
    # Determine which mapping to use
    if eval_class_mapping is not None:
        class_mapping = eval_class_mapping
        unique_classes = eval_unique_classes
        print(f"Using evaluation class mapping: {len(class_mapping)} classes")
    else:
        class_mapping = full_class_mapping
        unique_classes = None
        
        # MEJORA: Try to infer classes from ENHANCED results
        if full_results and 'results' in full_results:
            all_classes = set()
            for result in full_results['results']:
                if result.get('success', False):
                    predictions = result.get('predictions', {})
                    
                    # NUEVO: Usar unique_classes si está disponible
                    unique_cls = predictions.get('unique_classes', [])
                    if unique_cls:
                        all_classes.update(unique_cls)
                    elif 'y_true' in predictions:
                        y_true = predictions['y_true']
                        if isinstance(y_true, list) and len(y_true) > 0:
                            all_classes.update(y_true)
            
            if all_classes:
                unique_classes = sorted(list(all_classes))
                print(f"Inferred classes from enhanced results: {len(unique_classes)} classes")
                class_mapping = {k: v for k, v in class_mapping.items() if k in unique_classes}
        
        if unique_classes is None:
            unique_classes = list(range(len(class_mapping)))
            print(f"Using default class range: {len(unique_classes)} classes")
    
    return full_results, df_main, df_extras, class_mapping, unique_classes


def extract_model_predictions(full_results, algorithm, encoding):
    """Extrae datos de predicción - VERSIÓN SIMPLIFICADA"""
    if not full_results or 'results' not in full_results:
        return None, None, None, None
    
    for result in full_results['results']:
        if (result.get('algorithm') == algorithm and 
            result.get('encoding') == encoding and
            result.get('success', False)):
            
            predictions = result.get('predictions', {})
            evaluation_metrics = result.get('evaluation_metrics', {})
            
            y_true = predictions.get('y_true')
            y_pred = predictions.get('y_pred') 
            y_proba = predictions.get('y_proba')
            class_metrics = evaluation_metrics.get('class_metrics', {})
            
            # Verificar que los datos no estén vacíos
            if y_true is not None and len(y_true) > 0:
                print(f"  Datos encontrados: {len(y_true)} muestras")
                
                # Convertir a numpy arrays
                y_true = np.array(y_true)
                if y_pred is not None:
                    y_pred = np.array(y_pred)
                if y_proba is not None:
                    y_proba = np.array(y_proba)
                    
                return y_true, y_pred, y_proba, class_metrics
    
    return None, None, None, None

def plot_class_metrics(metrics_dict, model_name, data_name, class_names=None):
    """
    Generate class metrics plot matching 2.modelos_completos.py style
    """
    
    # Use real class names if provided
    if class_names is None:
        class_names = list(metrics_dict['sensitivity_per_class'].keys())
    
    # Create 2x2 subplots
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('<b>Sensitivity per Class</b>', 
                       '<b>Specificity per Class</b>',
                       '<b>Precision per Class</b>', 
                       '<b>Recall per Class</b>'),
        vertical_spacing=0.12,
        horizontal_spacing=0.08
    )
    
    # Colors for each metric
    colors = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)', 
              'rgb(44, 160, 44)', 'rgb(214, 39, 40)']
    
    metrics_names = ['sensitivity_per_class', 'specificity_per_class',
                    'precision_per_class', 'recall_per_class']
    
    positions = [(1,1), (1,2), (2,1), (2,2)]
    
    for (metric_name, pos, color) in zip(metrics_names, positions, colors):
        if metric_name in metrics_dict:
            values = list(metrics_dict[metric_name].values())
            
            fig.add_trace(
                go.Bar(
                    x=class_names,
                    y=values,
                    name=metric_name.replace('_per_class', '').title(),
                    marker_color=color,
                    showlegend=False,
                    hovertemplate="Class: %{x}<br>" +
                                 f"{metric_name.replace('_per_class', '').title()}: " +
                                 "%{y:.3f}<br><extra></extra>"
                ),
                row=pos[0], col=pos[1]
            )

    # Update layout with label rotation
    fig.update_layout(
        title={
            'text': f'<b>Class Metrics - {model_name}</b>',
            'x': 0.5,
            'y': 0.95,
            'xanchor': 'center',
            'yanchor': 'top',
            'font': dict(size=TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR)
        },
        template="plotly_white"
    )

    # Update axes with label rotation
    fig.update_xaxes(
        tickangle=45,
        title_text="<b>Class</b>",
        showgrid=True,
        gridwidth=1,
        gridcolor='rgb(240,240,240)',
        tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR),
        title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR)
    )

    fig.update_yaxes(
        title_text="<b>Value</b>",
        showgrid=True,
        gridwidth=1,
        gridcolor='rgb(240,240,240)',
        range=[0, 1.05],
        tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR),
        title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR)
    )
    
    return fig


def plot_confusion_matrix(y_true, y_pred, classes, model_name, data_name):
    """
    Generate enhanced confusion matrix matching 2.modelos_completos.py style
    """
    
    # Calculate confusion matrix and normalize
    cm = confusion_matrix(y_true, y_pred)
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # Configure custom color scale
    color_scale = [
        [0, 'rgb(255,255,255)'],      # White for 0
        [0.000001, 'rgb(240,248,255)'],  # Almost white for very small values
        [0.3, 'rgb(65,105,225)'],      # Royal blue for medium values
        [1, 'rgb(0,0,139)']            # Dark blue for maximum values
    ]

    # Create hover text
    hover_text = [[f'Real: {classes[i]}<br>' +
                   f'Predicted: {classes[j]}<br>' +
                   f'Count: {cm[i][j]}<br>' +
                   f'Percentage: {cm_percent[i][j]:.1f}%'
                   for j in range(len(classes))] for i in range(len(classes))]

    fig = go.Figure(data=go.Heatmap(
        z=cm,
        x=classes,
        y=classes,
        hoverongaps=False,
        colorscale=color_scale,
        hoverinfo='text',
        text=hover_text,
        showscale=True,
        colorbar=dict(
            title=dict(
                text="<b>Count</b>",
                font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY)
            ),
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY),
            len=0.75,
            thickness=20,
            x=1.02
        )
    ))

    # Add annotations
    for i in range(len(classes)):
        for j in range(len(classes)):
            if cm[i][j] > 0:  # Only show values greater than 0
                # Main value
                fig.add_annotation(
                    x=j,
                    y=i,
                    text=f"<b>{cm[i][j]}</b>",
                    showarrow=False,
                    font=dict(
                        color="white" if cm[i][j] > cm.max() / 2 else "black",
                        size=12,
                        family=FONT_FAMILY
                    ),
                    yshift=10  # Adjust vertical position
                )
                # Percentage below
                fig.add_annotation(
                    x=j,
                    y=i,
                    text=f"({cm_percent[i][j]:.1f}%)",
                    showarrow=False,
                    font=dict(
                        color="white" if cm[i][j] > cm.max() / 2 else "black",
                        size=10,
                        family=FONT_FAMILY
                    ),
                    yshift=-10  # Adjust vertical position
                )

    # Update layout
    fig.update_layout(
        title=dict(
            text=f'<b>Confusion Matrix - {model_name} - {data_name}</b>',
            x=0.5,
            y=0.95,
            xanchor='center',
            yanchor='top',
            font=dict(size=TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR)
        ),
        xaxis_title='<b>Predicted Class</b>',
        yaxis_title='<b>Real Class</b>',
        xaxis=dict(
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY),
            title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY),
            tickangle=45,
            side='bottom',
            gridcolor='white',
            showgrid=False
        ),
        yaxis=dict(
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY),
            title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY),
            gridcolor='white',
            showgrid=False
        ),
        template="plotly_white"
    )
    
    return fig

def plot_roc_curve(y_true, y_score, n_classes, model_name, data_name, class_names=None):
    """
    Generate ROC curves matching 2.modelos_completos.py style - ALL CLASSES VERSION
    """
    
    # Use real class names if provided
    if class_names is None:
        class_names = [f'Class {i}' for i in range(n_classes)]
    
    # Validate inputs
    if y_true is None or y_score is None:
        print(f"  Warning: No prediction data available for ROC curves")
        return None
        
    if len(y_true) == 0 or len(y_score) == 0:
        print(f"  Warning: Empty prediction data for ROC curves")
        return None
    
    # Convert to numpy arrays if they aren't already
    y_true = np.array(y_true)
    y_score = np.array(y_score)
    
    # Ensure y_score is 2D
    if len(y_score.shape) == 1:
        print(f"  Warning: y_score should be 2D for multiclass ROC")
        return None
    
    # Calculate ROC curves for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    # Extended color palette for many classes
    colors = [
        '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
        '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
        '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5',
        '#c49c94', '#f7b6d3', '#c7c7c7', '#dbdb8d', '#9edae5',
        '#ff6b6b', '#4ecdc4', '#45b7d1', '#96ceb4', '#feca57',
        '#ff9ff3', '#54a0ff', '#5f27cd', '#00d2d3', '#ff9f43'
    ]
    
    fig = go.Figure()
    classes_plotted = 0
    
    # MODIFIED: Remove the limit - plot ALL classes
    # max_classes_to_plot = min(10, n_classes)  # REMOVED THIS LINE
    max_classes_to_plot = n_classes  # NOW PLOTS ALL CLASSES
    
    print(f"  Generating ROC curves for all {n_classes} classes...")
    
    # Iterate through ALL classes to create binary classification problems
    for class_idx in range(min(n_classes, len(class_names))):
        if classes_plotted >= max_classes_to_plot:
            break
            
        try:
            # Convert to binary classification for current class
            y_true_binary = (y_true == class_idx).astype(int)
            
            # Check if we have both positive and negative cases
            unique_labels = np.unique(y_true_binary)
            n_positive = np.sum(y_true_binary == 1)
            n_negative = np.sum(y_true_binary == 0)
            
            # Skip if we don't have both classes or if class_idx exceeds y_score columns
            if len(unique_labels) <= 1:
                print(f"  Warning: Only one class present for {class_names[class_idx]} (skipping)")
                continue
                
            if class_idx >= y_score.shape[1]:
                print(f"  Warning: No probability data for class {class_idx} ({class_names[class_idx]})")
                continue
                
            if n_positive == 0:
                print(f"  Warning: No positive cases for class {class_names[class_idx]}")
                continue
                
            # Get probability scores for current class
            y_score_binary = y_score[:, class_idx]
            
            # Calculate ROC curve
            fpr[class_idx], tpr[class_idx], _ = roc_curve(y_true_binary, y_score_binary)
            roc_auc[class_idx] = auc(fpr[class_idx], tpr[class_idx])
            
            # Create hover text
            hover_text = []
            for j in range(len(fpr[class_idx])):
                hover_text.append(
                    f'Class: {class_names[class_idx]}<br>' +
                    f'FPR: {fpr[class_idx][j]:.3f}<br>' +
                    f'TPR: {tpr[class_idx][j]:.3f}<br>' +
                    f'AUC: {roc_auc[class_idx]:.3f}'
                )
            
            # Add trace to plot
            fig.add_trace(go.Scatter(
                x=fpr[class_idx],
                y=tpr[class_idx],
                mode='lines',
                name=f'{class_names[class_idx]} (AUC = {roc_auc[class_idx]:.3f})',
                line=dict(
                    color=colors[classes_plotted % len(colors)],
                    width=2.0  # Slightly thinner lines for better visibility with many classes
                ),
                hoverinfo='text',
                hovertext=hover_text
            ))
            
            classes_plotted += 1
            # if classes_plotted <= 5 or classes_plotted % 5 == 0:  # Print progress for first 5 and every 5th
            #     print(f"  - Added ROC curve for {class_names[class_idx]} (AUC: {roc_auc[class_idx]:.3f})")
                
        except Exception as e:
            error_class_name = class_names[class_idx] if class_idx < len(class_names) else f'Class_{class_idx}'
            print(f"  Warning: Could not plot ROC for class {class_idx} ({error_class_name}): {str(e)}")
            continue

    # Only proceed if we successfully plotted some curves
    if classes_plotted == 0:
        print(f"  Warning: No ROC curves could be generated for {model_name}")
        return None

    # Add random classifier line
    fig.add_trace(go.Scatter(
        x=[0, 1],
        y=[0, 1],
        mode='lines',
        name='Random Classifier',
        line=dict(
            color='gray',
            width=2,
            dash='dash'
        ),
        hoverinfo='skip'
    ))

    # Title shows total number of classes plotted
    title_text = f'<b>ROC Curves - {model_name} - {data_name} ({classes_plotted} classes)</b>'

    # MODIFIED: Improved layout for many classes
    fig.update_layout(
        title={
            'text': title_text,
            'x': 0.5,
            'y': 0.98,  # Moved title higher to make room for legend
            'xanchor': 'center',
            'yanchor': 'top',
            'font': dict(size=TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR)
        },
        xaxis_title='<b>False Positive Rate</b>',
        yaxis_title='<b>True Positive Rate</b>',
        xaxis=dict(
            showgrid=True,
            gridwidth=1,
            gridcolor='rgb(240,240,240)',
            zeroline=True,
            zerolinewidth=1,
            zerolinecolor='rgb(180,180,180)',
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR),
            title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR),
            range=[0, 1]
        ),
        yaxis=dict(
            showgrid=True,
            gridwidth=1,
            gridcolor='rgb(240,240,240)',
            zeroline=True,
            zerolinewidth=1,
            zerolinecolor='rgb(180,180,180)',
            scaleanchor="x",
            scaleratio=1,
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR),
            title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR),
            range=[0, 1]
        ),
        # MODIFIED: Legend configuration for many classes
        legend=dict(
            font=dict(size=16, color=FONT_COLOR, family=FONT_FAMILY),  # Smaller font for many entries
            bgcolor='rgba(255,255,255,0.9)',
            bordercolor='rgb(180,180,180)',
            borderwidth=1,
            orientation="v",
            yanchor="top",
            y=0.98,
            xanchor="left",
            x=1.02,  # Position legend to the right of the plot
            itemsizing="constant",
            itemwidth=30  # Consistent item width
        ),
        template="plotly_white",
        showlegend=True,
        # MODIFIED: Adjust margins to accommodate legend
        margin=dict(l=150, r=300, t=150, b=150)  # Increased right margin for legend
    )
    
    print(f"  ROC plot completed with {classes_plotted} curves")
    return fig

def plot_performance_comparison(df_main):
    """Enhanced performance comparison scatter plot"""
    
    if df_main is None or df_main.empty:
        print("No data for performance comparison")
        return None
    
    fig = go.Figure()
    
    # Group by algorithm
    algorithms = df_main['algorithm'].unique()
    
    for i, algorithm in enumerate(algorithms):
        algo_data = df_main[df_main['algorithm'] == algorithm]
        
        fig.add_trace(go.Scatter(
            x=algo_data['total_time'],
            y=algo_data['accuracy'],
            mode='markers+text',
            text=algo_data['encoding'],
            textposition="top center",
            name=algorithm,
            marker=dict(
                size=15,
                color=COLOR_PALETTE[i % len(COLOR_PALETTE)],
                line=dict(width=2, color='black'),
                opacity=0.8
            ),
            hovertemplate=
            '<b>%{text}</b><br>Algorithm: ' + algorithm + 
            '<br>Time: %{x:.2f}s<br>Accuracy: %{y:.4f}<extra></extra>'
        ))
    
    fig.update_layout(
        title=dict(
            text='<b>Time vs Accuracy Comparison by Algorithm</b>',
            x=0.5, y=0.95, xanchor='center', yanchor='top',
            font=dict(size=TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR)
        ),
        xaxis_title='<b>Total Time (seconds)</b>',
        yaxis_title='<b>Accuracy</b>',
        xaxis=dict(
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY),
            title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY)
        ),
        yaxis=dict(
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY),
            title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY)
        ),
        legend=dict(font=dict(size=LEGEND_FONT_SIZE, family=FONT_FAMILY)),
        template="plotly_white"
    )
    
    return fig


def plot_performance_heatmap(df_main):
    """Enhanced performance heatmap"""
    
    if df_main is None or df_main.empty:
        print("No data for heatmap")
        return None
    
    # Create pivot table
    pivot_data = df_main.pivot_table(
        values='accuracy',
        index='algorithm',
        columns='encoding',
        aggfunc='mean'
    )
    
    fig = go.Figure(data=go.Heatmap(
        z=pivot_data.values,
        x=pivot_data.columns,
        y=pivot_data.index,
        colorscale='Viridis',
        hovertemplate='<b>Model: %{y}</b><br><b>Encoding: %{x}</b><br><b>Accuracy: %{z:.4f}</b><extra></extra>',
        colorbar=dict(
            title=dict(
                text="<b>Accuracy</b>", 
                font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY)
            ),
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY)
        )
    ))
    
    # Add annotations with values
    for i, model in enumerate(pivot_data.index):
        for j, encoding in enumerate(pivot_data.columns):
            if not pd.isna(pivot_data.iloc[i, j]):
                fig.add_annotation(
                    x=j, y=i,
                    text=f'<b>{pivot_data.iloc[i, j]:.3f}</b>',
                    showarrow=False,
                    font=dict(color="white", size=12, family=FONT_FAMILY)
                )
    
    fig.update_layout(
        title=dict(
            text='<b>Accuracy Heatmap: Model vs Encoding</b>',
            x=0.5, y=0.95, xanchor='center', yanchor='top',
            font=dict(size=TITLE_FONT_SIZE, family=FONT_FAMILY, color=FONT_COLOR)
        ),
        xaxis_title='<b>Encoding</b>',
        yaxis_title='<b>Model</b>',
        xaxis=dict(
            tickangle=45,
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY),
            title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY)
        ),
        yaxis=dict(
            tickfont=dict(size=AXIS_TICK_FONT_SIZE, family=FONT_FAMILY),
            title_font=dict(size=AXIS_TITLE_FONT_SIZE, family=FONT_FAMILY)
        ),
        template="plotly_white"
    )
    
    return fig


def generate_model_visualizations(df_main, full_results, class_mapping, unique_classes):
    """Genera visualizaciones - SOLO DATOS REALES, VERSIÓN SIMPLE"""
    
    if df_main is None or df_main.empty:
        print("No data for model visualizations")
        return
    
    # Ordenar modelos por f1_weighted
    models_ranked = df_main.sort_values('f1_weighted', ascending=False).reset_index(drop=True)
    
    # Obtener nombres de clases reales
    present_class_names = [class_mapping.get(class_id, f'Class_{class_id}') for class_id in unique_classes]
    n_classes = len(unique_classes)
    
    print(f"Generating visualizations for {len(models_ranked)} models")
    print(f"Classes in evaluation: {n_classes}")
    
    successful_plots = 0
    failed_plots = 0
    
    for idx, (_, model_row) in enumerate(models_ranked.iterrows()):
        algorithm = model_row['algorithm']
        encoding = model_row['encoding']
        f1_score = model_row['f1_weighted']
        ranking = idx + 1
        
        print(f"Processing Rank {ranking}: {algorithm} - {encoding} (F1: {f1_score:.4f})")
        
        # Extraer datos reales
        y_true, y_pred, y_proba, real_class_metrics = extract_model_predictions(full_results, algorithm, encoding)
        
        # Solo continuar si hay datos reales suficientes
        if y_true is None or y_pred is None or len(y_true) < 10:
            print(f"  SKIPPING: Insufficient real data ({len(y_true) if y_true is not None else 0} samples)")
            failed_plots += 1
            continue
        
        print(f"  Using real data: {len(y_true)} samples")
        
        # Solo usar métricas reales por clase
        if not (real_class_metrics and all(key in real_class_metrics for key in 
                                         ['sensitivity_per_class', 'specificity_per_class', 
                                          'precision_per_class', 'recall_per_class'])):
            print(f"  SKIPPING: No real class metrics available")
            failed_plots += 1
            continue
        
        # Verificar distribución mínima por clase
        unique_in_data = np.unique(y_true)
        min_samples_per_class = min(np.sum(y_true == cls) for cls in unique_in_data)
        
        if min_samples_per_class < 2:
            print(f"  SKIPPING: Some classes have too few samples (min: {min_samples_per_class})")
            failed_plots += 1
            continue
        
        try:
            # Generar gráficos
            base_filename = f"rank_{ranking:02d}_f1_{f1_score:.3f}_{algorithm}_{encoding}".replace(' ', '_').replace('-', '_')
            
            # 1. Class metrics
            print(f"  - Generating class metrics plot")
            fig_metrics = plot_class_metrics(real_class_metrics, algorithm, encoding, present_class_names)
            if fig_metrics:
                save_plot(fig_metrics, f"{base_filename}_class_metrics")
            
            # 2. Confusion matrix
            print(f"  - Generating confusion matrix")
            fig_cm = plot_confusion_matrix(y_true, y_pred, present_class_names, algorithm, encoding)
            if fig_cm:
                save_plot(fig_cm, f"{base_filename}_confusion_matrix")
            
            # 3. ROC curves (solo si hay probabilidades)
            if y_proba is not None:
                print(f"  - Generating ROC curves")
                fig_roc = plot_roc_curve(y_true, y_proba, n_classes, algorithm, encoding, present_class_names)
                if fig_roc:
                    save_plot(fig_roc, f"{base_filename}_roc_curves")
            else:
                print(f"  - Skipping ROC curves (no probabilities)")
            
            successful_plots += 1
            print(f"  Success")
            
        except Exception as e:
            print(f"  Error: {e}")
            failed_plots += 1
    
    print(f"\nSummary: {successful_plots} successful, {failed_plots} failed")


In [3]:
"""Función main simplificada"""

print("Starting visualization generation...")

# Cargar datos
full_results, df_main, df_extras, class_mapping, unique_classes = load_evaluation_results()

print("Data loaded successfully")
print(f"Models: {len(df_main)}")
print(f"Classes: {len(unique_classes)}")

# Generar visualizaciones principales
print("\nGenerating main visualizations...")

# # 1. Performance comparison
# print("  - Performance comparison")
# fig = plot_performance_comparison(df_main)
# if fig:
#     save_plot(fig, "01_performance_comparison")

# # 2. Performance heatmap
# print("  - Performance heatmap")
# fig = plot_performance_heatmap(df_main)
# if fig:
#     save_plot(fig, "02_performance_heatmap")

# Generar visualizaciones individuales
print("\nGenerating individual model visualizations...")
print("(Using ONLY real data - no synthetic data generated)")

generate_model_visualizations(df_main, full_results, class_mapping, unique_classes)

# Contar archivos generados
if os.path.exists(PLOTS_DIR):
    plot_files = [f for f in os.listdir(PLOTS_DIR) if f.endswith('.png')]
    total_plots = len(plot_files)
    
    class_metrics_plots = len([f for f in plot_files if 'class_metrics' in f])
    confusion_matrix_plots = len([f for f in plot_files if 'confusion_matrix' in f])
    roc_curve_plots = len([f for f in plot_files if 'roc_curves' in f])
    summary_plots = len([f for f in plot_files if f.startswith('01_') or f.startswith('02_')])
    
    print(f"\nVisualization Results:")
    print(f"  Total plots: {total_plots}")
    print(f"  - Summary plots: {summary_plots}")
    print(f"  - Class metrics: {class_metrics_plots}")
    print(f"  - Confusion matrices: {confusion_matrix_plots}")
    print(f"  - ROC curves: {roc_curve_plots}")
    
    models_with_plots = (class_metrics_plots + confusion_matrix_plots + roc_curve_plots) // 3
    print(f"  - Models with complete plots: {models_with_plots}")

print(f"\nVisualization completed")
print(f"Plots saved in: {PLOTS_DIR}")

Starting visualization generation...
Searching for results files...
Main file: evaluation_results.csv
Main DataFrame loaded: 48 rows
Has enhanced columns: False
JSON results loaded
JSON contains: 48/48 successful results
Sample result analysis:
  Has predictions: True
  Has y_true: True
  Has class metrics: True
  Sample y_true length: 8555
  Sample classes: 29 unique
Excel sheets available: ['Resumen', 'Metricas_por_Clase', 'Tiempos_Ejecucion']
  Loaded Metricas_por_Clase: 1392 rows
  Loaded Tiempos_Ejecucion: 48 rows
Full class mapping loaded: 29 classes
Evaluation data loaded: 8555 samples
Classes present in evaluation: 29
Class range: 0 to 28
Sample genera: ['Flavobacterium', 'Vibrio', 'Corynebacterium', 'Pseudomonas', 'Pelagibacter']...
Using evaluation class mapping: 29 classes
Data loaded successfully
Models: 48
Classes: 29

Generating main visualizations...

Generating individual model visualizations...
(Using ONLY real data - no synthetic data generated)
Generating visualizati