# Student Performance Dashboard

This notebook contains the interactive dashboard and analysis
for the Students Performance  project.

**Author:** Birhanu Moges


In [None]:
import dash
from dash import dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import base64
import io
from datetime import datetime

# Machine Learning imports
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, GradientBoostingClassifier
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, f1_score
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE

# SHAP for interpretability - OPTIMIZED
import shap

# Additional imports
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
import warnings
warnings.filterwarnings('ignore')
import sys
import os

# Suppress unnecessary console output
original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')

# Read dataset only ONCE at the very beginning
df_original = pd.read_csv(r"C:/Users/DELL/AIgravity/ethiopian_students_dataset.csv")

# Restore stdout for essential messages
sys.stdout.close()
sys.stdout = original_stdout

print(f"Initial dataset shape: {df_original.shape}")

# Color scheme for color blindness accessibility
COLOR_SCHEME = {
    'primary': '#2E86AB',      # Blue
    'secondary': '#A23B72',    # Purple
    'success': '#18A999',      # Teal
    'warning': '#F18F01',      # Orange
    'danger': '#C73E1D',       # Red
    'light': '#F0F3F5',        # Light gray
    'dark': '#2C3E50',         # Dark blue-gray
    'text': '#2C3E50',
    'background': '#FFFFFF',
    'low_perf': '#C73E1D',     # Red for low performance
    'medium_perf': '#F18F01',  # Orange for medium performance
    'high_perf': '#18A999'     # Teal for high performance
}

# Initialize Dash app
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
app.title = "Ethiopian Student Performance Analytics Dashboard"

# =================================================================
# PROVIDED MODEL PERFORMANCE DATA - DEFINED AT GLOBAL SCOPE
# =================================================================

# Provided National Exam Score Model Performance - GLOBAL VARIABLES
NATIONAL_EXAM_MODEL_PERFORMANCE = pd.DataFrame({
    'Model': ['Gradient Boosting', 'XGBoost', 'Random Forest',
              'Ridge Regression', 'Linear Regression', 'Lasso Regression'],
    'R2_Score': [0.437997, 0.435344, 0.425839, 0.405843, 0.405831, 0.404305],
    'MAE': [0.081404, 0.081569, 0.082253, 0.083816, 0.083816, 0.083893],
    'RMSE': [0.107109, 0.107362, 0.108262, 0.110131, 0.110132, 0.110273]
})

# Provided Feature Importance for National Exam Score Model - GLOBAL VARIABLES
NATIONAL_EXAM_FEATURE_IMPORTANCE = pd.DataFrame({
    'Feature': [
        'Score_x_Participation',
        'Overall_Avg_Homework',
        'School_Academic_Score',
        'Overall_Test_Score_Avg',
        'Overall_Avg_Attendance',
        'Overall_Avg_Participation',
        'School_Resources_Score',
        'Parental_Involvement',
        'Resource_Efficiency',
        'Teacher_Student_Ratio',
        'Student_to_Resources_Ratio',
        'School_Type_Target',
        'Overall_Engagement_Score',
        'Teacher_Load_Adjusted',
        'Overall_Textbook_Access_Composite',
        'Field_Choice',
        'Career_Interest_Encoded'
    ],
    'Importance': [
        0.735596, 0.071998, 0.066883, 0.043070, 0.017778, 0.016191,
        0.013264, 0.011599, 0.005805, 0.005141, 0.002633, 0.002587,
        0.001936, 0.001591, 0.001533, 0.001304, 0.001090
    ],
    'Importance_%': [
        73.559627, 7.199785, 6.688265, 4.307000, 1.777779, 1.619086,
        1.326450, 1.159938, 0.580545, 0.514109, 0.263260, 0.258688,
        0.193604, 0.159140, 0.153329, 0.130358, 0.109039
    ]
})

# =================================================================
# DATA PREPROCESSING FUNCTIONS - ENHANCED VERSION
# =================================================================

def load_and_preprocess_data():
    """Load and preprocess the Ethiopian student dataset - ENHANCED VERSION"""
    # Use the already loaded dataset
    df = df_original.copy()

    # ================================
    # 1Ô∏è‚É£ INITIAL CLEANING & ENCODING
    # ================================
    # Drop Student_ID (never used in ML)
    df = df.drop(columns=['Student_ID'], errors='ignore')

    # Encode Field_Choice (Social=0, Natural=1)
    df['Field_Choice'] = df['Field_Choice'].map({'Social': 0, 'Natural': 1})

    # Fill missing Career_Interest with "Unknown"
    df['Career_Interest'] = df['Career_Interest'].fillna('Unknown')

    # ================================
    # 2Ô∏è‚É£ DEFINE EDUCATION STAGES
    # ================================
    lower_primary = ['Grade_1', 'Grade_2', 'Grade_3', 'Grade_4']
    upper_primary = ['Grade_5', 'Grade_6', 'Grade_7', 'Grade_8']
    secondary     = ['Grade_9', 'Grade_10']
    preparatory   = ['Grade_11', 'Grade_12']

    stages = {
        'Lower_Primary': lower_primary,
        'Upper_Primary': upper_primary,
        'Secondary': secondary,
        'Preparatory': preparatory
    }

    # ================================
    # 3Ô∏è‚É£ HELPER FUNCTION TO AGGREGATE GRADES
    # ================================
    def stage_average(df, grades, metric_keywords):
        """
        Compute average across all columns for a given stage and metric keywords.
        Returns the average series and list of original columns used.
        """
        cols = []
        for g in grades:
            for keyword in metric_keywords:
                cols += [c for c in df.columns if c.startswith(g) and keyword.lower() in c.lower()]
        cols = list(set(cols))
        return df[cols].mean(axis=1), cols

    # ================================
    # 4Ô∏è‚É£ AGGREGATE TEST SCORE, ATTENDANCE, HW, PARTICIPATION
    # ================================
    metrics_dict = {
        'Test_Score': ['Test_Score'],
        'Attendance': ['Attendance'],
        'HW_Completion': ['Homework_Completion'],
        'Participation': ['Participation']
    }

    cols_to_drop = []

    for metric_name, keywords in metrics_dict.items():
        for stage_name, grades in stages.items():
            col_name = f'Avg_{metric_name}_{stage_name}'
            df[col_name], original_cols = stage_average(df, grades, keywords)
            cols_to_drop += original_cols

    # Drop original grade-level columns
    df.drop(columns=list(set(cols_to_drop)), inplace=True)

    # ================================
    # 5Ô∏è‚É£ AGGREGATE TEXTBOOK ACCESS
    # ================================
    # Convert Yes/No ‚Üí 1/0 safely
    textbook_cols = [c for c in df.columns if 'Textbook' in c]
    for col in textbook_cols:
        df[col] = df[col].replace({'Yes': 1, 'No': 0}).infer_objects(copy=False)

    # Helper function for textbook access per stage
    def textbook_access(df, grade_prefixes):
        cols = []
        for g in grade_prefixes:
            cols.extend([c for c in df.columns if c.startswith(g) and 'Textbook' in c])
        return df[cols].mean(axis=1) if len(cols) > 0 else pd.Series(0, index=df.index)

    # Create aggregated textbook access per stage
    new_cols_df = pd.DataFrame({
        'Textbook_Access_1_4': textbook_access(df, lower_primary),
        'Textbook_Access_5_8': textbook_access(df, upper_primary),
        'Textbook_Access_9_10': textbook_access(df, secondary),
        'Textbook_Access_11_12': textbook_access(df, preparatory)
    })

    df = pd.concat([df, new_cols_df], axis=1)
    df = df.loc[:, ~df.columns.duplicated()]  # remove duplicates

    # ================================
    # 6Ô∏è‚É£ TRACK-BASED NATIONAL EXAMS
    # ================================
    # Subjects per track
    social_subjects = ['National_Exam_History', 'National_Exam_Geography',
                       'National_Exam_Economics', 'National_Exam_Math_Social']

    natural_subjects = ['National_Exam_Biology', 'National_Exam_Chemistry',
                        'National_Exam_Physics', 'National_Exam_Math_Natural']

    # Track-specific averages
    df['Social_Track_Subject_Avg']  = df[social_subjects].mean(axis=1)
    df['Natural_Track_Subject_Avg'] = df[natural_subjects].mean(axis=1)

    # Track-based assignment
    df['Track_Subject_Average'] = np.where(
        df['Field_Choice'] == 0,
        df['Social_Track_Subject_Avg'],
        df['Natural_Track_Subject_Avg']
    )

    # Common subjects for all students
    common_subjects = ['National_Exam_Aptitude', 'National_Exam_English',
                       'National_Exam_Civics_and_Ethical_Education']
    df['Common_Exam_Average'] = df[common_subjects].mean(axis=1)

    # Overall Track Exam Average
    df['Track_Exam_Average'] = (df['Common_Exam_Average'] + df['Track_Subject_Average']) / 2

    # DROP ORIGINAL HIGH-DIMENSION COLUMNS
    drop_cols = [c for c in df.columns if c.startswith('Grade_')]
    drop_cols += [c for c in df.columns if c.startswith('National_Exam_')]

    df = df.drop(columns=drop_cols)
    # -------------------------------
    # 0Ô∏è‚É£ Drop leaking exam average columns
    # -------------------------------
    leak_cols = [
        'Total_National_Exam_Score',
        'Social_Track_Subject_Avg',
        'Natural_Track_Subject_Avg',
        'Track_Exam_Average',
        'Track_Subject_Average',
        'Common_Exam_Average',
        'Avg_Score_Secondary',
        'Avg_Score_Preparatory',
        'Avg_Score_Lower_Primary',
        'Avg_Score_Upper_Primary',
        'Avg_Test_Score_Secondary',  'Avg_Test_Score_Preparatory',
        'Avg_Test_Score_Lower_Primary',  'Avg_Test_Score_Upper_Primary',
        'School_ID', 'Total_Test_Score']

    df = df.drop(columns=[c for c in leak_cols if c in df.columns])

    # fix null value
    df['Health_Issue'] = df['Health_Issue'].fillna('No Issue')
    df['Father_Education'] = df['Father_Education'].fillna('Unknown')
    df['Mother_Education'] = df['Mother_Education'].fillna('Unknown')

    # ============================================================
    # NEW ENHANCED PREPROCESSING CODE
    # ============================================================

    # -----------------------------
    # Create composite features
    # -----------------------------

    df['Overall_Textbook_Access_Composite'] = df[['Textbook_Access_1_4', 'Textbook_Access_5_8',
                                          'Textbook_Access_9_10', 'Textbook_Access_11_12']].mean(axis=1)

    # Attendance columns
    attendance_cols = [
        'Avg_Attendance_Lower_Primary',
        'Avg_Attendance_Upper_Primary',
        'Avg_Attendance_Secondary',
        'Avg_Attendance_Preparatory'
    ]

    df['Overall_Avg_Attendance'] = df[attendance_cols].mean(axis=1)

    # Homework columns
    homework_cols = [
        'Avg_HW_Completion_Lower_Primary',
        'Avg_HW_Completion_Upper_Primary',
        'Avg_HW_Completion_Secondary',
        'Avg_HW_Completion_Preparatory'
    ]

    df['Overall_Avg_Homework'] = df[homework_cols].mean(axis=1)

    # Participation columns
    participation_cols = [
        'Avg_Participation_Lower_Primary',
        'Avg_Participation_Upper_Primary',
        'Avg_Participation_Secondary',
        'Avg_Participation_Preparatory'
    ]

    df['Overall_Avg_Participation'] = df[participation_cols].mean(axis=1)

    # -----------------------------
    # Composite engagement score (weighted) - FIXED: Values are 1-100
    # -----------------------------
    df['Overall_Engagement_Score'] = (
        df['Overall_Avg_Attendance'] * 0.4 +
        df['Overall_Avg_Homework'] * 0.3 +
        df['Overall_Avg_Participation'] * 0.3
    )

    #==================================
    # DROP ORIGINAL HIGH-DIMENSION COLUMNS
    #==================================

    drop_cols = []

    # Test Scores
    drop_cols += [c for c in df.columns if c.startswith('Avg_Test_Score_')]

    # Textbook Access
    drop_cols += [c for c in df.columns if c.startswith('Textbook_Access_')]

    # Attendance, Participation, Homework
    drop_cols += [c for c in df.columns if c.startswith('Avg_Attendance_')]
    drop_cols += [c for c in df.columns if c.startswith('Avg_Participation_')]
    drop_cols += [c for c in df.columns if c.startswith('Avg_HW_Completion_')]

    # Drop safely
    df = df.drop(columns=drop_cols, errors='ignore')

    return df

