In [None]:
# -*- coding: utf-8 -*-
"""Untitled6.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1Rllywmf7_wZO-d5BH4TSvDNrbB2W9qJX
"""

#!/usr/bin/env python3
"""
Healthy Eating Classification Model
Analyzes nutritional data to predict if food items are healthy
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score,
                          roc_curve, precision_recall_curve, auc, f1_score,
                          precision_score, recall_score, accuracy_score)
from sklearn.inspection import permutation_importance, partial_dependence
from sklearn.metrics import accuracy_score
try:
    from imblearn.over_sampling import SMOTE
    SMOTE_AVAILABLE = True
except ImportError:
    SMOTE_AVAILABLE = False
    print("Warning: imbalanced-learn not installed. SMOTE will not be used.")
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

class HealthyEatingClassifier:
    def __init__(self):
        self.model = None
        self.scaler = StandardScaler()
        self.feature_columns = None
        self.target_thresholds = {
            'calories_per_100g': 300,  # High calorie threshold
            'saturated_fat_per_100g': 5,  # High saturated fat threshold
            'sugar_per_100g': 15,  # High sugar threshold
            'sodium_per_100g': 400  # High sodium threshold (mg)
        }

    def load_and_explore_data(self, file_path):
        """Load and perform exploratory data analysis"""
        print("Loading Healthy Eating Dataset...")
        self.df = pd.read_csv(file_path)

        print(f"Dataset shape: {self.df.shape}")
        print(f"\nColumn names in dataset:")
        print(list(self.df.columns))

        print("\nDataset Info:")
        print(self.df.info())
        print("\nFirst few rows:")
        print(self.df.head())
        print("\nDataset Statistics:")
        print(self.df.describe())

        # Check for missing values
        print("\nMissing values:")
        missing_summary = self.df.isnull().sum()
        print(missing_summary[missing_summary > 0] if (missing_summary > 0).any() else "No missing values")

        # Check which nutrient columns exist
        print(f"\n{'='*80}")
        print("NUTRIENT COLUMN CHECK")
        print(f"{'='*80}")
        expected_nutrients = ['calories', 'protein', 'carbs', 'fat', 'fiber', 'sugar', 'sodium', 'saturated_fat']
        expected_per_100g = [f'{n}_per_100g' for n in expected_nutrients]

        print("\nExpected base nutrients:")
        for nutrient in expected_nutrients:
            exists = "✓" if nutrient in self.df.columns else "✗"
            print(f"  {exists} {nutrient}")

        print("\nExpected per_100g columns:")
        for nutrient in expected_per_100g:
            exists = "✓" if nutrient in self.df.columns else "✗"
            print(f"  {exists} {nutrient}")

        # Check if serving_size_g exists
        if 'serving_size_g' in self.df.columns:
            print(f"\n✓ serving_size_g found")
            print(f"  Mean: {self.df['serving_size_g'].mean():.2f}g")
            print(f"  Min: {self.df['serving_size_g'].min():.2f}g")
            print(f"  Max: {self.df['serving_size_g'].max():.2f}g")
        else:
            print(f"\n✗ serving_size_g NOT found - assuming data is already per 100g")

        return self.df

    def create_healthy_target(self):
        """Create or verify is_healthy target variable"""
        print("Checking healthy eating target variable...")

        # Check if target already exists in dataset
        if 'is_healthy' in self.df.columns:
            print("✓ Target variable 'is_healthy' already exists in dataset - using it directly!")
            print("  (Not creating new target from thresholds)")
        else:
            print("Creating new target variable from nutritional thresholds...")

            # Verify required columns exist
            required_cols = ['calories_per_100g', 'saturated_fat_per_100g', 'sugar_per_100g', 'sodium_per_100g']
            missing = [col for col in required_cols if col not in self.df.columns]
            if missing:
                print(f"ERROR: Missing required columns: {missing}")
                print(f"Available columns: {list(self.df.columns)}")
                raise KeyError(f"Missing required columns for target creation: {missing}")

            # Define healthy criteria (lower is better for these nutrients)
            healthy_criteria = (
                (self.df['calories_per_100g'] <= self.target_thresholds['calories_per_100g']) &
                (self.df['saturated_fat_per_100g'] <= self.target_thresholds['saturated_fat_per_100g']) &
                (self.df['sugar_per_100g'] <= self.target_thresholds['sugar_per_100g']) &
                (self.df['sodium_per_100g'] <= self.target_thresholds['sodium_per_100g'])
            )

            self.df['is_healthy'] = healthy_criteria.astype(int)

        print(f"\nTarget Distribution:")
        print(f"  Healthy food items: {self.df['is_healthy'].sum()}")
        print(f"  Unhealthy food items: {(~self.df['is_healthy'].astype(bool)).sum()}")
        print(f"  Healthy percentage: {self.df['is_healthy'].mean():.2%}")

        # Show sample of each class to verify target makes sense
        if 'meal_name' in self.df.columns:
            print(f"\n{'='*80}")
            print("SAMPLE UNHEALTHY FOODS (first 3):")
            print(f"{'='*80}")
            display_cols = ['meal_name', 'calories_per_100g', 'protein_per_100g', 'carbs_per_100g', 'fat_per_100g']
            available_cols = [c for c in display_cols if c in self.df.columns]
            unhealthy_samples = self.df[self.df['is_healthy'] == 0][available_cols].head(3)
            print(unhealthy_samples.to_string())

            print(f"\n{'='*80}")
            print("SAMPLE HEALTHY FOODS (first 3):")
            print(f"{'='*80}")
            healthy_samples = self.df[self.df['is_healthy'] == 1][available_cols].head(3)
            print(healthy_samples.to_string())

        return self.df

    def normalize_units(self):
        """Ensure all nutrients are per 100g consistently"""
        print("Normalizing units to per 100g...")

        # Map YOUR dataset's column names to expected format
        column_mapping = {
            'protein_g': 'protein',
            'carbs_g': 'carbs',
            'fat_g': 'fat',
            'fiber_g': 'fiber',
            'sugar_g': 'sugar',
            'sodium_mg': 'sodium',
            # Note: Your dataset doesn't have saturated_fat, we'll handle this below
        }

        # Rename columns if they exist
        for old_name, new_name in column_mapping.items():
            if old_name in self.df.columns:
                self.df[new_name] = self.df[old_name]
                print(f"  Mapped {old_name} → {new_name}")

        nutrient_cols = ['calories', 'protein', 'carbs', 'fat', 'fiber', 'sugar', 'sodium', 'saturated_fat']

        # First, check if *_per_100g columns already exist in the dataset
        per_100g_exists = any(f'{col}_per_100g' in self.df.columns for col in nutrient_cols)

        if per_100g_exists:
            # Dataset already has per_100g columns - just ensure all exist
            print("Dataset already contains per_100g columns")
            for col in nutrient_cols:
                per_col = f'{col}_per_100g'
                if per_col not in self.df.columns:
                    # Try to find the base column
                    if col in self.df.columns:
                        self.df[per_col] = self.df[col].fillna(0)
                    else:
                        print(f"Warning: Missing both {col} and {per_col}, creating with zeros")
                        self.df[per_col] = 0.0
                else:
                    # Fill any NaNs in existing per_100g columns
                    self.df[per_col] = self.df[per_col].fillna(0)

        elif 'serving_size_g' in self.df.columns:
            # Has serving size, need to normalize
            print("Normalizing from serving size to per 100g")
            serving = self.df['serving_size_g'].replace(0, np.nan)

            for col in nutrient_cols:
                per_col = f'{col}_per_100g'
                if col in self.df.columns:
                    # For sodium, convert mg to mg (already correct unit)
                    self.df[per_col] = (self.df[col] / serving * 100).fillna(0)
                    print(f"  Created {per_col} from {col}")
                else:
                    # Column doesn't exist - check if it's saturated_fat
                    if col == 'saturated_fat':
                        print(f"  Warning: {col} not in dataset, setting to 0")
                        self.df[per_col] = 0.0
                    else:
                        print(f"  ERROR: {col} not found in dataset!")
                        self.df[per_col] = 0.0

        else:
            # No per_100g columns and no serving_size - assume raw columns are per 100g
            print("Assuming raw columns are per 100g")
            for col in nutrient_cols:
                per_col = f'{col}_per_100g'
                if col in self.df.columns:
                    self.df[per_col] = self.df[col].fillna(0)
                else:
                    print(f"Warning: Missing {col}, creating {per_col} with zeros")
                    self.df[per_col] = 0.0

        # Verify all required columns exist
        missing_cols = [f'{col}_per_100g' for col in nutrient_cols if f'{col}_per_100g' not in self.df.columns]
        if missing_cols:
            print(f"ERROR: Still missing columns after normalization: {missing_cols}")
            raise ValueError(f"Failed to create required columns: {missing_cols}")

        # Show summary of created columns with sample values
        print("\nUnit normalization complete - Summary:")
        for col in nutrient_cols:
            per_col = f'{col}_per_100g'
            if per_col in self.df.columns:
                non_zero = (self.df[per_col] != 0).sum()
                mean_val = self.df[per_col].mean()
                print(f"  {per_col}: {non_zero}/{len(self.df)} non-zero, mean={mean_val:.2f}")

        return self.df

    def engineer_features(self):
        """Create meaningful derived features WITHOUT leaking the target"""
        print("Engineering features...")

        # Core nutrients (ensure they exist)
        core_nutrients = ['calories_per_100g', 'protein_per_100g', 'carbs_per_100g',
                         'fat_per_100g', 'fiber_per_100g', 'sugar_per_100g', 'sodium_per_100g', 'saturated_fat_per_100g']

        # Add missing columns with zeros if they don't exist
        for nutrient in core_nutrients:
            if nutrient not in self.df.columns:
                base_name = nutrient.replace('_per_100g', '')
                if base_name in self.df.columns:
                    self.df[nutrient] = self.df[base_name].fillna(0)
                else:
                    self.df[nutrient] = 0

        # Safe derived features that don't directly leak
        calories = self.df['calories_per_100g'].replace(0, 1)
        protein = self.df['protein_per_100g']
        carbs = self.df['carbs_per_100g']
        fat = self.df['fat_per_100g']
        fiber = self.df['fiber_per_100g'].replace(0, 1)
        sugar = self.df['sugar_per_100g']
        sodium = self.df['sodium_per_100g']
        sat_fat = self.df['saturated_fat_per_100g']

        # SAFE nutrient densities (only use non-leaking nutrients)
        self.df['protein_density'] = protein / calories
        self.df['fiber_density'] = fiber / calories
        # DO NOT USE: sugar_density, sodium_density (leak the target thresholds)

        # SAFE nutrient ratios (relationships between safe nutrients)
        self.df['protein_to_fat_ratio'] = protein / fat.replace(0, 1)
        self.df['protein_to_carb_ratio'] = protein / carbs.replace(0, 1)
        self.df['fiber_to_carb_ratio'] = fiber / carbs.replace(0, 1)
        # DO NOT USE: sat_fat_ratio (leaks saturated fat threshold)
        self.df['carb_fiber_ratio'] = carbs / fiber
        self.df['fat_to_carb_ratio'] = fat / carbs.replace(0, 1)

        # Macronutrient balance features (SAFE - relative proportions)
        total_macro = protein + carbs + fat
        self.df['protein_pct'] = protein / total_macro.replace(0, 1)
        self.df['carb_pct'] = carbs / total_macro.replace(0, 1)
        self.df['fat_pct'] = fat / total_macro.replace(0, 1)

        # SAFE quality indicators (only positive nutrients)
        self.df['nutrient_density_score'] = (protein + fiber) / calories
        self.df['protein_fiber_ratio'] = protein / fiber
        # DO NOT USE: empty_calorie_ratio, sodium_to_protein_ratio (leak thresholds)

        # SAFE interaction features (only non-leaking nutrients)
        self.df['protein_fiber_product'] = protein * fiber
        self.df['protein_carb_product'] = protein * carbs
        # DO NOT USE: sugar_fat_product (leaks sugar threshold)

        # Categorical features (one-hot encoding)
        categorical_cols = ['cuisine', 'type', 'meal_category', 'ingredients']
        for col in categorical_cols:
            if col in self.df.columns:
                # Limit to top categories to avoid too many features
                top_categories = self.df[col].value_counts().head(10).index
                for category in top_categories:
                    self.df[f'{col}_{category}'] = (self.df[col] == category).astype(int)

        print(f"Total features after engineering: {self.df.shape[1]}")
        return self.df

    def clean_data(self):
        """Handle missing values and outliers"""
        print("Cleaning data...")

        # Handle missing values
        for col in self.df.columns:
            if self.df[col].dtype in ['int64', 'float64']:
                # Fill numerical columns with median
                self.df[col] = self.df[col].fillna(self.df[col].median())
            else:
                # Fill categorical columns with 'Unknown'
                self.df[col] = self.df[col].fillna('Unknown')

        # Handle outliers by clipping to 1st-99th percentiles
        numerical_cols = self.df.select_dtypes(include=[np.number]).columns
        for col in numerical_cols:
            if col != 'is_healthy':  # Don't clip the target
                lower_bound = self.df[col].quantile(0.01)
                upper_bound = self.df[col].quantile(0.99)
                self.df[col] = self.df[col].clip(lower_bound, upper_bound)

        print("Data cleaning completed!")
        return self.df

    def visualize_data(self):
        """Create comprehensive visualizations"""
        plt.style.use('default')

        # 1. Target distribution
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('Healthy Eating Dataset Analysis', fontsize=16, fontweight='bold')

        # Target distribution
        self.df['is_healthy'].value_counts().plot(kind='pie', ax=axes[0, 0], autopct='%1.1f%%')
        axes[0, 0].set_title('Healthy vs Unhealthy Distribution')

        # Calories distribution
        self.df.boxplot(column='calories_per_100g', by='is_healthy', ax=axes[0, 1])
        axes[0, 1].set_title('Calories by Health Status')

        # Sugar distribution
        self.df.boxplot(column='sugar_per_100g', by='is_healthy', ax=axes[0, 2])
        axes[0, 2].set_title('Sugar by Health Status')

        # Protein distribution
        self.df.boxplot(column='protein_per_100g', by='is_healthy', ax=axes[1, 0])
        axes[1, 0].set_title('Protein by Health Status')

        # Fiber distribution
        self.df.boxplot(column='fiber_per_100g', by='is_healthy', ax=axes[1, 1])
        axes[1, 1].set_title('Fiber by Health Status')

        # Sodium distribution
        self.df.boxplot(column='sodium_per_100g', by='is_healthy', ax=axes[1, 2])
        axes[1, 2].set_title('Sodium by Health Status')

        plt.tight_layout()
        plt.savefig('healthy_eating_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()

        # Correlation heatmap
        plt.figure(figsize=(12, 8))
        numeric_cols = self.df.select_dtypes(include=[np.number]).columns
        correlation_matrix = self.df[numeric_cols].corr()
        sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', center=0,
                   square=True, fmt='.2f')
        plt.title('Feature Correlation Matrix')
        plt.tight_layout()
        plt.savefig('correlation_matrix_healthy.png', dpi=300, bbox_inches='tight')
        plt.show()

    def preprocess_data(self):
        """Preprocess data for deep learning model"""
        print("Preprocessing data...")

        # ----------------- 1) SAFE features that DON'T leak the target -----------------
        feature_cols = [
            # Base nutrients (NOT used to define target)
            'protein_per_100g', 'carbs_per_100g', 'fat_per_100g', 'fiber_per_100g',

            # SAFE nutrient densities (only non-leaking)
            'protein_density', 'fiber_density',

            # SAFE nutrient ratios (only non-leaking nutrients)
            'protein_to_fat_ratio', 'protein_to_carb_ratio', 'fiber_to_carb_ratio',
            'carb_fiber_ratio', 'fat_to_carb_ratio',

            # Macronutrient percentages (SAFE - relative proportions)
            'protein_pct', 'carb_pct', 'fat_pct',

            # SAFE quality indicators
            'nutrient_density_score', 'protein_fiber_ratio',

            # SAFE interaction features
            'protein_fiber_product', 'protein_carb_product',
        ]

        # ----------------- 2) BAN: ALL features that directly or indirectly leak -----------------
        BAN = {
            # Direct threshold features
            'calories_per_100g',        # Used in target definition
            'saturated_fat_per_100g',   # Used in target definition
            'sugar_per_100g',           # Used in target definition
            'sodium_per_100g',          # Used in target definition

            # Indirect leakage through ratios/densities
            'sugar_density',            # sugar/calories (leaks sugar)
            'sodium_density',           # sodium/calories (leaks sodium)
            'sat_fat_ratio',            # sat_fat/fat (leaks sat_fat)
            'empty_calorie_ratio',      # sugar/calories (same as sugar_density)
            'sodium_to_protein_ratio',  # sodium/protein (leaks sodium)
            'sugar_fat_product',        # sugar*fat (leaks sugar)

            # Derived features based on leaking nutrients
            'energy_density',           # == calories_per_100g
            'has_added_sugar',          # Based on sugar threshold
            'is_ultra_processed',       # Based on calorie threshold

            # The target itself
            'is_healthy',
        }

        # ----------------- 3) Add categorical features -----------------
        categorical_prefixes = ['cuisine_', 'type_', 'meal_category_', 'ingredients_']
        for prefix in categorical_prefixes:
            matching_cols = [col for col in self.df.columns if col.startswith(prefix)]
            feature_cols.extend(matching_cols)

        # ----------------- 4) Filter and validate -----------------
        # Keep only existing columns
        feature_cols = [c for c in feature_cols if c in self.df.columns]
        # Remove any banned features
        feature_cols = [c for c in feature_cols if c not in BAN]

        # Save for later use
        self.feature_columns = feature_cols

        print(f"\n{'='*80}")
        print(f"FEATURE SELECTION SUMMARY")
        print(f"{'='*80}")
        print(f"Total features selected: {len(self.feature_columns)}")
        print(f"\nAll features being used:")
        for i, feat in enumerate(self.feature_columns, 1):
            print(f"  {i}. {feat}")

        # Check for any banned features that snuck through
        banned_in_use = [f for f in self.feature_columns if f in BAN]
        if banned_in_use:
            print(f"\n⚠️  WARNING: BANNED FEATURES DETECTED IN USE: {banned_in_use}")
            raise ValueError(f"Banned features found in feature list: {banned_in_use}")

        # ----------------- 5) Build X, y with final checks -----------------
        X = self.df[self.feature_columns].copy()
        y = self.df['is_healthy'].copy()

        # ----------------- 6) DATA QUALITY CHECK -----------------
        print(f"\n{'='*80}")
        print(f"DATA QUALITY CHECK")
        print(f"{'='*80}")

        # Check for constant/zero-variance features
        constant_features = []
        for col in X.columns:
            if X[col].nunique() <= 1:
                constant_features.append(col)
                print(f"⚠️  WARNING: {col} has constant values (variance=0)!")

        if constant_features:
            print(f"\n🚨 CRITICAL ERROR: {len(constant_features)} features have no variance!")
            print("This means all values are the same (likely all zeros).")
            print("The dataset or preprocessing has a serious problem!")
            print("\nConstant features detected:")
            for feat in constant_features:
                unique_vals = X[feat].unique()
                print(f"  - {feat}: only has value(s) {unique_vals}")

        # Check basic statistics
        print(f"\nFeature Statistics:")
        print(f"Total features: {len(X.columns)}")
        print(f"Features with variance: {len([c for c in X.columns if X[c].nunique() > 1])}")
        print(f"Features with constant values: {len(constant_features)}")

        # Show sample of feature values
        print(f"\nSample feature values (first 5 rows):")
        print(X.head().to_string())

        # ----------------- 7) LEAK DETECTION: Check correlations with target -----------------
        print(f"\n{'='*80}")
        print(f"CORRELATION ANALYSIS WITH TARGET")
        print(f"{'='*80}")
        correlations = []
        for col in self.feature_columns:
            if X[col].dtype in ['int64', 'float64']:
                if X[col].nunique() > 1:  # Only compute if not constant
                    corr = abs(X[col].corr(y))
                    if not pd.isna(corr):
                        correlations.append((col, corr))
                else:
                    correlations.append((col, 0.0))

        # Sort by correlation
        correlations.sort(key=lambda x: x[1], reverse=True)

        print("\nTop 10 features most correlated with target:")
        for i, (feat, corr) in enumerate(correlations[:10], 1):
            if pd.isna(corr) or corr == 0.0:
                print(f"  {i}. {feat}: NO VARIANCE (constant values)")
            else:
                warning = " ⚠️  SUSPICIOUSLY HIGH!" if corr > 0.9 else ""
                print(f"  {i}. {feat}: {corr:.4f}{warning}")

        # Check for extremely high correlations (potential leakage)
        high_corr_features = [feat for feat, corr in correlations if not pd.isna(corr) and corr > 0.9]
        if high_corr_features:
            print(f"\n⚠️  WARNING: Features with correlation > 0.9 detected!")
            print(f"These features may be leaking target information:")
            for feat in high_corr_features:
                corr_val = next(c for f, c in correlations if f == feat)
                print(f"  - {feat}: {corr_val:.4f}")
            print("\nConsider removing these features or investigating the data!")

        # Critical check: if most features are constant, stop
        if len(constant_features) > len(X.columns) * 0.5:
            raise ValueError(f"CRITICAL: More than 50% of features have constant values! "
                           f"Dataset preprocessing has failed. Check the normalize_units() "
                           f"and engineer_features() methods.")

        # Final validation
        if X.isnull().any().any():
            print("Warning: Features contain NaN values, filling with median...")
            X = X.fillna(X.median())

        if y.isnull().any():
            print("Warning: Target contains NaN values, filling with mode...")
            y = y.fillna(y.mode()[0])

        print(f"Feature matrix shape: {X.shape}")
        print(f"Target shape: {y.shape}")
        print(f"Features used (after BAN): {len(self.feature_columns)}")
        # Optional: print a few banned features found in df (for sanity)
        banned_present = [c for c in BAN if c in self.df.columns]
        if banned_present:
            print(f"Banned features detected (excluded): {banned_present}")

        return X, y


    def split_and_scale_data(self, X, y, use_smote=True):
        """Split data into train/val/test and scale features"""
        print("Splitting and scaling data...")

        # Stratified split: 70% train, 15% val, 15% test
        X_temp, X_test, y_temp, y_test = train_test_split(
            X, y, test_size=0.15, random_state=42, stratify=y
        )

        X_train, X_val, y_train, y_val = train_test_split(
            X_temp, y_temp, test_size=0.176, random_state=42, stratify=y_temp  # 0.176 * 0.85 ≈ 0.15
        )

        print(f"Original training set: {X_train.shape}")
        print(f"Original class distribution - Train: {np.bincount(y_train)}")

        # Apply SMOTE to training data only (not validation or test)
        # WARNING: SMOTE might be causing unrealistic results by creating too-perfect synthetic samples
        if use_smote and SMOTE_AVAILABLE:
            print("\n⚠️  WARNING: SMOTE is enabled but might cause overfitting!")
            print("If you're getting unrealistically high accuracy (>95%), try disabling SMOTE.")
            print("To disable: change use_smote=False in split_and_scale_data() call\n")

            print("Applying SMOTE to balance training data...")
            smote = SMOTE(random_state=42, k_neighbors=5)
            try:
                X_train, y_train = smote.fit_resample(X_train, y_train)
                print(f"After SMOTE - Training set: {X_train.shape}")
                print(f"After SMOTE - Class distribution: {np.bincount(y_train)}")
            except Exception as e:
                print(f"SMOTE failed: {e}. Continuing without SMOTE.")
        else:
            print("\nSMOTE is disabled - using original imbalanced data")

        # Scale features AFTER SMOTE
        X_train_scaled = self.scaler.fit_transform(X_train)
        X_val_scaled = self.scaler.transform(X_val)
        X_test_scaled = self.scaler.transform(X_test)

        print(f"\nFinal training set: {X_train_scaled.shape}")
        print(f"Validation set: {X_val_scaled.shape}")
        print(f"Test set: {X_test_scaled.shape}")
        print(f"Final class distribution - Train: {np.bincount(y_train)}")
        print(f"Class distribution - Val: {np.bincount(y_val)}")
        print(f"Class distribution - Test: {np.bincount(y_test)}")

        return (X_train_scaled, X_val_scaled, X_test_scaled,
                y_train, y_val, y_test)

    def build_model(self, input_dim):
        """Build deep learning model architecture"""
        print("Building healthy eating classification model...")

        # Calculate class weights for imbalance
        class_counts = np.bincount(self.df['is_healthy'])
        total_samples = class_counts.sum()
        class_weights = {
            0: total_samples / (2 * class_counts[0]),  # Unhealthy
            1: total_samples / (2 * class_counts[1])  # Healthy
        }
        print(f"Class weights: {class_weights}")
        print(f"Class distribution: Unhealthy={class_counts[0]}, Healthy={class_counts[1]}")

        # Balanced architecture - deep enough but not overfit
        model = keras.Sequential([
            # Input layer with batch normalization
            layers.Dense(64, input_shape=(input_dim,)),
            layers.BatchNormalization(),
            layers.Activation('relu'),
            layers.Dropout(0.4),

            # Second hidden layer
            layers.Dense(32, kernel_regularizer=regularizers.l2(1e-3)),
            layers.BatchNormalization(),
            layers.Activation('relu'),
            layers.Dropout(0.4),

            # Third hidden layer
            layers.Dense(16, kernel_regularizer=regularizers.l2(1e-3)),
            layers.BatchNormalization(),
            layers.Activation('relu'),
            layers.Dropout(0.3),

            # Output layer
            layers.Dense(1, activation='sigmoid')
        ])

        # Compile with adjusted learning rate
        model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=5e-4),  # Lower learning rate
            loss='binary_crossentropy',
            metrics=['accuracy', keras.metrics.AUC(name='auc'),
                    keras.metrics.Precision(name='precision'),
                    keras.metrics.Recall(name='recall')]
        )

        print("Model Architecture:")
        model.summary()

        return model, class_weights

    def train_model(self, model, X_train, y_train, X_val, y_val, class_weights):
        """Train the deep learning model"""
        print("Training model...")

        # Define callbacks with more aggressive settings
        callbacks = [
            keras.callbacks.EarlyStopping(
                monitor='val_auc',
                patience=20,  # Increased patience
                restore_best_weights=True,
                mode='max',
                verbose=1
            ),
            keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=10,  # Increased patience
                min_lr=1e-7,
                verbose=1
            ),
            keras.callbacks.ModelCheckpoint(
                'best_model_temp.h5',
                monitor='val_auc',
                save_best_only=True,
                mode='max',
                verbose=0
            )
        ]

        # Train the model with adjusted batch size
        history = model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            epochs=150,  # More epochs
            batch_size=16,  # Smaller batch size for better learning
            class_weight=class_weights,
            callbacks=callbacks,
            verbose=1
        )

        return history

    def evaluate_model(self, model, X_test, y_test, threshold=None):
        """Evaluate model performance"""
        print("Evaluating model...")

        # Get predictions
        y_pred_proba = model.predict(X_test).flatten()

        # Use optimal threshold if not provided
        if threshold is None:
            precision, recall, thresholds = precision_recall_curve(y_test, y_pred_proba)
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            optimal_idx = np.argmax(f1_scores)
            threshold = thresholds[optimal_idx]
            print(f"Optimal threshold: {threshold:.4f}")

        y_pred = (y_pred_proba >= threshold).astype(int)

        # Calculate metrics
        test_accuracy = accuracy_score(y_test, y_pred)            # Accuracy at your threshold
        fpr, tpr, _ = roc_curve(y_test, y_pred_proba)   # AUC from probabilities
        test_auc = auc(fpr, tpr)

        # Additional metrics with macro and weighted averages
        precision_macro = precision_score(y_test, y_pred, average='macro')
        precision_weighted = precision_score(y_test, y_pred, average='weighted')
        recall_macro = recall_score(y_test, y_pred, average='macro')
        recall_weighted = recall_score(y_test, y_pred, average='weighted')
        f1_macro = f1_score(y_test, y_pred, average='macro')
        f1_weighted = f1_score(y_test, y_pred, average='weighted')

        print(f"\nTest Results:")
        print(f"Accuracy: {test_accuracy:.4f}")
        print(f"AUC: {test_auc:.4f}")
        print(f"Precision (Macro): {precision_macro:.4f}")
        print(f"Precision (Weighted): {precision_weighted:.4f}")
        print(f"Recall (Macro): {recall_macro:.4f}")
        print(f"Recall (Weighted): {recall_weighted:.4f}")
        print(f"F1-Score (Macro): {f1_macro:.4f}")
        print(f"F1-Score (Weighted): {f1_weighted:.4f}")

        # Classification report
        print("\nDetailed Classification Report:")
        print(classification_report(y_test, y_pred, target_names=['Unhealthy', 'Healthy']))

        # Confusion matrix
        cm = confusion_matrix(y_test, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=['Unhealthy', 'Healthy'],
                   yticklabels=['Unhealthy', 'Healthy'])
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig('confusion_matrix_healthy.png', dpi=300, bbox_inches='tight')
        plt.show()

        # Create metrics summary plot
        self.plot_metrics_summary(test_accuracy, precision_macro, recall_macro, f1_macro)

        # Create comprehensive metrics table
        self.create_metrics_table(test_accuracy, precision_macro, recall_macro, f1_macro)

        return y_pred, y_pred_proba, threshold

    def plot_metrics_summary(self, accuracy, precision, recall, f1):
        """Plot metrics summary bar chart"""
        metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
        values = [accuracy, precision, recall, f1]

        plt.figure(figsize=(10, 6))
        bars = plt.bar(metrics, values, color=['skyblue', 'lightgreen', 'lightcoral', 'gold'])
        plt.title('Model Performance Metrics Summary', fontsize=14, fontweight='bold')
        plt.ylabel('Score', fontsize=12)
        plt.ylim(0, 1)

        # Add value labels on bars
        for bar, value in zip(bars, values):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.3f}', ha='center', va='bottom', fontweight='bold')

        plt.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.savefig('metrics_summary_healthy.png', dpi=300, bbox_inches='tight')
        plt.show()

    def create_metrics_table(self, accuracy, precision, recall, f1):
        """Create comprehensive metrics table"""
        print("\n" + "="*80)
        print("COMPREHENSIVE MODEL PERFORMANCE METRICS TABLE")
        print("="*80)

        # Create metrics DataFrame
        metrics_data = {
            'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score'],
            'Score': [f"{accuracy:.4f}", f"{precision:.4f}", f"{recall:.4f}", f"{f1:.4f}"]
        }

        metrics_df = pd.DataFrame(metrics_data)

        print("\nModel Performance Metrics:")
        print(metrics_df.to_string(index=False))

        # Create visual table
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.axis('tight')
        ax.axis('off')

        # Create table
        table_data = [
            ['Accuracy', f"{accuracy:.4f}"],
            ['Precision', f"{precision:.4f}"],
            ['Recall', f"{recall:.4f}"],
            ['F1-Score', f"{f1:.4f}"]
        ]

        headers = ['Metric', 'Score']
        table = ax.table(cellText=table_data, colLabels=headers,
                        cellLoc='center', loc='center')

        # Style the table
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1.2, 2)

        # Color code the rows
        for i in range(len(table_data) + 1):
            for j in range(len(headers)):
                cell = table[(i, j)]
                if i == 0:  # Header row
                    cell.set_facecolor('#4CAF50')
                    cell.set_text_props(weight='bold', color='white')
                else:  # Data rows
                    cell.set_facecolor('#F5F5F5')

        plt.title('Model Performance Metrics Table', fontsize=16, fontweight='bold', pad=20)
        plt.tight_layout()
        plt.savefig('metrics_table_healthy.png', dpi=300, bbox_inches='tight')
        plt.show()

        # Save metrics to CSV
        metrics_df.to_csv('model_metrics_healthy.csv', index=False)
        print(f"\nMetrics table saved to 'model_metrics_healthy.csv'")

        return metrics_df

    def _pick_threshold_on_validation(self, y_val, val_proba, mode="macro_f1"):
        """
        Choose a probability threshold using the validation set that optimizes
        a metric considering BOTH classes. Supported modes: 'macro_f1', 'balanced_acc'.
        """
        # Candidate thresholds: all unique probs plus 0 and 1
        thr_candidates = np.r_[0.0, np.sort(np.unique(val_proba)), 1.0]

        best_thr, best_score = 0.5, -1.0
        for t in thr_candidates:
            y_hat = (val_proba >= t).astype(int)
            if mode == "macro_f1":
                score = f1_score(y_val, y_hat, average="macro", zero_division=0)
            elif mode == "balanced_acc":
                # balanced accuracy = avg(sensitivity, specificity)
                tn, fp, fn, tp = confusion_matrix(y_val, y_hat).ravel()
                tpr = tp / (tp + fn + 1e-12)
                tnr = tn / (tn + fp + 1e-12)
                score = 0.5 * (tpr + tnr)
            else:
                raise ValueError("Unsupported mode for threshold picking.")

            if score > best_score:
                best_score, best_thr = score, t

        print(f"[Threshold search] mode={mode} best={best_score:.4f} @ thr={best_thr:.4f}")
        return best_thr


    def create_validation_test_comparison(self, model, X_val, y_val, X_test, y_test, threshold):
        """Create comparison between validation and test performance"""
        print("\n" + "="*80)
        print("VALIDATION vs TEST SET PERFORMANCE COMPARISON")
        print("="*80)

        # Get predictions for both sets
        y_val_pred_proba = model.predict(X_val).flatten()
        y_test_pred_proba = model.predict(X_test).flatten()

        y_val_pred = (y_val_pred_proba >= threshold).astype(int)
        y_test_pred = (y_test_pred_proba >= threshold).astype(int)


        # Calculate metrics for validation set
        val_accuracy = model.evaluate(X_val, y_val, verbose=0)[1]
        val_precision = precision_score(y_val, y_val_pred, average='macro')
        val_recall = recall_score(y_val, y_val_pred, average='macro')
        val_f1 = f1_score(y_val, y_val_pred, average='macro')

        # Calculate metrics for test set
        test_accuracy = model.evaluate(X_test, y_test, verbose=0)[1]
        test_precision = precision_score(y_test, y_test_pred, average='macro')
        test_recall = recall_score(y_test, y_test_pred, average='macro')
        test_f1 = f1_score(y_test, y_test_pred, average='macro')

        # Create comparison DataFrame
        comparison_data = {
            'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score'],
            'Validation': [val_accuracy, val_precision, val_recall, val_f1],
            'Test': [test_accuracy, test_precision, test_recall, test_f1],
            'Difference': [
                test_accuracy - val_accuracy,
                test_precision - val_precision,
                test_recall - val_recall,
                test_f1 - val_f1
            ]
        }

        comparison_df = pd.DataFrame(comparison_data)
        comparison_df['Validation'] = comparison_df['Validation'].round(4)
        comparison_df['Test'] = comparison_df['Test'].round(4)
        comparison_df['Difference'] = comparison_df['Difference'].round(4)

        print("\nValidation vs Test Performance Comparison:")
        print(comparison_df.to_string(index=False))

        # Create visual comparison
        fig, ax = plt.subplots(figsize=(10, 6))
        x = np.arange(len(comparison_df))
        width = 0.35

        bars1 = ax.bar(x - width/2, comparison_df['Validation'], width,
                      label='Validation', color='skyblue', alpha=0.8)
        bars2 = ax.bar(x + width/2, comparison_df['Test'], width,
                      label='Test', color='lightcoral', alpha=0.8)

        ax.set_xlabel('Metrics', fontsize=12)
        ax.set_ylabel('Score', fontsize=12)
        ax.set_title('Validation vs Test Set Performance Comparison', fontsize=14, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(comparison_df['Metric'])
        ax.legend()
        ax.set_ylim(0, 1)
        ax.grid(axis='y', alpha=0.3)

        # Add value labels on bars
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{height:.3f}', ha='center', va='bottom', fontsize=9)

        plt.tight_layout()
        plt.savefig('validation_test_comparison_healthy.png', dpi=300, bbox_inches='tight')
        plt.show()

        # Save comparison to CSV
        comparison_df.to_csv('validation_test_comparison_healthy.csv', index=False)
        print(f"\nComparison table saved to 'validation_test_comparison_healthy.csv'")

        return comparison_df

    def plot_training_history(self, history):
        """Plot training history"""
        hist = history.history
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))

        # Accuracy
        axes[0].plot(hist.get('accuracy', []), label='Training Accuracy')
        axes[0].plot(hist.get('val_accuracy', []), label='Validation Accuracy')
        axes[0].set_title('Model Accuracy')
        axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Accuracy')
        axes[0].legend(); axes[0].grid(True)

        # Loss
        axes[1].plot(hist.get('loss', []), label='Training Loss')
        axes[1].plot(hist.get('val_loss', []), label='Validation Loss')
        axes[1].set_title('Model Loss')
        axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Loss')
        axes[1].legend(); axes[1].grid(True)

        # AUC (only if present)
        auc_train = hist.get('auc', None)
        auc_val   = hist.get('val_auc', None)
        if auc_train is not None or auc_val is not None:
            axes[2].plot(auc_train or [], label='Training AUC')
            axes[2].plot(auc_val or [], label='Validation AUC')
            axes[2].set_title('Model AUC')
            axes[2].set_xlabel('Epoch'); axes[2].set_ylabel('AUC')
            axes[2].legend(); axes[2].grid(True)
        else:
            axes[2].set_visible(False)  # hide third panel if no AUC logged

        plt.tight_layout()
        plt.savefig('training_history_healthy.png', dpi=300, bbox_inches='tight')
        plt.show()


    def analyze_feature_importance(self, model, X_test, y_test):
        """Analyze feature importance using permutation importance"""
        print("Analyzing feature importance...")

        def analyze_feature_importance(self, model, X_test, y_test):
          """Analyze feature importance using permutation importance (binary classifier)."""
          print("Analyzing feature importance...")



          # --- Wrapper so the Keras model looks like an sklearn estimator ---
          class KerasPIWrapper:
              def __init__(self, trained_model):
                  self.model = trained_model
              def fit(self, X, y):
                  # No-op; required by sklearn API
                  return self
              def predict(self, X):
                  # Convert predicted probabilities to class labels (0/1)
                  proba = self.model.predict(X, verbose=0).flatten()
                  return (proba >= 0.5).astype(int)
              def score(self, X, y):
                  # Accuracy score used by permutation_importance when scoring="accuracy"
                  return accuracy_score(y, self.predict(X))

          est = KerasPIWrapper(model)

          # Run permutation importance on the (already scaled) test set
          perm = permutation_importance(
              estimator=est,
              X=X_test,
              y=y_test,
              n_repeats=10,
              random_state=42,
              scoring="accuracy",
              n_jobs=-1
          )

          # Build results DataFrame
          # Fallback in case feature names are missing/mismatched
          if not self.feature_columns or len(self.feature_columns) != X_test.shape[1]:
              feature_names = [f"feature_{i}" for i in range(X_test.shape[1])]
          else:
              feature_names = list(self.feature_columns)

          feature_importance_df = (
              pd.DataFrame({
                  "feature": feature_names,
                  "importance": perm.importances_mean,
                  "std": perm.importances_std
              })
              .sort_values("importance", ascending=True)
              .reset_index(drop=True)
          )

          print("\nTop 10 Most Important Features:")
          print(feature_importance_df.tail(10).to_string(index=False))

          # ----- Plot: Top 15 (or fewer if not enough features) -----
          top_k = min(15, len(feature_importance_df))
          top_features = feature_importance_df.tail(top_k)

          plt.figure(figsize=(12, 8))
          bars = plt.barh(
              range(len(top_features)),
              top_features["importance"],
              xerr=top_features["std"],
              capsize=5,
              color='skyblue',
              alpha=0.7
          )
          plt.yticks(range(len(top_features)), top_features["feature"])
          plt.xlabel("Feature Importance (mean decrease in accuracy)")
          plt.title(f"Top {top_k} Most Important Features for Healthy Eating Prediction")
          plt.grid(axis='x', alpha=0.3)

          # Value labels
          for bar, importance in zip(bars, top_features["importance"]):
              plt.text(
                  importance + (0.001 if np.isfinite(importance) else 0),
                  bar.get_y() + bar.get_height()/2,
                  f"{importance:.3f}",
                  va="center",
                  fontsize=9
              )

          plt.tight_layout()
          plt.savefig("feature_importance_healthy.png", dpi=300, bbox_inches="tight")
          plt.show()

          return feature_importance_df

    def plot_roc_pr_curves(self, y_test, y_pred_proba):
        """Plot ROC and Precision-Recall curves"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

        # ROC Curve
        fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
        roc_auc = auc(fpr, tpr)

        ax1.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
        ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        ax1.set_xlim([0.0, 1.0])
        ax1.set_ylim([0.0, 1.05])
        ax1.set_xlabel('False Positive Rate')
        ax1.set_ylabel('True Positive Rate')
        ax1.set_title('ROC Curve')
        ax1.legend(loc="lower right")
        ax1.grid(True)

        # Precision-Recall Curve
        precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
        pr_auc = auc(recall, precision)

        ax2.plot(recall, precision, color='darkorange', lw=2, label=f'PR curve (AUC = {pr_auc:.3f})')
        ax2.set_xlim([0.0, 1.0])
        ax2.set_ylim([0.0, 1.05])
        ax2.set_xlabel('Recall')
        ax2.set_ylabel('Precision')
        ax2.set_title('Precision-Recall Curve')
        ax2.legend(loc="lower left")
        ax2.grid(True)

        plt.tight_layout()
        plt.savefig('roc_pr_curves_healthy.png', dpi=300, bbox_inches='tight')
        plt.show()

        return roc_auc, pr_auc

    def run_complete_analysis(self, file_path):
        """Run the complete healthy eating analysis pipeline"""
        print("="*60)
        print("HEALTHY EATING DEEP LEARNING ANALYSIS")
        print("="*60)

        try:
            # Load and explore data
            self.load_and_explore_data(file_path)

            # Normalize units
            self.normalize_units()

            # Engineer features
            self.engineer_features()

            # Create healthy target
            self.create_healthy_target()

            # Clean data
            self.clean_data()

            # Visualize data
            self.visualize_data()

            # Preprocess data
            X, y = self.preprocess_data()

            # Split and scale data
            # SMOTE enabled to handle severe class imbalance (9% healthy vs 91% unhealthy)
            X_train, X_val, X_test, y_train, y_val, y_test = self.split_and_scale_data(X, y, use_smote=True)

            # Build model
            model, class_weights = self.build_model(X_train.shape[1])

            # Train model
            history = self.train_model(model, X_train, y_train, X_val, y_val, class_weights)

            # Plot training history
            self.plot_training_history(history)

            # ---------- pick threshold on VALIDATION (no test leakage) ----------
            val_proba = model.predict(X_val).flatten()
            # Choose ONE of the two modes below:
            val_opt_threshold = self._pick_threshold_on_validation(y_val, val_proba, mode="macro_f1")
            # val_opt_threshold = self._pick_threshold_on_validation(y_val, val_proba, mode="balanced_acc")
            # --------------------------------------------------------------------

            # Evaluate on TEST using the validation-derived threshold
            y_pred, y_pred_proba, _ = self.evaluate_model(
                model, X_test, y_test, threshold=val_opt_threshold
            )

            # Plot ROC and PR curves
            roc_auc, pr_auc = self.plot_roc_pr_curves(y_test, y_pred_proba)

            # Analyze feature importance
            feature_importance = self.analyze_feature_importance(model, X_test, y_test)

            # Create validation vs test comparison (using the validation-derived threshold)
            self.create_validation_test_comparison(
                model, X_val, y_val, X_test, y_test, threshold=val_opt_threshold)

            # Save model
            model.save('healthy_eating_model.h5')
            print("\nModel saved as 'healthy_eating_model.h5'")

            return {
                'model': model,
                'history': history,
                'feature_importance': feature_importance,
                'predictions': y_pred,
                'probabilities': y_pred_proba,
                'threshold': val_opt_threshold,
                'roc_auc': roc_auc,
                'pr_auc': pr_auc
            }

        except Exception as e:
            print(f"\nError during analysis: {str(e)}")
            print("This might be due to:")
            print("1. Missing required columns in the dataset")
            print("2. Incompatible data types")
            print("3. Missing values in critical columns")
            raise e

def main():
    """Main function to run the healthy eating analysis"""
    # Initialize the classifier
    classifier = HealthyEatingClassifier()

    # Run complete analysis
    file_path = "/content/sample_data/healthy_eating_dataset.csv"  # Update this path

    try:
        results = classifier.run_complete_analysis(file_path)
        print("\n" + "="*60)
        print("ANALYSIS COMPLETE!")
        print("="*60)
        print("Generated files:")
        print("- healthy_eating_analysis.png")
        print("- correlation_matrix_healthy.png")
        print("- confusion_matrix_healthy.png")
        print("- training_history_healthy.png")
        print("- feature_importance_healthy.png")
        print("- roc_pr_curves_healthy.png")
        print("- metrics_summary_healthy.png")
        print("- metrics_table_healthy.png")
        print("- validation_test_comparison_healthy.png")
        print("- model_metrics_healthy.csv")
        print("- validation_test_comparison_healthy.csv")
        print("- healthy_eating_model.h5")

    except FileNotFoundError:
        print(f"Error: Could not find the dataset file at '{file_path}'")
        print("Please update the file_path variable with the correct path to your dataset.")
    except Exception as e:
        print(f"Error during analysis: {str(e)}")

if __name__ == "__main__":
    main()