def encode_categorical_features(df):
    """Apply ENHANCED categorical encoding to the dataset"""
    df_encoded = df.copy()

    # ============================================================
    # NEW ENHANCED ENCODING CODE
    # ============================================================

    # -------------------------------
    # 0Ô∏è‚É£ Configuration
    # -------------------------------
    CURRENT_DATE = pd.Timestamp('2026-01-30')
    MAX_UNIQUE_OHE = 8
    ALPHA = 10

    # TARGET variable name (adjust if needed)
    TARGET = 'Overall_Average' if 'Overall_Average' in df_encoded.columns else 'Total_National_Exam_Score'

    # -------------------------------
    # 1Ô∏è‚É£ Fill missing values
    # -------------------------------
    if 'Health_Issue' in df_encoded.columns:
        df_encoded['Health_Issue'] = df_encoded['Health_Issue'].fillna('No Issue')

    for col in ['Father_Education', 'Mother_Education']:
        if col in df_encoded.columns:
            df_encoded[col] = df_encoded[col].fillna('Unknown')

    # -------------------------------
    # 2Ô∏è‚É£ Binary encoding (Yes/No features)
    # -------------------------------
    binary_maps = {
        'Gender': {'Male': 0, 'Female': 1},
        'Home_Internet_Access': {'No': 0, 'Yes': 1},
        'Electricity_Access': {'No': 0, 'Yes': 1},
        'School_Location': {'Rural': 0, 'Urban': 1}
    }

    for col, mapping in binary_maps.items():
        if col in df_encoded.columns:
            df_encoded[col] = df_encoded[col].map(mapping)

    # -------------------------------
    # 3Ô∏è‚É£ Ordinal encoding (Parents Education)
    # -------------------------------
    edu_map = {
        'Unknown': 0,
        'Primary': 1,
        'High School': 2,
        'College': 3,
        'University': 4
    }

    for col in ['Father_Education', 'Mother_Education']:
        if col in df_encoded.columns:
            df_encoded[col + '_Encoded'] = df_encoded[col].map(edu_map)
            df_encoded.drop(columns=[col], inplace=True)

    # -------------------------------
    # 4Ô∏è‚É£ Smoothed Target Encoding Function
    # -------------------------------
    def target_encode_smooth(df, col, target, alpha=ALPHA):
        global_mean = df[target].mean()
        stats = df.groupby(col)[target].agg(['mean', 'count'])
        smooth = (stats['count'] * stats['mean'] + alpha * global_mean) / (stats['count'] + alpha)
        return df[col].map(smooth).fillna(global_mean)

    # -------------------------------
    # 5Ô∏è‚É£ HEALTH ISSUE ‚Äî FIXED & IMPROVED
    # -------------------------------
    if 'Health_Issue' in df_encoded.columns:

        # 5.1 Binary flag: has any health issue
        df_encoded['Health_Issue_Flag'] = np.where(df_encoded['Health_Issue'] == 'No Issue', 0, 1)

        # 5.2 Severity encoding (domain-informed)
        severity_map = {
            'No Issue': 0,
            'Dental Problems': 1,
            'Vision Issues': 1,
            'Hearing Issues': 1,
            'Anemia': 2,
            'Parasitic Infections': 2,
            'Respiratory Issues': 2,
            'Malnutrition': 2,
            'Physical Disability': 3,
            'Chronic Illness': 3
        }

        df_encoded['Health_Issue_Severity'] = (
            df_encoded['Health_Issue']
            .map(severity_map)
            .fillna(1)
            .astype(int)
        )

        # 5.3 Target encoding (impact on outcome)
        if TARGET in df_encoded.columns:
            df_encoded['Health_Issue_Target'] = target_encode_smooth(df_encoded, 'Health_Issue', TARGET)

        # Drop original column
        df_encoded.drop(columns=['Health_Issue'], inplace=True)

    # -------------------------------
    # 6Ô∏è‚É£ Region encoding (Target Encoding)
    # -------------------------------
    if 'Region' in df_encoded.columns and TARGET in df_encoded.columns:
        df_encoded['Region_Encoded'] = target_encode_smooth(df_encoded, 'Region', TARGET)
        df_encoded.drop(columns=['Region'], inplace=True)

    # -------------------------------
    # 7Ô∏è‚É£ School Type (Frequency + Target Encoding)
    # -------------------------------
    if 'School_Type' in df_encoded.columns:
        freq_map = df_encoded['School_Type'].value_counts(normalize=True).to_dict()
        df_encoded['School_Type_Freq'] = df_encoded['School_Type'].map(freq_map)

        if TARGET in df_encoded.columns:
            df_encoded['School_Type_Target'] = target_encode_smooth(df_encoded, 'School_Type', TARGET)

        df_encoded.drop(columns=['School_Type'], inplace=True)

    # -------------------------------
    # 8Ô∏è‚É£ Career Interest (Target Encoding)
    # -------------------------------
    if 'Career_Interest' in df_encoded.columns and TARGET in df_encoded.columns:
        df_encoded['Career_Interest_Encoded'] = target_encode_smooth(df_encoded, 'Career_Interest', TARGET)
        df_encoded.drop(columns=['Career_Interest'], inplace=True)

    # -------------------------------
    # 9Ô∏è‚É£ Safe One-Hot Encoding (low-cardinality)
    # -------------------------------
    remaining_cats = df_encoded.select_dtypes(include=['object', 'category']).columns.tolist()
    safe_ohe_cols = [col for col in remaining_cats if df_encoded[col].nunique() <= MAX_UNIQUE_OHE]

    if safe_ohe_cols:
        df_encoded = pd.get_dummies(df_encoded, columns=safe_ohe_cols, drop_first=True)

    # -------------------------------
    # üîü Date_of_Birth ‚Üí Age
    # -------------------------------
    if 'Date_of_Birth' in df_encoded.columns:
        df_encoded['Date_of_Birth'] = pd.to_datetime(df_encoded['Date_of_Birth'], errors='coerce')
        df_encoded['Age'] = ((CURRENT_DATE - df_encoded['Date_of_Birth']).dt.days // 365).astype(float)
        df_encoded.drop(columns=['Date_of_Birth'], inplace=True)

    # -------------------------------
    # üîü Drop Raw Categorical Columns
    # -------------------------------
    drop_cols = [
        'Father_Education', 'Mother_Education','Career_Interest',
        'Health_Issue', 'Region','Date_of_Birth',
        'School_ID', 'School_Type','Health_Issue_Binary'
    ]
    df_encoded.drop(columns=[c for c in drop_cols if c in df_encoded.columns], inplace=True)

    return df_encoded

# =================================================================
# MACHINE LEARNING MODELS
# =================================================================

def train_regression_models(df):
    """Train regression models to predict Overall_Average"""
    TARGET = 'Overall_Average'

    if TARGET not in df.columns:
        return {}, {}, "No Model", None, None, None, None, None

    X = df.drop(columns=[TARGET], errors='ignore')
    y = df[TARGET]

    # Remove any remaining non-numeric columns
    X = X.select_dtypes(include=[np.number])

    if X.shape[1] == 0:
        return {}, {}, "No Model", None, None, None, None, None

    # Split Data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Define Models
    models = {
        "GradientBoosting": GradientBoostingRegressor(
            n_estimators=200,
            learning_rate=0.05,
            max_depth=5,
            random_state=42
        ),
        "RandomForest": RandomForestRegressor(
            n_estimators=200,
            max_depth=10,
            random_state=42,
            n_jobs=-1
        ),
        "XGBoost": XGBRegressor(
            n_estimators=700,
            max_depth=5,
            learning_rate=0.05,
            subsample=0.8,
            colsample_bytree=0.8,
            reg_alpha=1.0,
            reg_lambda=2.0,
            objective="reg:squarederror",
            random_state=42,
            n_jobs=-1
        )
    }

    # Train, Predict, Evaluate
    results = {}
    feature_importances = {}
    trained_models = {}

    for name, model in models.items():
        model.fit(X_train_scaled, y_train)
        y_pred = model.predict(X_test_scaled)

        mae = mean_absolute_error(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        r2 = r2_score(y_test, y_pred)

        results[name] = {"mae": mae, "rmse": rmse, "r2": r2, "y_test": y_test, "y_pred": y_pred}
        trained_models[name] = model

        # Feature Importance
        if hasattr(model, 'feature_importances_'):
            importances = model.feature_importances_
            feature_importance = pd.Series(importances, index=X.columns).sort_values(ascending=False)
            feature_importances[name] = feature_importance

    # Find best model
    if results:
        best_model_name = max(results, key=lambda x: results[x]['r2'])
        best_model = trained_models[best_model_name]
        print(f"Best model: {best_model_name} with R¬≤ = {results[best_model_name]['r2']:.3f}")
    else:
        best_model_name = "No Model"
        best_model = None

    return results, feature_importances, best_model_name, best_model, scaler, X.columns, X_train, X_test

def train_risk_classification(df):
    """Train classification model for Risk/NotRisk prediction"""
    if 'Overall_Average' not in df.columns:
        return {
            'f1': 0,
            'roc_auc': 0,
            'cm': np.array([[0, 0], [0, 0]]),
            'y_test': np.array([]),
            'y_probs': np.array([]),
            'model': None,
            'feature_importance': None,
            'scaler': None,
            'feature_names': None,
            'X_train': None,
            'X_test': None
        }

    score_col = 'Overall_Average'

    # Create Risk/NotRisk target (Risk = 1 if score < 50)
    df['Risk_NotRisk'] = (df[score_col] < 50).astype(int)

    # Prepare features
    X = df.drop(['Risk_NotRisk', score_col], axis=1, errors='ignore')
    y = df['Risk_NotRisk']

    # Select only numeric columns
    X = X.select_dtypes(include=[np.number])

    if X.shape[1] == 0:
        return {
            'f1': 0,
            'roc_auc': 0,
            'cm': np.array([[0, 0], [0, 0]]),
            'y_test': np.array([]),
            'y_probs': np.array([]),
            'model': None,
            'feature_importance': None,
            'scaler': None,
            'feature_names': None,
            'X_train': None,
            'X_test': None
        }

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Handle class imbalance with SMOTE
    smote = SMOTE(random_state=42)
    X_train_res, y_train_res = smote.fit_resample(X_train_scaled, y_train)

    # Base model (fast for initial evaluation)
    gb_base = GradientBoostingClassifier(
        n_estimators=100,
        learning_rate=0.1,
        max_depth=5,
        min_samples_split=50,
        min_samples_leaf=20,
        subsample=0.8,
        random_state=42,
        verbose=0
    )

    gb_base.fit(X_train_res, y_train_res)

    # Predict
    y_probs = gb_base.predict_proba(X_test_scaled)[:, 1]
    y_pred = (y_probs >= 0.50).astype(int)

    # Calculate metrics
    f1 = f1_score(y_test, y_pred, pos_label=1)
    roc_auc = roc_auc_score(y_test, y_probs)
    cm = confusion_matrix(y_test, y_pred)

    # Feature importance
    feature_importance = pd.Series(gb_base.feature_importances_, index=X.columns).sort_values(ascending=False)

    print(f"Risk Classification Results: F1-Score: {f1:.3f}, ROC-AUC: {roc_auc:.3f}")

    return {
        'f1': f1,
        'roc_auc': roc_auc,
        'cm': cm,
        'y_test': y_test,
        'y_probs': y_probs,
        'model': gb_base,
        'feature_importance': feature_importance,
        'scaler': scaler,
        'feature_names': X.columns.tolist(),
        'X_train': X_train,
        'X_test': X_test
    }

def perform_clustering():
    """Perform student clustering analysis - USING UPDATED OUTPUT"""

    # Updated cluster data from provided output
    cluster_sizes_data = {
        'Low': 39380,
        'Medium': 38933,
        'High': 21687
    }

    # Updated cluster profile based on provided output
    cluster_profile_data = {
        'High': {
            'Total_National_Exam_Score': 334.484428,
            'Overall_Average': 62.559605,
            'Overall_Engagement_Score': 73.043863,
            'Overall_Avg_Attendance': 86.733333,
            'Overall_Avg_Homework': 62.319139,
            'Overall_Avg_Participation': 65.515959,
            'Overall_Textbook_Access_Composite': 0.630930,
            'School_Resources_Score': 0.695637,
            'Teacher_Student_Ratio': 34.502018,
            'Student_to_Resources_Ratio': 15.891499,
            'Parental_Involvement': 0.365762
        },
        'Medium': {
            'Total_National_Exam_Score': 331.849464,
            'Overall_Average': 54.283330,
            'Overall_Engagement_Score': 78.301930,
            'Overall_Avg_Attendance': 87.396578,
            'Overall_Avg_Homework': 73.066655,
            'Overall_Avg_Participation': 71.411008,
            'Overall_Textbook_Access_Composite': 0.375552,
            'School_Resources_Score': 0.445902,
            'Teacher_Student_Ratio': 50.049940,
            'Student_to_Resources_Ratio': 22.578901,
            'Parental_Involvement': 0.484578
        },
        'Low': {
            'Total_National_Exam_Score': 286.685256,
            'Overall_Average': 47.309911,
            'Overall_Engagement_Score': 68.026416,
            'Overall_Avg_Attendance': 85.879107,
            'Overall_Avg_Homework': 52.551327,
            'Overall_Avg_Participation': 59.697917,
            'Overall_Textbook_Access_Composite': 0.361508,
            'School_Resources_Score': 0.424432,
            'Teacher_Student_Ratio': 49.957881,
            'Student_to_Resources_Ratio': 22.629008,
            'Parental_Involvement': 0.301593
        }
    }

    # Create cluster profile dataframe
    cluster_profile = pd.DataFrame(cluster_profile_data).T

    # Updated regional risk analysis from provided output
    regional_risk_data = {
        'Somali': 47.398699,
        'Benishangul-Gumuz': 45.542895,
        'Afar': 45.271891,
        'Tigray': 44.758569,
        'Sidama': 43.237808,
        'Gambela': 42.241869,
        'SNNP': 40.569923,
        'Oromia': 39.208222,
        'Amhara': 39.180777,
        'South West Ethiopia': 39.175258,
        'Dire Dawa': 31.365403,
        'Harari': 28.723770,
        'Addis Ababa': 21.323982
    }

    # Regional cluster distribution for heatmap
    regional_cluster_distribution = {
        'Addis Ababa': {'Low': 0.21, 'Medium': 0.60, 'High': 0.19},
        'Afar': {'Low': 0.45, 'Medium': 0.33, 'High': 0.22},
        'Amhara': {'Low': 0.39, 'Medium': 0.38, 'High': 0.23},
        'Benishangul-Gumuz': {'Low': 0.46, 'Medium': 0.32, 'High': 0.23},
        'Dire Dawa': {'Low': 0.31, 'Medium': 0.48, 'High': 0.21},
        'Gambela': {'Low': 0.42, 'Medium': 0.36, 'High': 0.22},
        'Harari': {'Low': 0.29, 'Medium': 0.50, 'High': 0.21},
        'Oromia': {'Low': 0.39, 'Medium': 0.39, 'High': 0.22},
        'SNNP': {'Low': 0.41, 'Medium': 0.37, 'High': 0.22},
        'Sidama': {'Low': 0.43, 'Medium': 0.34, 'High': 0.23},
        'Somali': {'Low': 0.47, 'Medium': 0.30, 'High': 0.23},
        'South West Ethiopia': {'Low': 0.39, 'Medium': 0.39, 'High': 0.22},
        'Tigray': {'Low': 0.45, 'Medium': 0.34, 'High': 0.22}
    }

    regional_risk = pd.Series(regional_risk_data)
    regional_cluster_df = pd.DataFrame(regional_cluster_distribution).T

    print(f"Clustering analysis completed")

    return {
        'silhouette_score': 0.1742,
        'cluster_profile': cluster_profile,
        'cluster_sizes': pd.Series(cluster_sizes_data),
        'regional_risk': regional_risk,
        'regional_cluster_distribution': regional_cluster_df
    }

# =================================================================
# GLOBAL VARIABLES INITIALIZATION
# =================================================================

# Initialize these as None first, then populate after model training
regression_models = None
classification_model = None
clustering_analysis = None
prediction_result = None
best_reg_model = None
reg_model = None
reg_scaler = None
reg_features = None
reg_X_train = None
reg_X_test = None
class_X_train = None
class_X_test = None

# Pre-computed SHAP data to speed up dashboard
shap_data_precomputed = None

# =================================================================
# LOAD AND PROCESS DATA - WITH ERROR HANDLING
# =================================================================

print("=" * 60)
print("ETHIOPIAN STUDENT PERFORMANCE ANALYTICS")
print("=" * 60)

print("\nLoading and preprocessing data...")
df_raw = load_and_preprocess_data()
df_encoded = encode_categorical_features(df_raw)
df_clean = df_encoded.copy()

print(f"Starting encoding with shape: {df_raw.shape}")
print(f"Final dataset shape: {df_clean.shape}")
print(f"Final dataset columns ({len(df_clean.columns)}): {df_clean.columns.tolist()}")

print("\n" + "=" * 60)
print("TRAINING MODELS")
print("=" * 60)

# Train models with error handling
try:
    print("\nTraining regression models...")
    regression_results, feature_importances, best_reg_model, reg_model, reg_scaler, reg_features, reg_X_train, reg_X_test = train_regression_models(df_clean)

    # Store in global variables
    regression_models = regression_results
except Exception as e:
    print(f"Error training regression models: {str(e)}")
    regression_models = {}
    best_reg_model = "No Model"
    reg_model = None
    reg_scaler = None
    reg_features = []
    reg_X_train = None
    reg_X_test = None

try:
    print("\nTraining risk classification model...")
    classification_results = train_risk_classification(df_clean)
    classification_model = classification_results
    class_X_train = classification_results['X_train']
    class_X_test = classification_results['X_test']

    # Pre-compute SHAP data for faster dashboard loading
    if classification_model['model'] is not None and class_X_test is not None:
        try:
            # Use a very small sample for SHAP to make it fast
            X_sample = class_X_test.sample(min(100, len(class_X_test)), random_state=42)
            X_sample_scaled = classification_model['scaler'].transform(X_sample)

            # Create SHAP explainer
            explainer = shap.TreeExplainer(classification_model['model'])

            # Calculate SHAP values
            shap_values = explainer.shap_values(X_sample_scaled)

            # Handle binary classification output
            if isinstance(shap_values, list):
                shap_values = shap_values[1]  # Take positive class

            # Store pre-computed SHAP data
            shap_data_precomputed = {
                'X_sample': X_sample,
                'shap_values': shap_values,
                'feature_names': X_sample.columns.tolist(),
                'explainer': explainer
            }
        except Exception:
            shap_data_precomputed = None

except Exception as e:
    print(f"Error training classification model: {str(e)}")
    classification_model = {
        'f1': 0,
        'roc_auc': 0,
        'cm': np.array([[0, 0], [0, 0]]),
        'y_test': np.array([]),
        'y_probs': np.array([]),
        'model': None,
        'feature_importance': None,
        'scaler': None,
        'feature_names': [],
        'X_train': None,
        'X_test': None
    }
    class_X_train = None
    class_X_test = None

try:
    print("\nPerforming clustering analysis...")
    clustering_results = perform_clustering()
    clustering_analysis = clustering_results
except Exception as e:
    print(f"Error in clustering analysis: {str(e)}")
    # Create a default clustering analysis structure
    clustering_analysis = {
        'silhouette_score': 0.0,
        'cluster_profile': pd.DataFrame(),
        'cluster_sizes': pd.Series(),
        'regional_risk': pd.Series(),
        'regional_cluster_distribution': pd.DataFrame()
    }

# Store feature importances globally
feature_importances_global = feature_importances

print("\n" + "=" * 60)
print("DASHBOARD INITIALIZATION COMPLETE")
print("=" * 60)

# =================================================================
# PREDICTION PROCESSING FUNCTIONS - FIXED VERSION
# =================================================================

# Store the target encoders from training for use in prediction
target_encoders = {}

def prepare_target_encoders(df_clean):
    """Prepare target encoders from training data for prediction"""
    TARGET = 'Overall_Average'
    ALPHA = 10

    def target_encode_smooth(df, col, target, alpha=ALPHA):
        global_mean = df[target].mean()
        stats = df.groupby(col)[target].agg(['mean', 'count'])
        smooth = (stats['count'] * stats['mean'] + alpha * global_mean) / (stats['count'] + alpha)
        return smooth.to_dict()

    # Region encoding
    if 'Region' in df_raw.columns and TARGET in df_clean.columns:
        target_encoders['Region'] = target_encode_smooth(df_raw, 'Region', TARGET)

    # School Type encoding
    if 'School_Type' in df_raw.columns and TARGET in df_clean.columns:
        target_encoders['School_Type'] = target_encode_smooth(df_raw, 'School_Type', TARGET)

    # Career Interest encoding
    if 'Career_Interest' in df_raw.columns and TARGET in df_clean.columns:
        target_encoders['Career_Interest'] = target_encode_smooth(df_raw, 'Career_Interest', TARGET)

    # Health Issue encoding
    if 'Health_Issue' in df_raw.columns and TARGET in df_clean.columns:
        target_encoders['Health_Issue'] = target_encode_smooth(df_raw, 'Health_Issue', TARGET)

# Prepare target encoders
prepare_target_encoders(df_clean)

def process_raw_input_for_prediction(raw_input):
    """
    Process raw input according to the specified steps:
    Step1: Convert raw input to DataFrame
    Step2: Apply SAME encoding rules as training
    Step3: ALIGN COLUMNS (THIS IS CRITICAL) - ADDED Overall_Engagement_Score
    Step4: Date ‚Üí Age
    Step5: Return processed data for prediction
    """

    try:
        # Step 1: Convert raw input to DataFrame
        input_df = pd.DataFrame([raw_input])

        # Step 2: Apply SAME encoding rules as training
        # Create a copy for encoding
        encoded_df = input_df.copy()

        # Binary encoding
        binary_maps = {
            'Gender': {'Male': 0, 'Female': 1},
            'Home_Internet_Access': {'No': 0, 'Yes': 1},
            'Electricity_Access': {'No': 0, 'Yes': 1},
            'School_Location': {'Rural': 0, 'Urban': 1}
        }

        for col, mapping in binary_maps.items():
            if col in encoded_df.columns:
                encoded_df[col] = encoded_df[col].map(mapping)

        # Field Choice encoding
        if 'Field_Choice' in encoded_df.columns:
            encoded_df['Field_Choice'] = encoded_df['Field_Choice'].map({'Social': 0, 'Natural': 1})

        # Ordinal encoding for parent education
        edu_map = {
            'Unknown': 0,
            'Primary': 1,
            'High School': 2,
            'College': 3,
            'University': 4
        }

        for col in ['Father_Education', 'Mother_Education']:
            if col in encoded_df.columns:
                encoded_df[col + '_Encoded'] = encoded_df[col].map(edu_map)

        # Health Issue encoding
        if 'Health_Issue' in encoded_df.columns:
            # Binary flag
            encoded_df['Health_Issue_Flag'] = np.where(encoded_df['Health_Issue'] == 'No Issue', 0, 1)

            # Severity encoding
            severity_map = {
                'No Issue': 0,
                'Dental Problems': 1,
                'Vision Issues': 1,
                'Hearing Issues': 1,
                'Anemia': 2,
                'Parasitic Infections': 2,
                'Respiratory Issues': 2,
                'Malnutrition': 2,
                'Physical Disability': 3,
                'Chronic Illness': 3
            }

            encoded_df['Health_Issue_Severity'] = (
                encoded_df['Health_Issue']
                .map(severity_map)
                .fillna(1)
                .astype(int)
            )

            # Target encoding
            if 'Health_Issue' in target_encoders:
                health_encoder = target_encoders['Health_Issue']
                encoded_df['Health_Issue_Target'] = encoded_df['Health_Issue'].map(health_encoder).fillna(df_clean['Overall_Average'].mean())

        # Region encoding (Target Encoding)
        if 'Region' in encoded_df.columns and 'Region' in target_encoders:
            region_encoder = target_encoders['Region']
            encoded_df['Region_Encoded'] = encoded_df['Region'].map(region_encoder).fillna(df_clean['Overall_Average'].mean())

        # School Type encoding (Frequency + Target Encoding)
        if 'School_Type' in encoded_df.columns:
            # Frequency encoding
            if 'School_Type' in df_raw.columns:
                freq_map = df_raw['School_Type'].value_counts(normalize=True).to_dict()
                encoded_df['School_Type_Freq'] = encoded_df['School_Type'].map(freq_map).fillna(0)

            # Target encoding
            if 'School_Type' in target_encoders:
                school_type_encoder = target_encoders['School_Type']
                encoded_df['School_Type_Target'] = encoded_df['School_Type'].map(school_type_encoder).fillna(df_clean['Overall_Average'].mean())

        # Career Interest encoding
        if 'Career_Interest' in encoded_df.columns and 'Career_Interest' in target_encoders:
            career_encoder = target_encoders['Career_Interest']
            encoded_df['Career_Interest_Encoded'] = encoded_df['Career_Interest'].map(career_encoder).fillna(df_clean['Overall_Average'].mean())

        # Step 4: Date ‚Üí Age
        if 'Date_of_Birth' in encoded_df.columns:
            CURRENT_DATE = pd.Timestamp('2026-01-30')
            encoded_df['Date_of_Birth'] = pd.to_datetime(encoded_df['Date_of_Birth'], errors='coerce')
            encoded_df['Age'] = ((CURRENT_DATE - encoded_df['Date_of_Birth']).dt.days // 365).astype(float)

        # Step 3: ALIGN COLUMNS (CRITICAL) - WITH Overall_Engagement_Score
        # Get all expected columns from training
        expected_columns = reg_features.tolist() if hasattr(reg_features, 'tolist') else list(reg_features)

        # Create final aligned dataframe
        aligned_df = pd.DataFrame(index=[0])

        # First, collect all encoded columns we have
        encoded_columns = {}
        for col in encoded_df.columns:
            if col in expected_columns:
                encoded_columns[col] = encoded_df[col].iloc[0]

        # Add all expected columns with appropriate values
        for col in expected_columns:
            if col in encoded_columns:
                aligned_df[col] = encoded_columns[col]
            else:
                # Set appropriate default values based on column type
                if col in ['Gender', 'Home_Internet_Access', 'Electricity_Access',
                          'School_Location', 'Field_Choice', 'Health_Issue_Flag']:
                    aligned_df[col] = 0
                elif col in ['Father_Education_Encoded', 'Mother_Education_Encoded',
                            'Health_Issue_Severity']:
                    aligned_df[col] = 0
                elif col in ['Parental_Involvement','Region_Encoded','School_Resources_Score',
                            'School_Academic_Score',
                            'School_Type_Freq', 'School_Type_Target','Overall_Textbook_Access_Composite',
                            'Career_Interest_Encoded', 'Health_Issue_Target']:
                    aligned_df[col] = 0.5
                elif col in ['Overall_Engagement_Score']:
                    # Calculate Overall_Engagement_Score using the same formula as training
                    attendance = raw_input.get('Overall_Avg_Attendance', 75)
                    homework = raw_input.get('Overall_Avg_Homework', 65)
                    participation = raw_input.get('Overall_Avg_Participation', 70)
                    aligned_df[col] = (attendance * 0.4 + homework * 0.3 + participation * 0.3)
                elif col in ['Overall_Avg_Attendance', 'Overall_Avg_Homework',
                            'Overall_Avg_Participation']:
                    aligned_df[col] = raw_input.get(col, 50)
                elif col == 'Teacher_Student_Ratio':
                    aligned_df[col] = 40.0
                elif col == 'Student_to_Resources_Ratio':
                    aligned_df[col] = 20.0
                elif col == 'Age':
                    aligned_df[col] = 15.0
                elif any(x in col for x in ['Region_', 'School_Type_', 'Health_Issue_']):
                    aligned_df[col] = 0
                else:
                    if col in df_clean.columns:
                        aligned_df[col] = df_clean[col].median()
                    else:
                        aligned_df[col] = 0

        # Ensure numeric types
        aligned_df = aligned_df.astype(float)

        return aligned_df

    except Exception as e:
        return None

def make_prediction_corrected(input_data):
    """Fixed prediction function using the specified steps"""
    global prediction_result

    try:
        # Process raw input through all steps
        processed_df = process_raw_input_for_prediction(input_data)

        if processed_df is None or processed_df.empty:
            return None

        # Step 5: Make Prediction
        # REGRESSION PREDICTION
        if reg_model is not None and reg_scaler is not None:
            # Ensure we have all required features
            X_reg = processed_df.copy()

            # Make sure columns match exactly
            missing_cols = [col for col in reg_features if col not in X_reg.columns]
            extra_cols = [col for col in X_reg.columns if col not in reg_features]

            if missing_cols:
                for col in missing_cols:
                    X_reg[col] = df_clean[col].median() if col in df_clean.columns else 0

            if extra_cols:
                X_reg = X_reg.drop(columns=extra_cols)

            # Reorder columns to match training
            X_reg = X_reg[reg_features]

            # Scale features
            X_reg_scaled = reg_scaler.transform(X_reg)

            # Make regression prediction
            predicted_score = reg_model.predict(X_reg_scaled)[0]

            # Get regression metrics
            reg_metrics = regression_models[best_reg_model] if best_reg_model in regression_models else {"mae": 0, "rmse": 0, "r2": 0}
        else:
            predicted_score = df_clean['Overall_Average'].mean()
            reg_metrics = {"mae": 0, "rmse": 0, "r2": 0}

        # CLASSIFICATION PREDICTION
        if classification_model['model'] is not None and classification_model['scaler'] is not None:
            # Prepare features for classification
            classification_features = classification_model['feature_names']
            X_class = processed_df.copy()

            # Ensure we have classification features
            missing_class_cols = [col for col in classification_features if col not in X_class.columns]
            if missing_class_cols:
                for col in missing_class_cols:
                    X_class[col] = df_clean[col].median() if col in df_clean.columns else 0

            # Keep only classification features
            X_class = X_class[classification_features]

            # Scale features
            X_class_scaled = classification_model['scaler'].transform(X_class)

            # Make classification prediction
            risk_prob = classification_model['model'].predict_proba(X_class_scaled)[0][1]
            is_risk = risk_prob >= 0.5
        else:
            risk_prob = 0.5
            is_risk = predicted_score < 50

        # Determine risk causes based on input values
        risk_causes = []
        risk_factors = []

        if input_data.get('School_Resources_Score', 0.5) < 0.4:
            risk_factors.append("Low School Resources Score")
        if input_data.get('Overall_Textbook_Access_Composite', 0.5) < 0.4:
            risk_factors.append("Poor Textbook Access")
        if input_data.get('Parental_Involvement', 0.5) < 0.3:
            risk_factors.append("Low Parental Involvement")
        if input_data.get('Teacher_Student_Ratio', 40) > 45:
            risk_factors.append("High Teacher-Student Ratio")
        if input_data.get('Home_Internet_Access', 'No') == 'No':
            risk_factors.append("No Internet Access at Home")
        if input_data.get('Electricity_Access', 'No') == 'No':
            risk_factors.append("No Electricity Access")
        if input_data.get('Father_Education', 'Unknown') in ['Unknown', 'Primary']:
            risk_factors.append("Low Father Education Level")
        if input_data.get('Mother_Education', 'Unknown') in ['Unknown', 'Primary']:
            risk_factors.append("Low Mother Education Level")
        if input_data.get('School_Location', 'Rural') == 'Rural':
            risk_factors.append("Rural School Location")
        if input_data.get('Health_Issue', 'No Issue') != 'No Issue':
            risk_factors.append("Health Issues Present")

        if is_risk:
            risk_causes = risk_factors.copy()
            if not risk_causes:
                risk_causes = ["Multiple academic and environmental factors contributing to risk"]
        else:
            if risk_factors:
                risk_causes = [f"Potential area for improvement: {factor}" for factor in risk_factors]
            else:
                risk_causes = ["All indicators are in favorable ranges"]

        # Generate recommendations
        prediction_recommendations = []
        if is_risk:
            prediction_recommendations.append("üî¥ Immediate Intervention Required")
            prediction_recommendations.append("‚Ä¢ Schedule academic counseling session")
            prediction_recommendations.append("‚Ä¢ Implement personalized learning plan")
            prediction_recommendations.append("‚Ä¢ Increase parent-teacher communication")

            if input_data.get('School_Resources_Score', 0.5) < 0.4:
                prediction_recommendations.append("‚Ä¢ Request additional learning materials")
            if input_data.get('Overall_Textbook_Access_Composite', 0.5) < 0.4:
                prediction_recommendations.append("‚Ä¢ Provide access to digital textbooks")
            if input_data.get('Parental_Involvement', 0.5) < 0.3:
                prediction_recommendations.append("‚Ä¢ Organize parent engagement workshop")
            if input_data.get('Teacher_Student_Ratio', 40) > 45:
                prediction_recommendations.append("‚Ä¢ Advocate for reduced class size")
            if input_data.get('Home_Internet_Access', 'No') == 'No':
                prediction_recommendations.append("‚Ä¢ Provide internet access support")
            if input_data.get('Electricity_Access', 'No') == 'No':
                prediction_recommendations.append("‚Ä¢ Provide Electricity access support")
            if input_data.get('Health_Issue', 'No Issue') != 'No Issue':
                prediction_recommendations.append("‚Ä¢ Arrange health support services")
        else:
            prediction_recommendations.append("‚úÖ Student is Performing Well")
            prediction_recommendations.append(f"‚Ä¢ Predicted Overall Average: {predicted_score:.1f}")
            prediction_recommendations.append(f"‚Ä¢ Risk Probability: {risk_prob*100:.1f}% (Low)")
            prediction_recommendations.append("‚Ä¢ Maintain current study habits")
            prediction_recommendations.append("‚Ä¢ Encourage participation in extracurricular activities")

            if risk_factors:
                prediction_recommendations.append("‚Ä¢ Areas for continued improvement:")
                for factor in risk_factors:
                    prediction_recommendations.append(f"  - Address {factor.lower()}")
            else:
                prediction_recommendations.append("‚Ä¢ All performance indicators are positive")

        # Create prediction result
        prediction_result = {
            'predicted_score': float(predicted_score),
            'risk_probability': float(risk_prob),
            'is_risk': bool(is_risk),
            'risk_causes': risk_causes,
            'recommendations': prediction_recommendations,
            'regression_metrics': {
                'model': best_reg_model,
                'r2': reg_metrics.get('r2', 0),
                'mae': reg_metrics.get('mae', 0),
                'rmse': reg_metrics.get('rmse', 0)
            },
            'classification_metrics': {
                'model': 'Gradient Boosting',
                'f1': classification_model.get('f1', 0),
                'roc_auc': classification_model.get('roc_auc', 0)
            },
            'input_processed': True,
            'processing_steps': {
                'step1': 'Raw input converted to DataFrame',
                'step2': 'Applied same encoding rules as training',
                'step3': 'Columns aligned with training data',
                'step4': 'Date converted to Age',
                'step5': 'Predictions made successfully'
            }
        }

        return prediction_result

    except Exception as e:
        return None

# =================================================================
# PLOTTING FUNCTIONS
# =================================================================

def create_datatype_bar_plot():
    """Create bar plot for data types distribution"""
    dtypes = df_original.dtypes.value_counts()

    fig = go.Figure(data=[
        go.Bar(
            x=dtypes.index.astype(str),
            y=dtypes.values,
            marker_color=COLOR_SCHEME['primary'],
            text=dtypes.values,
            textposition='auto'
        )
    ])
    fig.update_layout(
        title="Data Type Distribution",
        xaxis_title="Data Type",
        yaxis_title="Count",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text'])
    )
    return fig

def create_feature_category_plot():
    """Create bar plot showing features by category"""
    feature_categories = {
        'Student Factors': [
            'Gender', 'Parental_Involvement', 'Home_Internet_Access',
            'Electricity_Access', 'Father_Education_Encoded',
            'Mother_Education_Encoded', 'Age', 'Field_Choice'
        ],
        'Academic Factors': [
            'Overall_Average', 'Total_National_Exam_Score',
            'Overall_Avg_Attendance', 'Overall_Avg_Homework',
            'Overall_Avg_Participation', 'Overall_Engagement_Score',
            'Overall_Textbook_Access_Composite'
        ],
        'School Factors': [
            'School_Location', 'Teacher_Student_Ratio', 'School_Resources_Score',
            'School_Academic_Score', 'Student_to_Resources_Ratio',
            'School_Type_Freq', 'School_Type_Target'
        ],
        'Regional Factors': [
            'Region_Encoded'
        ],
        'Health Factors': [
            'Health_Issue_Flag', 'Health_Issue_Severity', 'Health_Issue_Target'
        ],
        'Other': ['Career_Interest_Encoded']
    }

    category_counts = {}
    colors = [COLOR_SCHEME['primary'], COLOR_SCHEME['secondary'], COLOR_SCHEME['success'],
              COLOR_SCHEME['warning'], COLOR_SCHEME['danger'], COLOR_SCHEME['dark']]

    for i, (category, features) in enumerate(feature_categories.items()):
        existing_features = [f for f in features if f in df_clean.columns]
        category_counts[category] = len(existing_features)

    fig = go.Figure(data=[
        go.Bar(
            x=list(category_counts.keys()),
            y=list(category_counts.values()),
            marker_color=colors[:len(category_counts)],
            text=list(category_counts.values()),
            textposition='auto'
        )
    ])
    fig.update_layout(
        title="Feature Distribution by Category",
        xaxis_title="Feature Category",
        yaxis_title="Number of Features",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        xaxis_tickangle=-45
    )
    return fig

def create_correlation_heatmap():
    """Create correlation heatmap for top features"""
    if 'Overall_Average' not in df_clean.columns:
        fig = go.Figure()
        fig.update_layout(title="Overall Average data not available")
        return fig

    numeric_cols = df_clean.select_dtypes(include=['float64', 'int64']).columns.tolist()

    if 'Risk_NotRisk' in numeric_cols:
        numeric_cols.remove('Risk_NotRisk')

    if len(numeric_cols) > 15:
        correlations = df_clean[numeric_cols].corr()['Overall_Average'].abs().sort_values(ascending=False)
        top_features = correlations.head(15).index.tolist()
        corr_data = df_clean[top_features].corr()
    else:
        corr_data = df_clean[numeric_cols].corr()

    fig = go.Figure(data=go.Heatmap(
        z=corr_data.values,
        x=corr_data.columns,
        y=corr_data.index,
        colorscale='RdBu',
        zmid=0,
        text=np.round(corr_data.values, 2),
        texttemplate='%{text}',
        textfont={"size": 10}
    ))
    fig.update_layout(
        title="Feature Correlation Heatmap (Top 15 Features)",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        xaxis_tickangle=-45,
        height=600
    )
    return fig

def create_regression_comparison_plot():
    """Create comparison plot for regression models"""
    if not regression_models:
        fig = go.Figure()
        fig.update_layout(title="No regression models available")
        return fig

    models = list(regression_models.keys())
    r2_scores = [regression_models[model]['r2'] for model in models]

    colors = [COLOR_SCHEME['secondary'] if model == best_reg_model else COLOR_SCHEME['primary'] for model in models]

    fig = go.Figure(data=[
        go.Bar(x=models, y=r2_scores, marker_color=colors,
              text=[f'{score:.3f}' for score in r2_scores],
              textposition='auto')
    ])
    fig.update_layout(
        title="Regression Model R¬≤ Score Comparison",
        xaxis_title="Models",
        yaxis_title="R¬≤ Score",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text'])
    )
    return fig

def create_actual_vs_predicted_plot():
    """Create actual vs predicted plot for best regression model"""
    if best_reg_model not in regression_models:
        fig = go.Figure()
        fig.update_layout(title="Best regression model not available")
        return fig

    best_reg = regression_models[best_reg_model]

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=best_reg['y_test'],
        y=best_reg['y_pred'],
        mode='markers',
        name='Predictions',
        marker=dict(color=COLOR_SCHEME['primary'], size=6, opacity=0.6)
    ))
    fig.add_trace(go.Scatter(
        x=[best_reg['y_test'].min(), best_reg['y_test'].max()],
        y=[best_reg['y_test'].min(), best_reg['y_test'].max()],
        mode='lines',
        name='Perfect Prediction',
        line=dict(dash='dash', color=COLOR_SCHEME['secondary'], width=2)
    ))
    fig.update_layout(
        title=f"Actual vs Predicted Overall Average - {best_reg_model}",
        xaxis_title="Actual Overall Average",
        yaxis_title="Predicted Overall Average",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text'])
    )
    return fig

def create_feature_importance_plot():
    """Create feature importance plot for regression model"""
    if best_reg_model in feature_importances_global:
        importance = feature_importances_global[best_reg_model].head(10)

        fig = go.Figure(go.Bar(
            x=importance.values,
            y=importance.index,
            orientation='h',
            marker_color=COLOR_SCHEME['primary'],
            text=[f'{imp:.3f}' for imp in importance.values],
            textposition='auto'
        ))
        fig.update_layout(
            title=f"Top 10 Feature Importance - {best_reg_model}",
            xaxis_title="Importance",
            yaxis_title="Features",
            plot_bgcolor=COLOR_SCHEME['background'],
            paper_bgcolor=COLOR_SCHEME['background'],
            font=dict(color=COLOR_SCHEME['text'])
        )
        return fig
    else:
        fig = go.Figure()
        fig.update_layout(title="Feature importance not available for this model")
        return fig

def create_national_exam_model_comparison_plot():
    """Create bar plot for National Exam Score model comparison"""
    df_national = NATIONAL_EXAM_MODEL_PERFORMANCE.copy()
    sorted_df = df_national.sort_values('R2_Score', ascending=True)

    fig = go.Figure()

    fig.add_trace(go.Bar(
        y=sorted_df['Model'],
        x=sorted_df['R2_Score'],
        orientation='h',
        marker_color=[COLOR_SCHEME['success'] if 'Gradient' in model else COLOR_SCHEME['primary'] for model in sorted_df['Model']],
        text=[f'{score:.4f}' for score in sorted_df['R2_Score']],
        textposition='auto',
        name='R¬≤ Score'
    ))

    fig.update_layout(
        title="National Exam Score Model Performance (R¬≤ Score)",
        xaxis_title="R¬≤ Score",
        yaxis_title="Model",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        height=400,
        xaxis=dict(range=[0, max(sorted_df['R2_Score']) * 1.15])
    )

    return fig

def create_national_exam_feature_importance_plot():
    """Create feature importance plot for National Exam Score model"""
    df_importance = NATIONAL_EXAM_FEATURE_IMPORTANCE.copy()
    sorted_df = df_importance.sort_values('Importance', ascending=True)

    fig = go.Figure(go.Bar(
        x=sorted_df['Importance'],
        y=sorted_df['Feature'],
        orientation='h',
        marker_color=COLOR_SCHEME['secondary'],
        text=[f'{imp:.1%}' for imp in sorted_df['Importance']],
        textposition='auto'
    ))

    fig.update_layout(
        title="Feature Importance - National Exam Score (Gradient Boosting)",
        xaxis_title="Importance Score",
        yaxis_title="Features",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        height=500,
        xaxis=dict(range=[0, max(sorted_df['Importance']) * 1.15])
    )

    return fig

def create_national_exam_performance_table():
    """Create performance table for National Exam Score models"""
    return NATIONAL_EXAM_MODEL_PERFORMANCE.copy()

def create_confusion_matrix_plot():
    """Create confusion matrix plot for classification"""
    cm = classification_model['cm']

    fig = go.Figure(data=go.Heatmap(
        z=cm,
        x=['Not Risk', 'Risk'],
        y=['Not Risk', 'Risk'],
        hoverongaps=False,
        colorscale='Blues',
        text=cm,
        texttemplate='%{text}',
        textfont={"size": 16}
    ))
    fig.update_layout(
        title="Risk Classification Confusion Matrix",
        xaxis_title="Predicted",
        yaxis_title="Actual",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text'])
    )
    return fig

def create_roc_curve_plot():
    """Create ROC curve plot"""
    if len(classification_model['y_test']) == 0:
        fig = go.Figure()
        fig.update_layout(title="Classification data not available")
        return fig

    fpr, tpr, _ = roc_curve(classification_model['y_test'], classification_model['y_probs'])
    roc_auc = classification_model['roc_auc']

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=fpr, y=tpr,
        mode='lines',
        name=f'ROC Curve (AUC = {roc_auc:.3f})',
        line=dict(color=COLOR_SCHEME['primary'], width=2)
    ))
    fig.add_trace(go.Scatter(
        x=[0, 1], y=[0, 1],
        mode='lines',
        name='Random Classifier',
        line=dict(dash='dash', color=COLOR_SCHEME['secondary'], width=1)
    ))
    fig.update_layout(
        title="ROC Curve - Risk Classification",
        xaxis_title="False Positive Rate",
        yaxis_title="True Positive Rate",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        yaxis=dict(scaleanchor="x", scaleratio=1),
        xaxis=dict(constrain='domain')
    )
    return fig

def create_cluster_distribution_plot():
    """Create cluster distribution plot"""
    if clustering_analysis is None:
        fig = go.Figure()
        fig.update_layout(title="Clustering data not available")
        return fig

    cluster_counts = clustering_analysis['cluster_sizes']

    colors = {
        'High': COLOR_SCHEME['success'],
        'Medium': COLOR_SCHEME['warning'],
        'Low': COLOR_SCHEME['danger']
    }

    for label in ['High', 'Medium', 'Low']:
        if label not in cluster_counts.index:
            cluster_counts[label] = 0

    fig = go.Figure(data=[
        go.Bar(
            x=cluster_counts.index,
            y=cluster_counts.values,
            marker_color=[colors.get(label, COLOR_SCHEME['primary']) for label in cluster_counts.index],
            text=cluster_counts.values,
            textposition='auto'
        )
    ])
    fig.update_layout(
        title="Student Performance Cluster Distribution",
        xaxis_title="Performance Level",
        yaxis_title="Number of Students",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text'])
    )
    return fig

def create_regional_risk_plot():
    """Create regional risk analysis plot"""
    if clustering_analysis is None:
        fig = go.Figure()
        fig.update_layout(title="Regional data not available")
        return fig

    regional_risk = clustering_analysis['regional_risk']

    if regional_risk.empty:
        fig = go.Figure()
        fig.update_layout(title="Regional data not available")
        return fig

    fig = go.Figure(data=[
        go.Bar(
            x=regional_risk.index,
            y=regional_risk.values,
            marker_color=COLOR_SCHEME['danger'],
            text=[f'{val:.1f}%' for val in regional_risk.values],
            textposition='auto'
        )
    ])
    fig.update_layout(
        title="Regional Risk Analysis (% Low Performance)",
        xaxis_title="Region",
        yaxis_title="% Low Performance",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        xaxis_tickangle=-45
    )
    return fig

def create_score_distribution_plot():
    """Create distribution plot for overall average scores"""
    if 'Overall_Average' not in df_clean.columns:
        fig = go.Figure()
        fig.update_layout(title="Overall Average data not available")
        return fig

    fig = go.Figure()
    fig.add_trace(go.Histogram(
        x=df_clean['Overall_Average'],
        nbinsx=30,
        marker_color=COLOR_SCHEME['primary'],
        opacity=0.7
    ))
    fig.update_layout(
        title="Distribution of Overall Average Scores",
        xaxis_title="Overall Average Score",
        yaxis_title="Number of Students",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text'])
    )
    return fig

def create_students_by_region_plot():
    """Create bar plot showing number of students by region"""
    if 'Region' not in df_raw.columns:
        fig = go.Figure()
        fig.update_layout(title="Region data not available in raw dataset")
        return fig

    region_counts = df_raw['Region'].value_counts().sort_values(ascending=True)

    fig = go.Figure(data=[
        go.Bar(
            y=region_counts.index,
            x=region_counts.values,
            orientation='h',
            marker_color=COLOR_SCHEME['primary'],
            text=region_counts.values,
            textposition='auto'
        )
    ])
    fig.update_layout(
        title="Number of Students by Region",
        xaxis_title="Number of Students",
        yaxis_title="Region",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        height=500
    )
    return fig

def create_feature_summary_table():
    """Create summary statistics table for each feature"""
    numeric_cols = df_original.select_dtypes(include=[np.number]).columns.tolist()

    stats_dict = {}
    for col in numeric_cols[:20]:
        stats_dict[col] = {
            'Mean': f"{df_original[col].mean():.4f}",
            'Std': f"{df_original[col].std():.4f}",
            'Min': f"{df_original[col].min():.4f}",
            '25%': f"{df_original[col].quantile(0.25):.4f}",
            'Median': f"{df_original[col].median():.4f}",
            '75%': f"{df_original[col].quantile(0.75):.4f}",
            'Max': f"{df_original[col].max():.4f}",
            'Missing': f"{df_original[col].isnull().sum():.0f}"
        }

    stats_list = ['Mean', 'Std', 'Min', '25%', 'Median', '75%', 'Max', 'Missing']
    summary_data = []

    for stat in stats_list:
        row = {'Statistic': stat}
        for col in list(stats_dict.keys())[:10]:
            row[col] = stats_dict[col][stat]
        summary_data.append(row)

    return pd.DataFrame(summary_data)

def create_risk_distribution_plot():
    """Create bar plot for risk vs not risk distribution in predictions"""
    if prediction_result is None:
        return None

    risk_data = {
        'Category': ['Risk', 'Not Risk'],
        'Count': [1 if prediction_result['is_risk'] else 0, 0 if prediction_result['is_risk'] else 1],
        'Percentage': [prediction_result['risk_probability'] * 100, (1 - prediction_result['risk_probability']) * 100]
    }

    df_risk = pd.DataFrame(risk_data)

    fig = go.Figure(data=[
        go.Bar(
            x=df_risk['Category'],
            y=df_risk['Count'],
            marker_color=[COLOR_SCHEME['danger'], COLOR_SCHEME['success']],
            text=[f'{p:.1f}%' for p in df_risk['Percentage']],
            textposition='auto'
        )
    ])

    fig.update_layout(
        title="Predicted Risk Distribution",
        xaxis_title="Risk Category",
        yaxis_title="Count",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        height=300
    )

    return fig

def create_shap_summary_plot():
    """Create SHAP summary plot (beeswarm) for classification model"""
    fig = go.Figure()

    if shap_data_precomputed is None:
        fig.update_layout(
            title="SHAP Analysis Not Available",
            plot_bgcolor=COLOR_SCHEME['background'],
            paper_bgcolor=COLOR_SCHEME['background'],
            font=dict(color=COLOR_SCHEME['text'])
        )
        return fig

    try:
        shap_values = shap_data_precomputed['shap_values']
        feature_names = shap_data_precomputed['feature_names']

        n_features = len(feature_names)
        top_n = min(15, n_features)
        mean_abs_shap = np.abs(shap_values).mean(axis=0)
        top_indices = np.argsort(mean_abs_shap)[-top_n:][::-1]
        top_features = [feature_names[i] for i in top_indices]
        top_shap_values = shap_values[:, top_indices]

        y_positions = np.arange(top_n)
        sample_size = min(50, top_shap_values.shape[0])

        x_vals = []
        y_vals = []
        colors = []

        for i in range(sample_size):
            for j, feature_idx in enumerate(top_indices):
                shap_value = shap_values[i, feature_idx]
                if abs(shap_value) > 0.01:
                    x_vals.append(shap_value)
                    y_vals.append(j + np.random.uniform(-0.1, 0.1))
                    colors.append('blue' if shap_value > 0 else 'red')

        fig.add_trace(go.Scatter(
            x=x_vals,
            y=y_vals,
            mode='markers',
            marker=dict(
                size=6,
                color=colors,
                opacity=0.4
            ),
            showlegend=False
        ))

        fig.update_layout(
            title="SHAP Beeswarm Plot (Top 15 Features)",
            xaxis_title="SHAP Value (Impact on Risk Probability)",
            yaxis_title="Features",
            plot_bgcolor=COLOR_SCHEME['background'],
            paper_bgcolor=COLOR_SCHEME['background'],
            font=dict(color=COLOR_SCHEME['text']),
            height=400,
            yaxis=dict(
                tickmode='array',
                tickvals=list(range(top_n)),
                ticktext=top_features,
                autorange="reversed"
            )
        )

        fig.add_annotation(
            x=0.02, y=1.05,
            xref="paper", yref="paper",
            text="Blue = Increases risk | Red = Decreases risk",
            showarrow=False,
            font=dict(size=10, color=COLOR_SCHEME['text'])
        )

    except Exception as e:
        fig.update_layout(title="SHAP summary plot failed")

    return fig

def create_shap_global_plot_classification():
    """Create global SHAP importance plot for classification model"""
    fig = go.Figure()

    if shap_data_precomputed is None:
        fig.update_layout(
            title="SHAP Analysis Not Available",
            plot_bgcolor=COLOR_SCHEME['background'],
            paper_bgcolor=COLOR_SCHEME['background'],
            font=dict(color=COLOR_SCHEME['text'])
        )
        return fig

    try:
        shap_values = shap_data_precomputed['shap_values']
        feature_names = shap_data_precomputed['feature_names']

        shap_importance = np.abs(shap_values).mean(axis=0)

        importance_df = (
            pd.DataFrame({
                "feature": feature_names,
                "importance": shap_importance
            })
            .sort_values("importance", ascending=False)
            .head(10)
        )

        fig.add_trace(go.Bar(
            x=importance_df["importance"],
            y=importance_df["feature"],
            orientation="h",
            marker_color=COLOR_SCHEME["secondary"],
            name="Mean |SHAP Value|"
        ))

        fig.update_layout(
            title="Global SHAP Feature Importance (Classification)",
            xaxis_title="Mean |SHAP Value|",
            yaxis_title="Features",
            plot_bgcolor=COLOR_SCHEME["background"],
            paper_bgcolor=COLOR_SCHEME["background"],
            font=dict(color=COLOR_SCHEME["text"]),
            height=350,
            yaxis=dict(autorange="reversed")
        )

    except Exception as e:
        fig.update_layout(title="SHAP calculation failed")

    return fig

def create_regional_cluster_heatmap():
    """Create heatmap showing cluster distribution by region"""
    if clustering_analysis is None or 'regional_cluster_distribution' not in clustering_analysis:
        fig = go.Figure()
        fig.update_layout(
            title="Regional Cluster Distribution Heatmap",
            plot_bgcolor=COLOR_SCHEME['background'],
            paper_bgcolor=COLOR_SCHEME['background'],
            font=dict(color=COLOR_SCHEME['text']),
            height=500
        )
        return fig

    regional_cluster_df = clustering_analysis['regional_cluster_distribution']

    if regional_cluster_df.empty or len(regional_cluster_df.columns) == 0:
        fig = go.Figure()
        fig.update_layout(title="No regional cluster data available")
        return fig

    fig = go.Figure(data=go.Heatmap(
        z=regional_cluster_df.values,
        x=regional_cluster_df.columns,
        y=regional_cluster_df.index,
        colorscale='RdYlGn_r',
        text=regional_cluster_df.values.round(2),
        texttemplate='%{text:.2f}',
        textfont={"size": 12},
        hoverongaps=False,
        colorbar=dict(title="Proportion")
    ))

    fig.update_layout(
        title="Regional Cluster Distribution Heatmap",
        xaxis_title="Performance Cluster",
        yaxis_title="Region",
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        height=500,
        xaxis=dict(tickangle=0)
    )

    return fig

def create_regional_cluster_barchart():
    """Create stacked bar chart showing cluster distribution by region"""
    if clustering_analysis is None or 'regional_cluster_distribution' not in clustering_analysis:
        fig = go.Figure()
        fig.update_layout(
            title="Regional Cluster Distribution (Stacked)",
            plot_bgcolor=COLOR_SCHEME['background'],
            paper_bgcolor=COLOR_SCHEME['background'],
            font=dict(color=COLOR_SCHEME['text']),
            height=500
        )
        return fig

    regional_cluster_df = clustering_analysis['regional_cluster_distribution']

    if regional_cluster_df.empty or len(regional_cluster_df.columns) == 0:
        fig = go.Figure()
        fig.update_layout(title="No regional cluster data available")
        return fig

    fig = go.Figure()

    cluster_colors = {
        'Low': COLOR_SCHEME['danger'],
        'Medium': COLOR_SCHEME['warning'],
        'High': COLOR_SCHEME['success']
    }

    for cluster in ['Low', 'Medium', 'High']:
        if cluster in regional_cluster_df.columns:
            fig.add_trace(go.Bar(
                name=cluster,
                x=regional_cluster_df.index,
                y=regional_cluster_df[cluster],
                marker_color=cluster_colors[cluster],
                text=regional_cluster_df[cluster].round(2),
                textposition='inside',
                textfont=dict(size=10)
            ))

    fig.update_layout(
        title="Regional Cluster Distribution (Stacked)",
        xaxis_title="Region",
        yaxis_title="Proportion",
        barmode='stack',
        plot_bgcolor=COLOR_SCHEME['background'],
        paper_bgcolor=COLOR_SCHEME['background'],
        font=dict(color=COLOR_SCHEME['text']),
        height=500,
        xaxis=dict(tickangle=-45),
        legend=dict(
            title="Performance Level",
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )

    return fig

def create_recommendations_summary_table():
    """Create comprehensive summary table for Recommendations & Summary dashboard"""

    total_students = len(df_clean)
    if 'Overall_Average' in df_clean.columns:
        risk_students = (df_clean['Overall_Average'] < 50).sum()
        avg_score = df_clean['Overall_Average'].mean()
        risk_percentage = (risk_students / total_students * 100) if total_students > 0 else 0
    else:
        risk_students = 0
        avg_score = 0
        risk_percentage = 0

    if clustering_analysis is not None:
        high_performers = clustering_analysis['cluster_sizes'].get('High', 0)
        medium_performers = clustering_analysis['cluster_sizes'].get('Medium', 0)
        low_performers = clustering_analysis['cluster_sizes'].get('Low', 0)
    else:
        high_performers = 0
        medium_performers = 0
        low_performers = 0

    if best_reg_model in regression_models:
        best_model_r2 = regression_models[best_reg_model]['r2']
        best_model_mae = regression_models[best_reg_model]['mae']
    else:
        best_model_r2 = 0
        best_model_mae = 0

    f1_score_val = classification_model.get('f1', 0)
    roc_auc_val = classification_model.get('roc_auc', 0)

    if clustering_analysis is not None and 'regional_risk' in clustering_analysis:
        regional_risk = clustering_analysis['regional_risk']
        if not regional_risk.empty:
            top_risk_region = regional_risk.idxmax()
            top_risk_value = regional_risk.max()
        else:
            top_risk_region = "N/A"
            top_risk_value = 0
    else:
        top_risk_region = "N/A"
        top_risk_value = 0

    summary_data = [
        {'Metric': 'Dataset Statistics', 'Value': '', 'Details': ''},
        {'Metric': '  Total Students', 'Value': f'{total_students:,}', 'Details': '100%'},
        {'Metric': '  At-Risk Students', 'Value': f'{risk_students:,}', 'Details': f'{risk_percentage:.1f}%'},
        {'Metric': '  Average Overall Score', 'Value': f'{avg_score:.2f}', 'Details': 'out of 100'},
        {'Metric': '', 'Value': '', 'Details': ''},

        {'Metric': 'Performance Clusters', 'Value': '', 'Details': ''},
        {'Metric': '  High Performers', 'Value': f'{high_performers:,}', 'Details': f'{(high_performers/total_students*100):.1f}%' if total_students > 0 else '0%'},
        {'Metric': '  Medium Performers', 'Value': f'{medium_performers:,}', 'Details': f'{(medium_performers/total_students*100):.1f}%' if total_students > 0 else '0%'},
        {'Metric': '  Low Performers', 'Value': f'{low_performers:,}', 'Details': f'{(low_performers/total_students*100):.1f}%' if total_students > 0 else '0%'},
        {'Metric': '', 'Value': '', 'Details': ''},

        {'Metric': 'Model Performance', 'Value': '', 'Details': ''},
        {'Metric': '  Best Regression Model', 'Value': best_reg_model, 'Details': f'R¬≤ = {best_model_r2:.3f}'},
        {'Metric': '  Regression MAE', 'Value': f'{best_model_mae:.2f}', 'Details': 'avg error in points'},
        {'Metric': '  Risk Classification F1', 'Value': f'{f1_score_val:.3f}', 'Details': '>0.75 is good'},
        {'Metric': '  Risk Classification ROC-AUC', 'Value': f'{roc_auc_val:.3f}', 'Details': '>0.80 is excellent'},
        {'Metric': '', 'Value': '', 'Details': ''},

        {'Metric': 'National Exam Model', 'Value': '', 'Details': ''},
        {'Metric': '  Best Model', 'Value': 'Gradient Boosting', 'Details': 'R¬≤ = 0.4380'},
        {'Metric': '  Top Feature', 'Value': 'Score_x_Participation', 'Details': '73.6% importance'},
        {'Metric': '', 'Value': '', 'Details': ''},

        {'Metric': 'Regional Risk Analysis', 'Value': '', 'Details': ''},
        {'Metric': '  Highest Risk Region', 'Value': top_risk_region, 'Details': f'{top_risk_value:.1f}% low performers'},
        {'Metric': '  Lowest Risk Region', 'Value': 'Addis Ababa', 'Details': '21.3% low performers'},
        {'Metric': '', 'Value': '', 'Details': ''},

        {'Metric': 'Key Predictors', 'Value': '', 'Details': ''},
        {'Metric': '  Top Factor', 'Value': 'School Resources Score', 'Details': 'Strongest predictor'},
        {'Metric': '  2nd Factor', 'Value': 'Textbook Access', 'Details': 'Critical for learning'},
        {'Metric': '  3rd Factor', 'Value': 'Parental Involvement', 'Details': 'Significant impact'},
        {'Metric': '  4th Factor', 'Value': 'Teacher-Student Ratio', 'Details': 'Lower is better'},
    ]

    return pd.DataFrame(summary_data)

# =================================================================
# DASHBOARD LAYOUT
# =================================================================

app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.Div([
                html.H2("Student Analytics", className="mb-4", style={'color': COLOR_SCHEME['dark']}),

                html.Div([
                    dbc.Button(
                        "Overview Dashboard",
                        id="btn-overview",
                        color="primary",
                        className="mb-3 w-100",
                        style={'textAlign': 'left', 'backgroundColor': COLOR_SCHEME['primary'], 'border': 'none'}
                    ),

                    dbc.Button(
                        "Models Analysis",
                        id="btn-models",
                        color="primary",
                        className="mb-3 w-100",
                        style={'textAlign': 'left', 'backgroundColor': COLOR_SCHEME['danger'], 'border': 'none'}
                    ),

                    dbc.Button(
                        "Make Prediction",
                        id="btn-prediction",
                        color="primary",
                        className="mb-3 w-100",
                        style={'textAlign': 'left', 'backgroundColor': COLOR_SCHEME['secondary'], 'border': 'none'}
                    ),

                    dbc.Button(
                        "Student Clustering",
                        id="btn-clustering",
                        color="primary",
                        className="mb-3 w-100",
                        style={'textAlign': 'left', 'backgroundColor': COLOR_SCHEME['success'], 'border': 'none'}
                    ),

                    dbc.Button(
                        "Recommendation & Summary",
                        id="btn-recommendations",
                        color="primary",
                        className="mb-3 w-100",
                        style={'textAlign': 'left', 'backgroundColor': COLOR_SCHEME['warning'], 'border': 'none'}
                    ),
                ], id="nav-buttons"),

                html.Hr(),

                html.Div([
                    html.H5("Quick Stats", style={'color': COLOR_SCHEME['dark']}),
                    html.P(f"Students: {len(df_clean):,}"),
                    html.P(f"After Preprocessing: {df_clean.shape[1]-2} Features"),
                    html.P(f"Risk Students: {(df_clean['Overall_Average'] < 50).sum():,}" if 'Overall_Average' in df_clean.columns else "Risk: N/A"),
                    html.P(f"Best Model: {best_reg_model}"),
                ], style={'fontSize': '14px'}),
            ], style={
                'backgroundColor': COLOR_SCHEME['light'],
                'padding': '20px',
                'height': '100vh',
                'borderRight': f'2px solid {COLOR_SCHEME["primary"]}',
                'overflowY': 'auto',
                'position': 'fixed',
                'width': '25%'
            })
        ], width=3),

        dbc.Col([
            dcc.Loading(
                id="loading-1",
                type="default",
                children=html.Div(id="page-content", style={'padding': '20px', 'height': '100vh', 'overflowY': 'auto'})
            )
        ], width=9, style={'marginLeft': '25%'})
    ])
], fluid=True, style={'backgroundColor': COLOR_SCHEME['background'], 'height': '100vh', 'overflow': 'hidden'})

# =================================================================
# PAGE CONTENT FUNCTIONS
# =================================================================

def create_overview_page():
    """Create overview page content"""
    if 'Overall_Average' in df_clean.columns:
        risk_count = (df_clean['Overall_Average'] < 50).sum()
        avg_score = df_clean['Overall_Average'].mean()
    else:
        risk_count = 0
        avg_score = 0

    feature_summary_df = create_feature_summary_table()

    return html.Div([
        dbc.Row([
            dbc.Col([
                html.H1("Overview Dashboard",
                       className="mb-4",
                       style={'color': COLOR_SCHEME['dark'], 'fontWeight': 'bold'}),
                html.P("Comprehensive analysis of Ethiopian students' academic performance",
                      className="text-muted mb-4")
            ])
        ]),

        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardBody([
                        html.Div([
                            html.H6("Total Students", style={'color': COLOR_SCHEME['primary']}),
                            html.H3(f"{len(df_clean):,}", style={'color': COLOR_SCHEME['secondary']})
                        ], className="text-center")
                    ])
                ], style={'border': f"2px solid {COLOR_SCHEME['primary']}"})
            ], width=3),
            dbc.Col([
                dbc.Card([
                    dbc.CardBody([
                        html.Div([
                            html.H6("All Columns", style={'color': COLOR_SCHEME['primary']}),
                            html.H3(f"{df_original.shape[1]}", style={'color': COLOR_SCHEME['secondary']})
                        ], className="text-center")
                    ])
                ], style={'border': f"2px solid {COLOR_SCHEME['primary']}"})
            ], width=3),
            dbc.Col([
                dbc.Card([
                    dbc.CardBody([
                        html.Div([
                            html.H6("Avg Overall Score", style={'color': COLOR_SCHEME['primary']}),
                            html.H3(f"{avg_score:.1f}", style={'color': COLOR_SCHEME['secondary']})
                        ], className="text-center")
                    ])
                ], style={'border': f"2px solid {COLOR_SCHEME['primary']}"})
            ], width=3),
            dbc.Col([
                dbc.Card([
                    dbc.CardBody([
                        html.Div([
                            html.H6("Risk Students", style={'color': COLOR_SCHEME['danger']}),
                            html.H3(f"{risk_count:,}", style={'color': COLOR_SCHEME['danger']})
                        ], className="text-center")
                    ])
                ], style={'border': f"2px solid {COLOR_SCHEME['danger']}"})
            ], width=3)
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardHeader("Dashboard Objectives", style={'backgroundColor': COLOR_SCHEME['primary'], 'color': 'white'}),
                    dbc.CardBody([
                        html.Ul([
                            html.Li("Analyze Ethiopian student performance patterns"),
                            html.Li("Predict individual student academic outcomes"),
                            html.Li("Identify at-risk students for early intervention"),
                            html.Li("Understand school and regional disparities"),
                            html.Li("Provide actionable recommendations for educators"),
                            html.Li("Cluster students based on performance characteristics")
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Feature Summary Statistics (First 10 Features)", style={'color': COLOR_SCHEME['dark']}),
                html.P("Features as columns, Statistics as rows", className="text-muted mb-3"),
                dbc.Card([
                    dbc.CardBody([
                        html.Div([
                            dash.dash_table.DataTable(
                                id='feature-summary-table',
                                columns=[{"name": col, "id": col} for col in feature_summary_df.columns],
                                data=feature_summary_df.to_dict('records'),
                                style_table={
                                    'overflowX': 'auto',
                                    'height': '300px',
                                    'overflowY': 'auto',
                                    'maxWidth': '100%',
                                },
                                style_cell={
                                    'textAlign': 'center',
                                    'padding': '8px',
                                    'fontSize': '11px',
                                    'minWidth': '100px',
                                    'maxWidth': '150px',
                                    'whiteSpace': 'normal'
                                },
                                style_header={
                                    'backgroundColor': COLOR_SCHEME['primary'],
                                    'color': 'white',
                                    'fontWeight': 'bold',
                                    'textAlign': 'center'
                                },
                                style_data_conditional=[
                                    {
                                        'if': {'row_index': 'odd'},
                                        'backgroundColor': COLOR_SCHEME['light']
                                    },
                                    {
                                        'if': {'column_id': 'Statistic'},
                                        'fontWeight': 'bold',
                                        'backgroundColor': COLOR_SCHEME['dark'],
                                        'color': 'white'
                                    }
                                ],
                                fixed_columns={'headers': True, 'data': 1},
                                fixed_rows={'headers': True}
                            )
                        ]),
                        html.P("Note: Showing first 10 features for readability. Statistics include mean, standard deviation, quartiles, and missing values.",
                              className="text-muted mt-3")
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Overall Average & Regional Distribution", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                dcc.Graph(
                                    figure=create_score_distribution_plot(),
                                    style={'height': '400px'}
                                )
                            ], width=6),
                            dbc.Col([
                                dcc.Graph(
                                    figure=create_students_by_region_plot(),
                                    style={'height': '400px'}
                                )
                            ], width=6)
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Feature Analysis", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                dcc.Graph(
                                    figure=create_feature_category_plot(),
                                    style={'height': '400px'}
                                )
                            ], width=6),
                            dbc.Col([
                                dcc.Graph(
                                    figure=create_datatype_bar_plot(),
                                    style={'height': '400px'}
                                )
                            ], width=6)
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Feature Correlation Analysis", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dcc.Graph(figure=create_correlation_heatmap())
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Dataset Summary", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        html.H5("Preprocessing Steps:", style={'color': COLOR_SCHEME['primary']}),
                        html.Ul([
                            html.Li(f"Loaded {df_raw.shape[0]:,} student records with {df_original.shape[1]} original features"),
                            html.Li("Aggregated grade-level scores into education stages"),
                            html.Li("Created engagement and textbook access composites"),
                            html.Li("Encoded categorical variables"),
                            html.Li(f"Final dataset: {df_clean.shape[1]-2} features after preprocessing")
                        ]),
                        html.H5("Feature Categories:", style={'color': COLOR_SCHEME['primary'], 'marginTop': '20px'}),
                        html.Ul([
                            html.Li(f"Student Factors: Demographic and personal characteristics"),
                            html.Li(f"Academic Factors: Performance metrics and engagement"),
                            html.Li(f"School Factors: Institutional resources and environment"),
                            html.Li(f"Regional Factors: Geographic and regional indicators"),
                            html.Li(f"Health Factors: Health-related conditions"),
                            html.Li(f"Other: Miscellaneous features")
                        ])
                    ])
                ], style={'border': f"2px solid {COLOR_SCHEME['primary']}"})
            ])
        ])
    ])

def create_models_page():
    """Create models (risk analysis) page content with SHAP plots and National Exam Score Analysis"""
    performance_data = {
        'Model': ['GradientBoosting', 'RandomForest', 'XGBoost'],
        'MAE': [2.988437, 3.067115, regression_models.get('XGBoost', {}).get('mae', 0)],
        'RMSE': [3.730292, 3.833611, regression_models.get('XGBoost', {}).get('rmse', 0)],
        'R¬≤': [0.7855, 0.773, regression_models.get('XGBoost', {}).get('r2', 0)]
    }

    performance_df = pd.DataFrame(performance_data)
    national_exam_df = create_national_exam_performance_table()

    return html.Div([
        dbc.Row([
            dbc.Col([
                html.H1("Overall Models Analysis Dashboard",
                       className="mb-4",
                       style={'color': COLOR_SCHEME['dark'], 'fontWeight': 'bold'}),
                html.P("Comprehensive model performance analysis and risk assessment with SHAP explanations",
                      className="text-muted mb-4")
            ])
        ]),

        dbc.Row([
            dbc.Col([
                html.H4("National Exam Score Model Analysis", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardHeader("Model Performance for Total National Exam Score Prediction",
                                  style={'backgroundColor': COLOR_SCHEME['warning'], 'color': 'white', 'fontWeight': 'bold'}),
                    dbc.CardBody([
                        html.Div([
                            dbc.Row([
                                dbc.Col([
                                    dcc.Graph(
                                        figure=create_national_exam_model_comparison_plot(),
                                        style={'height': '400px', 'width': '100%'}
                                    )
                                ], width=8),
                                dbc.Col([
                                    html.H5("Best Model Details", style={'color': COLOR_SCHEME['primary']}),
                                    html.P(f"Best Model: Gradient Boosting", style={'fontSize': '16px', 'fontWeight': 'bold'}),
                                    html.P(f"R¬≤ Score: 0.4380", style={'fontSize': '15px'}),
                                    html.P(f"MAE: 0.0814", style={'fontSize': '15px'}),
                                    html.P(f"RMSE: 0.1071", style={'fontSize': '15px'}),
                                    html.Hr(),
                                    html.H6("Interpretation:", style={'color': COLOR_SCHEME['primary']}),
                                    html.Ul([
                                        html.Li("Gradient Boosting achieved the best performance"),
                                        html.Li("Model explains ~43.8% of variance in National Exam Scores"),
                                        html.Li("Average prediction error: ~0.08 points"),
                                        html.H6("Durbin-Watson Statistic: 2.00", style={'color': COLOR_SCHEME['success']}),
                                        html.Li("Durbin-Watson shows independent residuals (good model fit)")
                                    ], style={'fontSize': '13px'})
                                ], width=4)
                            ]),

                            html.H5("Complete Model Performance Table", style={'color': COLOR_SCHEME['dark'], 'marginTop': '30px'}),
                            html.Div([
                                dash.dash_table.DataTable(
                                    id='national-exam-performance-table',
                                    columns=[{"name": col, "id": col} for col in national_exam_df.columns],
                                    data=national_exam_df.to_dict('records'),
                                    style_table={'overflowX': 'auto', 'marginBottom': '20px'},
                                    style_cell={
                                        'textAlign': 'center',
                                        'padding': '12px',
                                        'fontSize': '14px'
                                    },
                                    style_header={
                                        'backgroundColor': COLOR_SCHEME['warning'],
                                        'color': 'white',
                                        'fontWeight': 'bold',
                                        'textAlign': 'center'
                                    },
                                    style_data_conditional=[
                                        {
                                            'if': {'column_id': 'R2_Score'},
                                            'backgroundColor': COLOR_SCHEME['light'],
                                            'fontWeight': 'bold'
                                        },
                                        {
                                            'if': {'row_index': 0},
                                            'backgroundColor': '#D1E7DD'
                                        },
                                        {
                                            'if': {'row_index': 'odd'},
                                            'backgroundColor': '#F8F9FA'
                                        }
                                    ]
                                )
                            ]),

                            html.Hr(style={'marginTop': '30px', 'marginBottom': '30px'}),

                            html.H5("Feature Importance - National Exam Score Model", style={'color': COLOR_SCHEME['dark']}),
                            dbc.Card([
                                dbc.CardBody([
                                    dcc.Graph(
                                        figure=create_national_exam_feature_importance_plot(),
                                        style={'height': '550px', 'width': '100%'}
                                    ),
                                    html.Div([
                                        html.H6("Key Insights:", className="mt-4", style={'color': COLOR_SCHEME['primary']}),
                                        html.Ul([
                                            html.Li("Score_x_Participation is the most important feature (73.6%)"),
                                            html.Li("Overall Homework completion contributes 7.2%"),
                                            html.Li("School Academic Score contributes 6.7%"),
                                            html.Li("Overall Test Score Average contributes 4.3%"),
                                            html.Li("Other factors have smaller but meaningful contributions")
                                        ], style={'fontSize': '14px'})
                                    ])
                                ])
                            ], style={'border': f'1px solid {COLOR_SCHEME["warning"]}'})
                        ])
                    ])
                ], style={'border': f'2px solid {COLOR_SCHEME["warning"]}'})
            ])
        ], className="mb-5"),

        dbc.Row([
            dbc.Col([
                html.H4("Overall Average Model Performance", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardHeader("Model Performance for Overall Average Prediction",
                                  style={'backgroundColor': COLOR_SCHEME['primary'], 'color': 'white'}),
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                dcc.Graph(figure=create_regression_comparison_plot(), style={'height': '400px'})
                            ], width=8),
                            dbc.Col([
                                html.H5("Model Details", style={'color': COLOR_SCHEME['secondary']}),
                                html.P(f"Best Model: {best_reg_model}", style={'fontSize': '16px', 'fontWeight': 'bold'}),
                                html.P(f"R¬≤ Score: {regression_models[best_reg_model]['r2']:.3f}" if best_reg_model in regression_models else "R¬≤: N/A", style={'fontSize': '15px'}),
                                html.P(f"MAE: {regression_models[best_reg_model]['mae']:.2f}" if best_reg_model in regression_models else "MAE: N/A", style={'fontSize': '15px'}),
                                html.P(f"RMSE: {regression_models[best_reg_model]['rmse']:.2f}" if best_reg_model in regression_models else "RMSE: N/A", style={'fontSize': '15px'}),
                                html.Hr(),
                                html.H6("Interpretation:", style={'color': COLOR_SCHEME['primary']}),
                                html.Ul([
                                    html.Li(f"{best_reg_model} achieved the best performance"),
                                    html.Li(f"Model explains ~{regression_models[best_reg_model]['r2']*100:.1f}% of variance in scores" if best_reg_model in regression_models else ""),
                                    html.Li(f"Average prediction error: ~{regression_models[best_reg_model]['mae']:.1f} points" if best_reg_model in regression_models else "")
                                ], style={'fontSize': '13px'})
                            ], width=4)
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Regression Feature Importance", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dcc.Graph(figure=create_feature_importance_plot(), style={'height': '450px'})
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Actual vs Predicted Values", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dcc.Graph(figure=create_actual_vs_predicted_plot(), style={'height': '400px'})
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Model Performance Comparison", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        html.Div([
                            dash.dash_table.DataTable(
                                id='model-performance-table',
                                columns=[{"name": col, "id": col} for col in performance_df.columns],
                                data=performance_df.to_dict('records'),
                                style_table={'overflowX': 'auto', 'marginBottom': '20px'},
                                style_cell={
                                    'textAlign': 'center',
                                    'padding': '12px',
                                    'fontSize': '14px'
                                },
                                style_header={
                                    'backgroundColor': COLOR_SCHEME['primary'],
                                    'color': 'white',
                                    'fontWeight': 'bold'
                                },
                                style_data_conditional=[
                                    {
                                        'if': {'column_id': 'R¬≤'},
                                        'backgroundColor': COLOR_SCHEME['light'],
                                        'fontWeight': 'bold'
                                    },
                                    {
                                        'if': {'row_index': 'odd'},
                                        'backgroundColor': '#F8F9FA'
                                    }
                                ]
                            )
                        ]),
                        html.Hr(),
                        html.H5("Performance Metrics Interpretation:", style={'color': COLOR_SCHEME['primary'], 'marginTop': '20px'}),
                        html.Ul([
                            html.Li("MAE (Mean Absolute Error): Lower is better (average prediction error in points)"),
                            html.Li("RMSE (Root Mean Square Error): Lower is better (penalizes larger errors)"),
                            html.Li("R¬≤ Score: Higher is better (variance explained by model)")
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("SHAP Analysis for Risk Classification", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardHeader("Model Interpretability with SHAP", style={'backgroundColor': COLOR_SCHEME['secondary'], 'color': 'white'}),
                    dbc.CardBody([
                        html.P("SHAP (SHapley Additive exPlanations) values explain how each feature contributes to individual predictions, providing both global and local interpretability for the risk classification model.",
                              className="text-muted mb-4"),

                        dbc.Row([
                            dbc.Col([
                                html.H5("Global SHAP Importance (Bar)", style={'color': COLOR_SCHEME['primary']}),
                                dcc.Graph(figure=create_shap_global_plot_classification(), style={'height': '400px'}),
                                html.P("Average impact magnitude of each feature across all predictions",
                                      className="text-muted text-center mt-2")
                            ], width=6),
                            dbc.Col([
                                html.H5("SHAP Summary (Beeswarm)", style={'color': COLOR_SCHEME['primary']}),
                                dcc.Graph(figure=create_shap_summary_plot(), style={'height': '400px'}),
                                html.P("Distribution of SHAP values for top features",
                                      className="text-muted text-center mt-2")
                            ], width=6)
                        ]),

                        html.Hr(),

                        html.H5("SHAP Value Interpretation:", style={'color': COLOR_SCHEME['primary']}),
                        html.Ul([
                            html.Li("Blue: Feature increases the risk probability"),
                            html.Li("Red: Feature decreases the risk probability"),
                            html.Li("Magnitude: Larger absolute values indicate stronger influence"),
                            html.Li("Global importance: Average of absolute SHAP values across all predictions")
                        ]),

                        html.H5("Key Insights from SHAP Analysis:", style={'color': COLOR_SCHEME['primary'], 'marginTop': '20px'}),
                        html.Ul([
                            html.Li("School resources are the most important feature for risk prediction"),
                            html.Li("Textbook access strongly impacts individual student risk"),
                            html.Li("Parental involvement shows significant protective effect"),
                            html.Li("Teacher-student ratio has negative impact when too high"),
                            html.Li("Health issues show varied impact depending on severity")
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Risk Classification Performance", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                html.H5("Gradient Boosting Classifier Metrics", style={'color': COLOR_SCHEME['primary']}),
                                html.P(f"F1-Score: {classification_model['f1']:.3f}", style={'fontSize': '16px'}),
                                html.P(f"ROC-AUC: {classification_model['roc_auc']:.3f}", style={'fontSize': '16px'}),
                                html.P(f"Accuracy: {(classification_model['cm'].diagonal().sum() / classification_model['cm'].sum()):.3f}" if classification_model['cm'].sum() > 0 else "Accuracy: N/A", style={'fontSize': '16px'}),
                                html.Hr(),
                                html.H6("Interpretation:", style={'color': COLOR_SCHEME['primary']}),
                                html.Ul([
                                    html.Li("F1-Score > 0.75 indicates good performance"),
                                    html.Li("ROC-AUC > 0.89 shows excellent discrimination"),
                                    html.Li("Model effectively identifies at-risk students")
                                ])
                            ], width=6),
                            dbc.Col([
                                dcc.Graph(
                                    figure=create_confusion_matrix_plot(),
                                    style={'height': '350px'}
                                )
                            ], width=6)
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardBody([
                        dcc.Graph(figure=create_roc_curve_plot(), style={'height': '400px'})
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardBody([
                        html.H4("Top Risk Factors & Intervention Framework", style={'color': COLOR_SCHEME['dark']}),
                        html.Hr(),
                        dbc.Row([
                            dbc.Col([
                                html.H5("Most Important Risk Factors", style={'color': COLOR_SCHEME['danger']}),
                                html.Ul([
                                    html.Li("Low School Resources Score"),
                                    html.Li("Poor Textbook Access"),
                                    html.Li("Low Student Engagement"),
                                    html.Li("High Teacher‚ÄìStudent Ratio"),
                                    html.Li("Low Parental Involvement and other..."),
                                ])
                            ], width=6),
                            dbc.Col([
                                html.H5("Risk Intervention Framework", style={'color': COLOR_SCHEME['danger']}),
                                html.Ul([
                                    html.Li("Tier 1 (High Risk): Multiple risk factors present"),
                                    html.Li("Tier 2 (Medium Risk): 2‚Äì3 risk factors present"),
                                    html.Li("Tier 3 (Low Risk): 0‚Äì1 risk factors present")
                                ])
                            ], width=6)
                        ])
                    ])
                ])
            ], width=12)
        ])
    ])

def create_prediction_page():
    """Create prediction page with input form for specified columns"""

    default_values = {
        'Gender': 'Male',
        'Date_of_Birth': '2005-06-15',
        'Region': 'Oromia',
        'Health_Issue': 'No Issue',
        'Father_Education': 'High School',
        'Mother_Education': 'High School',
        'Parental_Involvement': 0.5,
        'Home_Internet_Access': 'No',
        'Electricity_Access': 'No',
        'School_Type': 'Public',
        'School_Location': 'Rural',
        'Teacher_Student_Ratio': 40,
        'School_Resources_Score': 0.5,
        'School_Academic_Score': 0.5,
        'Student_to_Resources_Ratio': 20,
        'Field_Choice': 'Social',
        'Career_Interest': 'Teacher',
        'Overall_Textbook_Access_Composite': 0.5,
        'Overall_Avg_Attendance': 75,
        'Overall_Avg_Homework': 65,
        'Overall_Avg_Participation': 70
    }

    return html.Div([
        dbc.Row([
            dbc.Col([
                html.H1("Make Student Performance Prediction",
                       className="mb-4",
                       style={'color': COLOR_SCHEME['dark'], 'fontWeight': 'bold'}),
                html.P("Enter student details to predict academic performance and risk level",
                      className="text-muted mb-4"),
                html.P("Please enter values for all required columns:", className="text-muted"),
                html.Code("['Gender', 'Date_of_Birth', 'Region', 'Health_Issue', 'Father_Education', 'Mother_Education', 'Parental_Involvement', 'Home_Internet_Access', 'Electricity_Access', 'School_Type', 'School_Location', 'Teacher_Student_Ratio', 'School_Resources_Score','School_Academic_Score', 'Student_to_Resources_Ratio','Field_Choice','Career_Interest','Overall_Textbook_Access_Composite', 'Overall_Avg_Attendance', 'Overall_Avg_Homework', 'Overall_Avg_Participation']",
                        style={'display': 'block', 'padding': '10px', 'backgroundColor': COLOR_SCHEME['light'], 'margin': '10px 0'}),
                html.P("Note: Engagement metrics (Attendance, Homework, Participation) use 1-100 scale",
                      className="text-muted")
            ])
        ]),

        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardHeader("Student Information Input - ALL REQUIRED COLUMNS",
                                 style={'backgroundColor': COLOR_SCHEME['primary'], 'color': 'white'}),
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                html.Label("Gender", className="mb-2"),
                                dcc.Dropdown(
                                    id='gender-input',
                                    options=[
                                        {'label': 'Male', 'value': 'Male'},
                                        {'label': 'Female', 'value': 'Female'}
                                    ],
                                    value=default_values['Gender'],
                                    clearable=False
                                ),
                            ], width=3),
                            dbc.Col([
                                html.Label("Date of Birth (YYYY-MM-DD)", className="mb-2"),
                                dbc.Input(
                                    id='dob-input',
                                    type='text',
                                    placeholder="2005-06-15",
                                    value=default_values['Date_of_Birth']
                                ),
                                html.Small("Format: YYYY-MM-DD", className="text-muted")
                            ], width=3),
                            dbc.Col([
                                html.Label("Region", className="mb-2"),
                                dcc.Dropdown(
                                    id='region-input',
                                    options=[
                                        {'label': 'Addis Ababa', 'value': 'Addis Ababa'},
                                        {'label': 'Afar', 'value': 'Afar'},
                                        {'label': 'Amhara', 'value': 'Amhara'},
                                        {'label': 'Benishangul-Gumuz', 'value': 'Benishangul-Gumuz'},
                                        {'label': 'Dire Dawa', 'value': 'Dire Dawa'},
                                        {'label': 'Gambela', 'value': 'Gambela'},
                                        {'label': 'Harari', 'value': 'Harari'},
                                        {'label': 'Oromia', 'value': 'Oromia'},
                                        {'label': 'Sidama', 'value': 'Sidama'},
                                        {'label': 'SNNP', 'value': 'SNNP'},
                                        {'label': 'Somali', 'value': 'Somali'},
                                        {'label': 'South West Ethiopia', 'value': 'South West Ethiopia'},
                                        {'label': 'Tigray', 'value': 'Tigray'}
                                    ],
                                    value=default_values['Region'],
                                    clearable=False
                                ),
                            ], width=3),
                            dbc.Col([
                                html.Label("Health Issue", className="mb-2"),
                                dcc.Dropdown(
                                    id='health-input',
                                    options=[
                                        {'label': 'No Issue', 'value': 'No Issue'},
                                        {'label': 'Dental Problems', 'value': 'Dental Problems'},
                                        {'label': 'Vision Issues', 'value': 'Vision Issues'},
                                        {'label': 'Hearing Issues', 'value': 'Hearing Issues'},
                                        {'label': 'Anemia', 'value': 'Anemia'},
                                        {'label': 'Parasitic Infections', 'value': 'Parasitic Infections'},
                                        {'label': 'Respiratory Issues', 'value': 'Respiratory Issues'},
                                        {'label': 'Malnutrition', 'value': 'Malnutrition'},
                                        {'label': 'Physical Disability', 'value': 'Physical Disability'},
                                        {'label': 'Chronic Illness', 'value': 'Chronic Illness'}
                                    ],
                                    value=default_values['Health_Issue'],
                                    clearable=False
                                ),
                            ], width=3),
                        ], className="mb-3"),

                        dbc.Row([
                            dbc.Col([
                                html.Label("Father Education", className="mb-2"),
                                dcc.Dropdown(
                                    id='father-edu-input',
                                    options=[
                                        {'label': 'Unknown', 'value': 'Unknown'},
                                        {'label': 'Primary', 'value': 'Primary'},
                                        {'label': 'High School', 'value': 'High School'},
                                        {'label': 'College', 'value': 'College'},
                                        {'label': 'University', 'value': 'University'}
                                    ],
                                    value=default_values['Father_Education'],
                                    clearable=False
                                ),
                            ], width=3),
                            dbc.Col([
                                html.Label("Mother Education", className="mb-2"),
                                dcc.Dropdown(
                                    id='mother-edu-input',
                                    options=[
                                        {'label': 'Unknown', 'value': 'Unknown'},
                                        {'label': 'Primary', 'value': 'Primary'},
                                        {'label': 'High School', 'value': 'High School'},
                                        {'label': 'College', 'value': 'College'},
                                        {'label': 'University', 'value': 'University'}
                                    ],
                                    value=default_values['Mother_Education'],
                                    clearable=False
                                ),
                            ], width=3),
                            dbc.Col([
                                html.Label("Parental Involvement (0-1)", className="mb-2"),
                                dbc.Input(
                                    id='parental-input',
                                    type='number',
                                    min=0,
                                    max=1,
                                    step=0.1,
                                    value=default_values['Parental_Involvement']
                                ),
                                html.Small("0 = None, 1 = High", className="text-muted")
                            ], width=3),
                            dbc.Col([
                                html.Label("Home Internet Access", className="mb-2"),
                                dcc.Dropdown(
                                    id='internet-input',
                                    options=[
                                        {'label': 'No', 'value': 'No'},
                                        {'label': 'Yes', 'value': 'Yes'}
                                    ],
                                    value=default_values['Home_Internet_Access'],
                                    clearable=False
                                ),
                            ], width=3),
                        ], className="mb-3"),

                        dbc.Row([
                            dbc.Col([
                                html.Label("Electricity Access", className="mb-2"),
                                dcc.Dropdown(
                                    id='electricity-input',
                                    options=[
                                        {'label': 'No', 'value': 'No'},
                                        {'label': 'Yes', 'value': 'Yes'}
                                    ],
                                    value=default_values['Electricity_Access'],
                                    clearable=False
                                ),
                            ], width=3),
                            dbc.Col([
                                html.Label("School Type", className="mb-2"),
                                dcc.Dropdown(
                                    id='school-type-input',
                                    options=[
                                        {'label': 'Public', 'value': 'Public'},
                                        {'label': 'Private', 'value': 'Private'},
                                        {'label': 'NGO-operated', 'value': 'NGO-operated'},
                                        {'label': 'Faith-based', 'value': 'Faith-based'}
                                    ],
                                    value=default_values['School_Type'],
                                    clearable=False
                                ),
                            ], width=3),
                            dbc.Col([
                                html.Label("School Location", className="mb-2"),
                                dcc.Dropdown(
                                    id='location-input',
                                    options=[
                                        {'label': 'Rural', 'value': 'Rural'},
                                        {'label': 'Urban', 'value': 'Urban'}
                                    ],
                                    value=default_values['School_Location'],
                                    clearable=False
                                ),
                            ], width=3),
                            dbc.Col([
                                html.Label("Teacher-Student Ratio", className="mb-2"),
                                dbc.Input(
                                    id='ratio-input',
                                    type='number',
                                    min=10,
                                    max=100,
                                    step=1,
                                    value=int(default_values['Teacher_Student_Ratio'])
                                ),
                                html.Small("e.g., 40 for 40:1 ratio", className="text-muted")
                            ], width=3),
                        ], className="mb-3"),

                        dbc.Row([
                            dbc.Col([
                                html.Label("School Resources Score (0-1)", className="mb-2"),
                                dbc.Input(
                                    id='resources-input',
                                    type='number',
                                    min=0,
                                    max=1,
                                    step=0.1,
                                    value=round(default_values['School_Resources_Score'], 2)
                                ),
                                html.Small("0 = Poor, 1 = Excellent", className="text-muted")
                            ], width=3),
                            dbc.Col([
                                html.Label("School Academic Score (0-1)", className="mb-2"),
                                dbc.Input(
                                    id='academic-input',
                                    type='number',
                                    min=0,
                                    max=1,
                                    step=0.1,
                                    value=round(default_values['School_Academic_Score'], 2)
                                ),
                                html.Small("0 = Poor, 1 = Excellent", className="text-muted")
                            ], width=3),
                            dbc.Col([
                                html.Label("Student-to-Resources Ratio", className="mb-2"),
                                dbc.Input(
                                    id='student-resources-input',
                                    type='number',
                                    min=5,
                                    max=50,
                                    step=1,
                                    value=int(default_values['Student_to_Resources_Ratio'])
                                ),
                                html.Small("Lower is better", className="text-muted")
                            ], width=3),
                            dbc.Col([
                                html.Label("Field Choice", className="mb-2"),
                                dcc.Dropdown(
                                    id='field-input',
                                    options=[
                                        {'label': 'Social', 'value': 'Social'},
                                        {'label': 'Natural', 'value': 'Natural'}
                                    ],
                                    value=default_values['Field_Choice'],
                                    clearable=False
                                ),
                            ], width=3),
                        ], className="mb-3"),

                        dbc.Row([
                            dbc.Col([
                                html.Label("Career Interest", className="mb-2"),
                                dcc.Dropdown(
                                    id='career-input',
                                    options=[
                                        {'label': 'Teacher', 'value': 'Teacher'},
                                        {'label': 'Doctor', 'value': 'Doctor'},
                                        {'label': 'Engineer', 'value': 'Engineer'},
                                        {'label': 'Farmer', 'value': 'Farmer'},
                                        {'label': 'Business', 'value': 'Business'},
                                        {'label': 'Government', 'value': 'Government'},
                                        {'label': 'Unknown', 'value': 'Unknown'}
                                    ],
                                    value=default_values['Career_Interest'],
                                    clearable=False
                                ),
                            ], width=3),
                            dbc.Col([
                                html.Label("Overall Textbook Access (0-1)", className="mb-2"),
                                dbc.Input(
                                    id='textbook-input',
                                    type='number',
                                    min=0,
                                    max=1,
                                    step=0.01,
                                    value=round(default_values['Overall_Textbook_Access_Composite'], 2)
                                ),
                                html.Small("0 = None, 1 = Full access", className="text-muted")
                            ], width=3),
                            dbc.Col([
                                html.Label("Overall Avg Attendance (0-100)", className="mb-2"),
                                dbc.Input(
                                    id='attendance-input',
                                    type='number',
                                    min=0,
                                    max=100,
                                    step=1,
                                    value=round(default_values['Overall_Avg_Attendance'], 2)
                                ),
                                html.Small("0 = 0%, 100 = 100%", className="text-muted")
                            ], width=3),
                            dbc.Col([
                                html.Label("Overall Avg Homework (0-100)", className="mb-2"),
                                dbc.Input(
                                    id='homework-input',
                                    type='number',
                                    min=0,
                                    max=100,
                                    step=1,
                                    value=round(default_values['Overall_Avg_Homework'], 2)
                                ),
                                html.Small("0 = 0%, 100 = 100%", className="text-muted")
                            ], width=3),
                        ], className="mb-3"),

                        dbc.Row([
                            dbc.Col([
                                html.Label("Overall Avg Participation (0-100)", className="mb-2"),
                                dbc.Input(
                                    id='participation-input',
                                    type='number',
                                    min=0,
                                    max=100,
                                    step=1,
                                    value=round(default_values['Overall_Avg_Participation'], 2)
                                ),
                                html.Small("0 = 0%, 100 = 100%", className="text-muted")
                            ], width=3),
                            dbc.Col([
                                dbc.Button(
                                    "Make Prediction",
                                    id="predict-button",
                                    color="success",
                                    className="w-100 mt-4",
                                    size="lg",
                                    style={'height': '50px', 'fontSize': '16px'}
                                ),
                            ], width=3),
                            dbc.Col([
                                dbc.Button(
                                    "Reset Form",
                                    id="reset-button",
                                    color="warning",
                                    className="w-100 mt-4",
                                    size="lg",
                                    style={'height': '50px', 'fontSize': '16px'}
                                ),
                            ], width=3),
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        html.Div(id="prediction-results")
    ])

def create_prediction_results():
    """Create prediction results display with regression output"""
    global prediction_result

    if prediction_result is None:
        return html.Div([
            dbc.Alert(
                "Please make a prediction first by clicking the 'Make Prediction' button above.",
                color="info",
                className="mt-3"
            )
        ])

    card_color = "danger" if prediction_result['is_risk'] else "success"
    prediction_confidence = prediction_result['regression_metrics']['r2'] * 100
    risk_plot = create_risk_distribution_plot()

    return html.Div([
        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardHeader(f"Prediction Results", style={'backgroundColor': COLOR_SCHEME[card_color], 'color': 'white'}),
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                html.H4("Academic Performance Prediction", className="mb-3"),
                                dbc.Card([
                                    dbc.CardBody([
                                        html.H1(f"{prediction_result['predicted_score']:.1f}",
                                               className="text-center",
                                               style={'color': COLOR_SCHEME['primary'], 'fontSize': '48px'}),
                                        html.H6("Predicted Overall Average Score", className="text-center text-muted"),
                                        html.Hr(),
                                        html.Div([
                                            html.H6("Model Used:", className="mt-2"),
                                            html.P(f"{prediction_result['regression_metrics']['model']}",
                                                  style={'fontSize': '14px'}),
                                            html.H6("Prediction Confidence:", className="mt-2"),
                                            dbc.Progress(
                                                value=prediction_confidence,
                                                label=f"{prediction_confidence:.1f}%",
                                                color="success" if prediction_confidence > 70 else "warning" if prediction_confidence > 50 else "danger",
                                                className="mb-2"
                                            ),
                                            html.Small(f"Based on R¬≤ score of {prediction_result['regression_metrics']['r2']:.3f}",
                                                     className="text-muted")
                                        ])
                                    ])
                                ])
                            ], width=6),
                            dbc.Col([
                                html.H4("Risk Assessment", className="mb-3"),
                                dbc.Card([
                                    dbc.CardBody([
                                        html.H1(f"{'AT RISK' if prediction_result['is_risk'] else 'NOT AT RISK'}",
                                               className="text-center",
                                               style={'color': COLOR_SCHEME[card_color], 'fontSize': '36px'}),
                                        html.H6(f"Risk Probability: {prediction_result['risk_probability']*100:.1f}%",
                                              className="text-center"),
                                        html.Div([
                                            dbc.Progress(
                                                value=prediction_result['risk_probability'] * 100,
                                                label=f"{prediction_result['risk_probability']*100:.1f}%",
                                                color=card_color,
                                                className="mb-2",
                                                style={'height': '30px'}
                                            )
                                        ], style={'margin': '20px 0'}),
                                        html.Div([
                                            html.H6("Risk Classification Model:", className="mt-2"),
                                            html.P(f"Gradient Boosting Classifier", style={'fontSize': '14px'}),
                                            html.P(f"F1-Score: {prediction_result['classification_metrics']['f1']:.3f}",
                                                  style={'fontSize': '12px', 'marginBottom': '5px'}),
                                            html.P(f"ROC-AUC: {prediction_result['classification_metrics']['roc_auc']:.3f}",
                                                  style={'fontSize': '12px'})
                                        ])
                                    ])
                                ])
                            ], width=6)
                        ]),

                        html.Hr(),

                        dbc.Row([
                            dbc.Col([
                                html.H4("Risk Factors & Areas for Improvement", className="mb-3"),
                                dbc.Card([
                                    dbc.CardBody([
                                        html.Ul([html.Li(html.Strong(cause)) for cause in prediction_result['risk_causes']])
                                        if prediction_result['risk_causes'] else
                                        html.P("No specific risk factors identified", className="text-muted text-center")
                                    ])
                                ])
                            ], width=6),
                            dbc.Col([
                                html.H4("Recommendations", className="mb-3"),
                                dbc.Card([
                                    dbc.CardBody([
                                        html.Ul([html.Li(html.Span(rec.replace('‚Ä¢ ', ''))) for rec in prediction_result['recommendations']])
                                    ])
                                ])
                            ], width=6)
                        ]),

                        html.Hr(),

                        dbc.Row([
                            dbc.Col([
                                html.H4("Risk Distribution", className="mb-3"),
                                dbc.Card([
                                    dbc.CardBody([
                                        dcc.Graph(
                                            figure=risk_plot if risk_plot else {},
                                            style={'height': '300px'}
                                        ),
                                        html.P("This shows the predicted risk status distribution for this student.",
                                            className="text-muted text-center mt-2")
                                    ])
                                ])
                            ], width=12),
                        ]),

                        html.Hr(),

                        dbc.Row([
                            dbc.Col([
                                html.H6("Processing Steps Completed:", className="mt-3"),
                                html.Div([
                                    html.Span(prediction_result['processing_steps']['step1'], style={'marginRight': '10px'}),
                                    html.Span("‚Üí", style={'marginRight': '10px', 'color': COLOR_SCHEME['primary']}),
                                    html.Span(prediction_result['processing_steps']['step2'], style={'marginRight': '10px'}),
                                    html.Span("‚Üí", style={'marginRight': '10px', 'color': COLOR_SCHEME['primary']}),
                                    html.Span(prediction_result['processing_steps']['step3'], style={'marginRight': '10px'}),
                                    html.Span("‚Üí", style={'marginRight': '10px', 'color': COLOR_SCHEME['primary']}),
                                    html.Span(prediction_result['processing_steps']['step4'], style={'marginRight': '10px'}),
                                    html.Span("‚Üí", style={'marginRight': '10px', 'color': COLOR_SCHEME['primary']}),
                                    html.Span(prediction_result['processing_steps']['step5'])
                                ], style={'fontSize': '12px', 'backgroundColor': COLOR_SCHEME['light'], 'padding': '10px', 'borderRadius': '5px'})
                            ])
                        ])
                    ])
                ], style={'border': f"3px solid {COLOR_SCHEME[card_color]}"})
            ])
        ], className="mt-4")
    ])

def create_clustering_page():
    """Create clustering analysis page content with regional cluster distribution heatmap"""
    if clustering_analysis is None:
        return html.Div([
            dbc.Alert(
                "Clustering analysis data is not available. Please wait for initialization to complete.",
                color="warning",
                className="mt-3"
            )
        ])

    cluster_counts = clustering_analysis.get('cluster_sizes', pd.Series())

    if cluster_counts.empty:
        cluster_counts = pd.Series({'Low': 0, 'Medium': 0, 'High': 0})

    complete_cluster_profile = pd.DataFrame({
        'Performance_Cluster': ['Low', 'Medium', 'High'],
        'Overall_Engagement_Score': [68.026416, 78.301930, 73.043863],
        'School_Academic_Score': [0.424432, 0.445902, 0.695637],
        'Teacher_Student_Ratio': [49.957881, 50.049940, 34.502018],
        'Student_to_Resources_Ratio': [22.629008, 22.578901, 15.891499],
        'Parental_Involvement': [0.301593, 0.484578, 0.365762],
        'Overall_Textbook_Access_Composite': [0.361508, 0.375552, 0.630930],
        'Total_National_Exam_Score': [286.685256, 331.849464, 334.484428],
        'Overall_Average': [47.309911, 54.283330, 62.559605],
        'Overall_Avg_Attendance': [85.879107, 87.396578, 86.733333],
        'Overall_Avg_Homework': [52.551327, 73.066655, 62.319139],
        'Overall_Avg_Participation': [59.697917, 71.411008, 65.515959],
        'School_Resources_Score': [0.424432, 0.445902, 0.695637]
    })

    complete_cluster_profile.set_index('Performance_Cluster', inplace=True)

    return html.Div([
        dbc.Row([
            dbc.Col([
                html.H1("Student Clustering Analysis",
                       className="mb-4",
                       style={'color': COLOR_SCHEME['dark'], 'fontWeight': 'bold'}),
                html.P("Grouping students based on academic performance patterns",
                      className="text-muted mb-4")
            ])
        ]),

        dbc.Row([
            dbc.Col([
                html.H4("Performance Cluster Distribution", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                dcc.Graph(
                                    figure=create_cluster_distribution_plot(),
                                    style={'height': '400px'}
                                )
                            ], width=6),
                            dbc.Col([
                                html.H5("Cluster Analysis", style={'color': COLOR_SCHEME['primary']}),
                                html.P(f"Silhouette Score: {clustering_analysis.get('silhouette_score', 0):.4f}"),
                                html.P("Three distinct student groups identified:"),
                                html.Ul([
                                    html.Li("High Performers: Top academic achievement"),
                                    html.Li("Medium Performers: Average performance"),
                                    html.Li("Low Performers: Require intervention")
                                ]),
                                html.Hr(),
                                html.H6("Cluster Sizes:", style={'color': COLOR_SCHEME['primary']}),
                                html.P(f"High: {cluster_counts.get('High', 0):,} students"),
                                html.P(f"Medium: {cluster_counts.get('Medium', 0):,} students"),
                                html.P(f"Low: {cluster_counts.get('Low', 0):,} students"),
                                html.Hr(),
                                html.H6("Cluster Mapping:", style={'color': COLOR_SCHEME['success']}),
                                html.Ul([
                                    html.Li("Cluster 0 ‚Üí Low Performance"),
                                    html.Li("Cluster 1 ‚Üí Medium Performance"),
                                    html.Li("Cluster 2 ‚Üí High Performance")
                                ])
                            ], width=6)
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Complete Cluster Profile Table", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardHeader("All Cluster Characteristics",
                                  style={'backgroundColor': COLOR_SCHEME['primary'], 'color': 'white'}),
                    dbc.CardBody([
                        html.Div([
                            dash.dash_table.DataTable(
                                id='cluster-profile-table',
                                columns=[{"name": col, "id": col} for col in complete_cluster_profile.reset_index().columns],
                                data=complete_cluster_profile.reset_index().to_dict('records'),
                                style_table={
                                    'overflowX': 'auto',
                                    'height': '160px',
                                    'overflowY': 'auto',
                                    'maxWidth': '100%',
                                },
                                style_cell={
                                    'textAlign': 'center',
                                    'padding': '8px',
                                    'fontSize': '12px',
                                    'minWidth': '120px',
                                    'maxWidth': '150px',
                                    'whiteSpace': 'normal'
                                },
                                style_header={
                                    'backgroundColor': COLOR_SCHEME['primary'],
                                    'color': 'white',
                                    'fontWeight': 'bold',
                                    'textAlign': 'center'
                                },
                                style_data_conditional=[
                                    {
                                        'if': {'column_id': 'Performance_Cluster'},
                                        'fontWeight': 'bold',
                                        'backgroundColor': COLOR_SCHEME['dark'],
                                        'color': 'white'
                                    },
                                    {
                                        'if': {'row_index': 0},
                                        'backgroundColor': COLOR_SCHEME['light']
                                    },
                                    {
                                        'if': {'row_index': 1},
                                        'backgroundColor': '#FFF3CD'
                                    },
                                    {
                                        'if': {'row_index': 2},
                                        'backgroundColor': '#D1E7DD'
                                    }
                                ],
                                fixed_columns={'headers': True, 'data': 1},
                                fixed_rows={'headers': True}
                            )
                        ]),
                        html.P("Note: This table shows all characteristics for each performance cluster based on the provided data.",
                              className="text-muted mt-3")
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Regional Cluster Distribution", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardHeader("Heatmap: Proportion of Students by Performance Level in Each Region",
                                  style={'backgroundColor': COLOR_SCHEME['secondary'], 'color': 'white'}),
                    dbc.CardBody([
                        dcc.Graph(
                            figure=create_regional_cluster_heatmap(),
                            style={'height': '500px'}
                        ),
                        html.Div([
                            html.H6("Interpretation:", className="mt-3", style={'color': COLOR_SCHEME['primary']}),
                            html.Ul([
                                html.Li("Red indicates higher proportion of Low performers"),
                                html.Li("Yellow indicates higher proportion of Medium performers"),
                                html.Li("Green indicates higher proportion of High performers"),
                                html.Li("Each row sums to 1.0 (100% of students in that region)")
                            ])
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardHeader("Stacked Bar Chart: Regional Performance Distribution",
                                  style={'backgroundColor': COLOR_SCHEME['warning'], 'color': 'white'}),
                    dbc.CardBody([
                        dcc.Graph(
                            figure=create_regional_cluster_barchart(),
                            style={'height': '500px'}
                        )
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Regional Risk Analysis (% Low Performance)", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dcc.Graph(figure=create_regional_risk_plot())
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Key Cluster Insights", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                html.H5("High Performers", style={'color': COLOR_SCHEME['success']}),
                                html.Ul([
                                    html.Li("Highest Overall Average: 62.6"),
                                    html.Li("Best National Exam Scores: 334.5"),
                                    html.Li("Best Textbook Access: 0.63"),
                                    html.Li("Best School Resources: 0.70"),
                                    html.Li("Lowest Teacher-Student Ratio: 34.5:1"),
                                    html.Li("Moderate Engagement Score: 73.0")
                                ])
                            ], width=4),
                            dbc.Col([
                                html.H5("Medium Performers", style={'color': COLOR_SCHEME['warning']}),
                                html.Ul([
                                    html.Li("Medium Overall Average: 54.3"),
                                    html.Li("Good National Exam Scores: 331.8"),
                                    html.Li("Highest Engagement Score: 78.3"),
                                    html.Li("Highest Homework Completion: 73.1"),
                                    html.Li("Highest Parental Involvement: 0.48"),
                                    html.Li("Best Attendance: 87.4%")
                                ])
                            ], width=4),
                            dbc.Col([
                                html.H5("Low Performers", style={'color': COLOR_SCHEME['danger']}),
                                html.Ul([
                                    html.Li("Lowest Overall Average: 47.3"),
                                    html.Li("Lowest National Exam Scores: 286.7"),
                                    html.Li("Lowest Textbook Access: 0.36"),
                                    html.Li("Lowest Homework Completion: 52.6"),
                                    html.Li("Lowest Parental Involvement: 0.30"),
                                    html.Li("Poorest School Resources: 0.42")
                                ])
                            ], width=4)
                        ])
                    ])
                ])
            ])
        ])
    ])

def create_recommendations_page():
    """Create recommendations and summary page with comprehensive summary table"""

    summary_df = create_recommendations_summary_table()

    return html.Div([
        dbc.Row([
            dbc.Col([
                html.H1("Recommendations & Summary Dashboard",
                       className="mb-4",
                       style={'color': COLOR_SCHEME['dark'], 'fontWeight': 'bold'}),
                html.P("Actionable insights and recommendations based on data analysis",
                      className="text-muted mb-4")
            ])
        ]),

        dbc.Row([
            dbc.Col([
                html.H4("Comprehensive Dashboard Summary Table", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardHeader("Complete Project Overview",
                                  style={'backgroundColor': COLOR_SCHEME['dark'], 'color': 'white', 'fontWeight': 'bold'}),
                    dbc.CardBody([
                        html.Div([
                            dash.dash_table.DataTable(
                                id='recommendations-summary-table',
                                columns=[
                                    {"name": "Metric", "id": "Metric", "type": "text"},
                                    {"name": "Value", "id": "Value", "type": "text"},
                                    {"name": "Details / Interpretation", "id": "Details", "type": "text"}
                                ],
                                data=summary_df.to_dict('records'),
                                style_table={
                                    'overflowX': 'auto',
                                    'width': '100%',
                                    'marginBottom': '20px'
                                },
                                style_cell={
                                    'textAlign': 'left',
                                    'padding': '12px',
                                    'fontSize': '14px',
                                    'fontFamily': 'Arial, sans-serif',
                                    'border': '1px solid #dee2e6'
                                },
                                style_header={
                                    'backgroundColor': COLOR_SCHEME['dark'],
                                    'color': 'white',
                                    'fontWeight': 'bold',
                                    'textAlign': 'center',
                                    'border': '1px solid #dee2e6'
                                },
                                style_data_conditional=[
                                    {
                                        'if': {'row_index': 'odd'},
                                        'backgroundColor': '#F8F9FA'
                                    },
                                    {
                                        'if': {'column_id': 'Metric', 'filter_query': '{Metric} contains "Dataset Statistics"'},
                                        'backgroundColor': '#2E86AB',
                                        'color': 'white',
                                        'fontWeight': 'bold'
                                    },
                                    {
                                        'if': {'column_id': 'Metric', 'filter_query': '{Metric} contains "Performance Clusters"'},
                                        'backgroundColor': '#18A999',
                                        'color': 'white',
                                        'fontWeight': 'bold'
                                    },
                                    {
                                        'if': {'column_id': 'Metric', 'filter_query': '{Metric} contains "Model Performance"'},
                                        'backgroundColor': '#A23B72',
                                        'color': 'white',
                                        'fontWeight': 'bold'
                                    },
                                    {
                                        'if': {'column_id': 'Metric', 'filter_query': '{Metric} contains "National Exam Model"'},
                                        'backgroundColor': '#F18F01',
                                        'color': 'white',
                                        'fontWeight': 'bold'
                                    },
                                    {
                                        'if': {'column_id': 'Metric', 'filter_query': '{Metric} contains "Regional Risk Analysis"'},
                                        'backgroundColor': '#C73E1D',
                                        'color': 'white',
                                        'fontWeight': 'bold'
                                    },
                                    {
                                        'if': {'column_id': 'Metric', 'filter_query': '{Metric} contains "Key Predictors"'},
                                        'backgroundColor': '#2C3E50',
                                        'color': 'white',
                                        'fontWeight': 'bold'
                                    },
                                    {
                                        'if': {'column_id': 'Metric', 'filter_query': '{Metric} = ""'},
                                        'backgroundColor': 'transparent',
                                        'border': 'none',
                                        'height': '10px'
                                    }
                                ],
                                style_cell_conditional=[
                                    {
                                        'if': {'column_id': 'Metric'},
                                        'width': '35%',
                                        'textAlign': 'left',
                                        'paddingLeft': '20px'
                                    },
                                    {
                                        'if': {'column_id': 'Value'},
                                        'width': '25%',
                                        'textAlign': 'center',
                                        'fontWeight': 'bold'
                                    },
                                    {
                                        'if': {'column_id': 'Details'},
                                        'width': '40%',
                                        'textAlign': 'left',
                                        'color': '#555'
                                    }
                                ]
                            )
                        ]),
                        html.P("This table provides a complete summary of all key metrics, model performance, and insights from the entire dashboard.",
                              className="text-muted mt-3")
                    ])
                ], style={'border': f'2px solid {COLOR_SCHEME["dark"]}', 'marginBottom': '30px'})
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Targeted Recommendations for Student Groups", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                dbc.Card([
                                    dbc.CardHeader("High Performers", style={'backgroundColor': COLOR_SCHEME['success'], 'color': 'white'}),
                                    dbc.CardBody([
                                        html.H6("Characteristics:", className="mb-3"),
                                        html.Ul([
                                            html.Li("Average Score: 62.6"),
                                            html.Li("Best school resources"),
                                            html.Li("Good textbook access"),
                                            html.Li("Low teacher-student ratio (34:1)")
                                        ]),
                                        html.H6("Recommendations:", className="mt-3"),
                                        html.Ul([
                                            html.Li("Enrichment programs for advanced learning"),
                                            html.Li("Leadership and mentorship opportunities"),
                                            html.Li("Preparation for national competitions"),
                                            html.Li("College readiness programs"),
                                            html.Li("STEM/STEAM initiatives")
                                        ])
                                    ])
                                ])
                            ], width=4),

                            dbc.Col([
                                dbc.Card([
                                    dbc.CardHeader("Medium Performers", style={'backgroundColor': COLOR_SCHEME['warning'], 'color': 'white'}),
                                    dbc.CardBody([
                                        html.H6("Characteristics:", className="mb-3"),
                                        html.Ul([
                                            html.Li("Average Score: 54.3"),
                                            html.Li("Highest engagement levels"),
                                            html.Li("Moderate school resources"),
                                            html.Li("High teacher-student ratio (50:1)")
                                        ]),
                                        html.H6("Recommendations:", className="mt-3"),
                                        html.Ul([
                                            html.Li("Targeted academic support"),
                                            html.Li("Study skills workshops"),
                                            html.Li("Regular progress monitoring"),
                                            html.Li("Peer tutoring programs"),
                                            html.Li("Career guidance sessions")
                                        ])
                                    ])
                                ])
                            ], width=4),

                            dbc.Col([
                                dbc.Card([
                                    dbc.CardHeader("Low Performers", style={'backgroundColor': COLOR_SCHEME['danger'], 'color': 'white'}),
                                    dbc.CardBody([
                                        html.H6("Characteristics:", className="mb-3"),
                                        html.Ul([
                                            html.Li("Average Score: 47.3"),
                                            html.Li("Lowest school resources"),
                                            html.Li("Poor textbook access"),
                                            html.Li("Limited parental involvement")
                                        ]),
                                        html.H6("Recommendations:", className="mt-3"),
                                        html.Ul([
                                            html.Li("Immediate academic intervention"),
                                            html.Li("Small group tutoring"),
                                            html.Li("Resource allocation priority"),
                                            html.Li("Parent engagement programs"),
                                            html.Li("Social-emotional support")
                                        ])
                                    ])
                                ])
                            ], width=4)
                        ])
                    ])
                ])
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                html.H4("Strategic Institutional Recommendations", style={'color': COLOR_SCHEME['dark']}),
                dbc.Card([
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                html.H5("Resource Allocation:", style={'color': COLOR_SCHEME['primary']}),
                                html.Ul([
                                    html.Li("Prioritize textbook distribution to under-resourced schools"),
                                    html.Li("Reduce class sizes in high-risk regions"),
                                    html.Li("Improve digital infrastructure and internet access"),
                                    html.Li("Provide teacher training in differentiated instruction")
                                ])
                            ], width=6),
                            dbc.Col([
                                html.H5("Program Development:", style={'color': COLOR_SCHEME['secondary']}),
                                html.Ul([
                                    html.Li("Establish parent-teacher collaboration programs"),
                                    html.Li("Implement tiered intervention systems"),
                                    html.Li("Develop early warning systems for at-risk students"),
                                    html.Li("Create recognition programs for high achievers")
                                ])
                            ], width=6)
                        ])
                    ])
                ], style={'border': f"2px solid {COLOR_SCHEME['primary']}"})
            ])
        ], className="mb-4"),

        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardBody([
                        dbc.Row([
                            dbc.Col([
                                html.H5("Next Steps:", style={'color': COLOR_SCHEME['dark']}),
                                html.Ul([
                                    html.Li("Implement targeted interventions for at-risk students"),
                                    html.Li("Allocate resources to high-need regions"),
                                    html.Li("Develop teacher training programs"),
                                    html.Li("Establish monitoring and evaluation systems"),
                                    html.Li("Expand data collection for continuous improvement")
                                ])
                            ], md=6),
                            dbc.Col([
                                html.Div([
                                    html.H6("Dashboard Utility:"),
                                    html.P("This dashboard provides actionable insights for educators, policymakers, and administrators to:"),
                                    html.Ul([
                                        html.Li("Identify at-risk students early"),
                                        html.Li("Allocate resources effectively"),
                                        html.Li("Monitor intervention effectiveness"),
                                        html.Li("Make data-driven decisions"),
                                        html.Li("Improve overall educational outcomes")
                                    ])
                                ], style={
                                    'backgroundColor': COLOR_SCHEME['light'],
                                    'padding': '15px',
                                    'borderRadius': '5px'
                                })
                            ], md=6),
                        ])
                    ])
                ], style={'border': f"2px solid {COLOR_SCHEME['dark']}"})
            ])
        ])
    ])

# =================================================================
# CALLBACKS
# =================================================================

@app.callback(
    Output("page-content", "children"),
    [Input("btn-overview", "n_clicks"),
     Input("btn-models", "n_clicks"),
     Input("btn-prediction", "n_clicks"),
     Input("btn-clustering", "n_clicks"),
     Input("btn-recommendations", "n_clicks")],
    prevent_initial_call=False
)
def display_page(overview_clicks, models_clicks, prediction_clicks, clustering_clicks, recommendations_clicks):
    ctx = callback_context

    if not ctx.triggered:
        return create_overview_page()

    button_id = ctx.triggered[0]['prop_id'].split('.')[0]

    try:
        if button_id == "btn-overview":
            return create_overview_page()
        elif button_id == "btn-models":
            return create_models_page()
        elif button_id == "btn-prediction":
            return create_prediction_page()
        elif button_id == "btn-clustering":
            if clustering_analysis is None:
                return html.Div([
                    dbc.Alert(
                        "Clustering analysis is still loading or not available. Please try again in a moment.",
                        color="warning",
                        className="mt-3"
                    )
                ])
            return create_clustering_page()
        elif button_id == "btn-recommendations":
            return create_recommendations_page()
    except Exception as e:
        print(f"Error in display_page: {str(e)}")
        return html.Div([
            dbc.Alert(
                f"Error loading page: {str(e)}",
                color="danger",
                className="mt-3"
            ),
            html.P("Please try refreshing the page or selecting a different section.", className="text-muted")
        ])

    return create_overview_page()

@app.callback(
    Output("prediction-results", "children"),
    [Input("predict-button", "n_clicks")],
    [State("gender-input", "value"),
     State("dob-input", "value"),
     State("region-input", "value"),
     State("health-input", "value"),
     State("father-edu-input", "value"),
     State("mother-edu-input", "value"),
     State("parental-input", "value"),
     State("internet-input", "value"),
     State("electricity-input", "value"),
     State("school-type-input", "value"),
     State("location-input", "value"),
     State("ratio-input", "value"),
     State("resources-input", "value"),
     State("academic-input", "value"),
     State("student-resources-input", "value"),
     State("field-input", "value"),
     State("career-input", "value"),
     State("textbook-input", "value"),
     State("attendance-input", "value"),
     State("homework-input", "value"),
     State("participation-input", "value")]
)
def update_prediction(n_clicks, gender, dob, region, health, father_edu, mother_edu,
                     parental, internet, electricity, school_type, location, ratio,
                     resources, academic, student_resources, field, career,
                     textbook, attendance, homework, participation):

    if n_clicks is None or n_clicks == 0:
        return html.Div()

    input_data = {
        'Gender': gender if gender is not None else 'Male',
        'Date_of_Birth': dob if dob is not None else '2005-06-15',
        'Region': region if region is not None else 'Oromia',
        'Health_Issue': health if health is not None else 'No Issue',
        'Father_Education': father_edu if father_edu is not None else 'High School',
        'Mother_Education': mother_edu if mother_edu is not None else 'High School',
        'Parental_Involvement': parental if parental is not None else 0.5,
        'Home_Internet_Access': internet if internet is not None else 'No',
        'Electricity_Access': electricity if electricity is not None else 'No',
        'School_Type': school_type if school_type is not None else 'Public',
        'School_Location': location if location is not None else 'Rural',
        'Teacher_Student_Ratio': ratio if ratio is not None else 40,
        'School_Resources_Score': resources if resources is not None else 0.5,
        'School_Academic_Score': academic if academic is not None else 0.5,
        'Student_to_Resources_Ratio': student_resources if student_resources is not None else 20,
        'Field_Choice': field if field is not None else 'Social',
        'Career_Interest': career if career is not None else 'Teacher',
        'Overall_Textbook_Access_Composite': textbook if textbook is not None else 0.5,
        'Overall_Avg_Attendance': attendance if attendance is not None else 75,
        'Overall_Avg_Homework': homework if homework is not None else 65,
        'Overall_Avg_Participation': participation if participation is not None else 70
    }

    result = make_prediction_corrected(input_data)

    if result is None:
        return dbc.Alert("Prediction failed. Please check your inputs and ensure all required fields are filled.", color="danger")

    return create_prediction_results()

@app.callback(
    [Output("gender-input", "value"),
     Output("dob-input", "value"),
     Output("region-input", "value"),
     Output("health-input", "value"),
     Output("father-edu-input", "value"),
     Output("mother-edu-input", "value"),
     Output("parental-input", "value"),
     Output("internet-input", "value"),
     Output("electricity-input", "value"),
     Output("school-type-input", "value"),
     Output("location-input", "value"),
     Output("ratio-input", "value"),
     Output("resources-input", "value"),
     Output("academic-input", "value"),
     Output("student-resources-input", "value"),
     Output("field-input", "value"),
     Output("career-input", "value"),
     Output("textbook-input", "value"),
     Output("attendance-input", "value"),
     Output("homework-input", "value"),
     Output("participation-input", "value")],
    [Input("reset-button", "n_clicks")]
)
def reset_form(n_clicks):
    if n_clicks:
        defaults = [
            'Male',
            '2005-06-15',
            'Oromia',
            'No Issue',
            'High School',
            'High School',
            0.5,
            'No',
            'No',
            'Public',
            'Rural',
            40,
            0.5,
            0.5,
            20,
            'Social',
            'Teacher',
            0.5,
            75,
            65,
            70
        ]
        return defaults
    raise dash.exceptions.PreventUpdate

# =================================================================
# RUN THE APP
# =================================================================

if __name__ == "__main__":
    print("\n" + "=" * 60)
    print("DASHBOARD STARTING")
    print("=" * 60)
    print("\nDashboard available at: http://localhost:8056")
    print("=" * 60)

    app.run(debug=True, port=8056, dev_tools_hot_reload=False)