In [None]:
# Install dependencies (Colab)
!pip install -q tensorflow scikit-learn matplotlib gradio

# Imports
import io, sys, math, random
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
import gradio as gr
from google.colab import files

# Global variables to store data and models
global_df = None
global_models = {}
global_scalers = {}
global_country_data = {}

# ----------------- 1. Load dataset -----------------
def load_dataset(file):
    global global_df
    try:
        if hasattr(file, 'read'):
            # If it's a file-like object
            df = pd.read_csv(file)
        else:
            # If it's a file path
            df = pd.read_csv(file.name)

        print("Loaded file successfully")
        print(f"Shape: {df.shape}")

        # Normalize column names
        df.columns = [c.strip() for c in df.columns]
        global_df = df

        # Get available countries for dropdown
        countries = sorted(df['Country'].unique().tolist()) if 'Country' in df.columns else []

        return f"Dataset loaded successfully! Shape: {df.shape}", df.head(10).to_html(), gr.Dropdown(choices=countries, value=countries[0] if countries else None), gr.Dropdown(choices=countries, value=countries[0] if countries else None)
    except Exception as e:
        return f"Error loading file: {str(e)}", "", gr.Dropdown(choices=[]), gr.Dropdown(choices=[])

# ----------------- 2. Data Preprocessing -----------------
def preprocess_data():
    global global_df, global_scalers
    if global_df is None:
        return "Please load dataset first!", "", "", gr.Dropdown(choices=[]), gr.Dropdown(choices=[])

    try:
        df = global_df.copy()

        # Convert Year to integer
        df['Year'] = df['Year'].astype(int)
        df = df.sort_values(['Country','Year']).reset_index(drop=True)

        # Handle missing values
        num_cols = [c for c in df.columns if c not in ['Country','Year']]
        for country in df['Country'].unique():
            mask = df['Country'] == country
            df.loc[mask, num_cols] = df.loc[mask, num_cols].interpolate(method='linear', limit_direction='both', axis=0)

        global_df = df

        # Get updated countries list
        countries = sorted(df['Country'].unique().tolist())

        # Create EDA plot
        plt.figure(figsize=(10, 6))
        agg = df.groupby('Year')['Total Energy Consumption (TWh)'].sum().reset_index()
        plt.plot(agg['Year'], agg['Total Energy Consumption (TWh)'], marker='o', linewidth=2, color='cyan')
        plt.title("Global Total Energy Consumption (TWh) over Years", fontsize=14, color='white')
        plt.xlabel("Year", color='white')
        plt.ylabel("TWh", color='white')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()

        return ("Data preprocessing completed!",
                f"Missing values after processing:\n{df[num_cols].isna().sum().to_string()}",
                plt,
                gr.Dropdown(choices=countries, value=countries[0] if countries else None),
                gr.Dropdown(choices=countries, value=countries[0] if countries else None))
    except Exception as e:
        return f"Error in preprocessing: {str(e)}", "", None, gr.Dropdown(choices=[]), gr.Dropdown(choices=[])

# ----------------- 3. Model Training -----------------
def create_sequences(series, lookback=5):
    X, y = [], []
    for i in range(lookback, len(series)):
        X.append(series[i-lookback:i])
        y.append(series[i])
    return np.array(X), np.array(y)

def build_lstm(input_shape, units=32, dropout=0.2):
    model = Sequential()
    model.add(LSTM(units, input_shape=input_shape, return_sequences=False))
    if dropout and dropout>0:
        model.add(Dropout(dropout))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mse')
    return model

def train_models(country, lookback=5, lstm_units=32):
    global global_df, global_models, global_scalers, global_country_data

    if global_df is None:
        return "Please load and preprocess data first!", None, ""

    if not country:
        return "Please select a country!", None, ""

    try:
        # Filter country data
        country_df = global_df[global_df['Country']==country].copy().reset_index(drop=True)

        if len(country_df) < lookback + 3:
            return f"Not enough data for {country}. Need at least {lookback + 3} years.", None, ""

        # Prepare target series
        target_col = 'Total Energy Consumption (TWh)'
        target_series = country_df[target_col].values.astype(float)

        # Scale data
        scaler_target = MinMaxScaler()
        scaled_target = scaler_target.fit_transform(target_series.reshape(-1,1)).flatten()

        # Create sequences
        X_all, y_all = create_sequences(scaled_target, lookback=lookback)

        # Split data
        n_samples = len(X_all)
        train_n = int(0.7 * n_samples)
        val_n = int(0.15 * n_samples) + train_n

        X_train = X_all[:train_n].reshape((train_n, lookback, 1))
        y_train = y_all[:train_n]
        X_test = X_all[val_n:].reshape((len(X_all[val_n:]), lookback, 1))
        y_test = y_all[val_n:]

        # Train Linear Regression baseline
        raw_series = target_series
        X_raw, y_raw = create_sequences(raw_series, lookback=lookback)
        test_raw = X_raw[val_n:]
        y_test_raw = y_raw[val_n:]

        lr = LinearRegression()
        lr.fit(X_raw[:train_n], y_raw[:train_n])
        y_pred_lr = lr.predict(test_raw)
        rmse_lr = math.sqrt(mean_squared_error(y_test_raw, y_pred_lr))
        mae_lr = mean_absolute_error(y_test_raw, y_pred_lr)

        # Train LSTM
        model = build_lstm((lookback, 1), units=lstm_units, dropout=0.2)
        early_stop = EarlyStopping(monitor='loss', patience=10, restore_best_weights=True)
        history = model.fit(X_train, y_train, epochs=100, batch_size=8, callbacks=[early_stop], verbose=0)

        # Evaluate LSTM
        y_pred_lstm_scaled = model.predict(X_test).flatten()
        y_pred_lstm = scaler_target.inverse_transform(y_pred_lstm_scaled.reshape(-1,1)).flatten()
        y_test_unscaled = scaler_target.inverse_transform(y_test.reshape(-1,1)).flatten()
        rmse_lstm = math.sqrt(mean_squared_error(y_test_unscaled, y_pred_lstm))
        mae_lstm = mean_absolute_error(y_test_unscaled, y_pred_lstm)

        # Store models and data
        global_models[country] = {
            'lr': lr,
            'lstm': model,
            'scaler_target': scaler_target,
            'lookback': lookback
        }
        global_country_data[country] = country_df

        # Create comparison plot
        plt.figure(figsize=(12, 6))
        years = country_df['Year'].values[lookback+val_n:]
        plt.plot(years, y_test_unscaled, label='Actual', marker='o', linewidth=2, color='#00ff00')
        plt.plot(years, y_pred_lstm, label='LSTM Prediction', marker='s', linewidth=2, color='#ff00ff')
        plt.plot(years, y_pred_lr, label='Linear Regression', marker='^', linewidth=2, color='#00ffff')
        plt.title(f"Energy Consumption Forecast - {country}", fontsize=14, color='white')
        plt.xlabel("Year", color='white')
        plt.ylabel("Total Energy Consumption (TWh)", color='white')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()

        results_text = f"""
         Model Results for {country}:

         Linear Regression:
           RMSE: {rmse_lr:.4f}
           MAE: {mae_lr:.4f}

         LSTM Neural Network:
           RMSE: {rmse_lstm:.4f}
           MAE: {mae_lstm:.4f}

         Dataset Info:
           Data points: {len(country_df)} years
           Lookback window: {lookback} years
           Test samples: {len(y_test_unscaled)} years
        """

        return f" Models trained successfully for {country}!", plt, results_text

    except Exception as e:
        return f" Error training models: {str(e)}", None, ""

# ----------------- 4. Forecasting -----------------
def forecast_future(country, years_ahead=5):
    global global_models, global_country_data

    if not country:
        return None, "Please select a country first!"

    if country not in global_models:
        return None, "Please train models for this country first!"

    try:
        model_data = global_models[country]
        country_df = global_country_data[country]
        lookback = model_data['lookback']

        # Get last window
        target_series = country_df['Total Energy Consumption (TWh)'].values.astype(float)
        scaler_target = model_data['scaler_target']
        scaled_target = scaler_target.transform(target_series.reshape(-1,1)).flatten()
        last_window = scaled_target[-lookback:]

        # Recursive forecasting
        def recursive_forecast(last_window_scaled, steps, model, scaler_target):
            window = last_window_scaled.copy()
            preds = []
            for s in range(steps):
                inp = window.reshape(1, window.shape[0], 1)
                p = model.predict(inp, verbose=0).flatten()[0]
                preds.append(p)
                window = np.roll(window, -1)
                window[-1] = p
            preds_unscaled = scaler_target.inverse_transform(np.array(preds).reshape(-1,1)).flatten()
            return preds_unscaled

        forecast_vals = recursive_forecast(last_window, years_ahead, model_data['lstm'], scaler_target)

        # Create forecast plot
        plt.figure(figsize=(12, 6))
        historical_years = country_df['Year'].values
        historical_values = country_df['Total Energy Consumption (TWh)'].values

        forecast_years = list(range(historical_years[-1] + 1, historical_years[-1] + years_ahead + 1))

        plt.plot(historical_years[-10:], historical_values[-10:], label='Historical', marker='o', linewidth=2, color='#00ff00')
        plt.plot(forecast_years, forecast_vals, label='Forecast', marker='s', linewidth=2, color='#ff00ff', linestyle='--')
        plt.title(f" Energy Consumption Forecast - {country}", fontsize=14, color='white')
        plt.xlabel("Year", color='white')
        plt.ylabel("Total Energy Consumption (TWh)", color='white')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()

        forecast_text = f" Forecast for {country} (next {years_ahead} years):\n\n"
        for i, val in enumerate(forecast_vals, 1):
            forecast_text += f" Year {historical_years[-1] + i}: {val:.2f} TWh\n"

        forecast_text += f"\n Based on {len(country_df)} years of historical data"
        forecast_text += f"\n Using {lookback}-year lookback window"

        return plt, forecast_text

    except Exception as e:
        return None, f" Error in forecasting: {str(e)}"

# ----------------- 5. Clustering Analysis -----------------
def perform_clustering(n_clusters=4):
    global global_df

    if global_df is None:
        return None, "Please load dataset first!"

    try:
        # Aggregate features by country
        agg_cols = ['Per Capita Energy Use (kWh)','Renewable Energy Share (%)','Fossil Fuel Dependency (%)',
                   'Industrial Energy Use (%)','Household Energy Use (%)','Carbon Emissions (Million Tons)']
        agg_df = global_df.groupby('Country')[agg_cols].mean().dropna()

        # Perform PCA and clustering
        pca = PCA(n_components=2)
        X_pca = pca.fit_transform(agg_df.values)

        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        clusters = kmeans.fit_predict(agg_df.values)

        # Create cluster plot
        plt.figure(figsize=(12, 8))
        scatter = plt.scatter(X_pca[:, 0], X_pca[:, 1], c=clusters, cmap='viridis', s=100, alpha=0.7)
        plt.colorbar(scatter, label='Cluster')
        plt.title(f' Country Clustering (K-means, k={n_clusters})', fontsize=14, color='white')
        plt.xlabel('Principal Component 1', color='white')
        plt.ylabel('Principal Component 2', color='white')

        # Add country labels for some points
        for i, country in enumerate(agg_df.index):
            if i % 3 == 0:  # Label every 3rd country to avoid clutter
                plt.annotate(country, (X_pca[i, 0], X_pca[i, 1]), fontsize=8, alpha=0.7, color='white')

        plt.grid(True, alpha=0.3)
        plt.tight_layout()

        # Create cluster summary
        agg_df['Cluster'] = clusters
        cluster_summary = agg_df.groupby('Cluster').agg({
            'Per Capita Energy Use (kWh)': 'mean',
            'Renewable Energy Share (%)': 'mean',
            'Fossil Fuel Dependency (%)': 'mean',
            'Carbon Emissions (Million Tons)': 'mean'
        }).round(2)

        summary_text = " Cluster Summary:\n\n"
        for cluster_id in sorted(cluster_summary.index):
            summary_text += f" Cluster {cluster_id}:\n"
            cluster_data = cluster_summary.loc[cluster_id]
            summary_text += f"    Countries: {len(agg_df[agg_df['Cluster'] == cluster_id])}\n"
            summary_text += f"    Per Capita: {cluster_data['Per Capita Energy Use (kWh)']} kWh\n"
            summary_text += f"    Renewable: {cluster_data['Renewable Energy Share (%)']}%\n"
            summary_text += f"    Fossil Fuel: {cluster_data['Fossil Fuel Dependency (%)']}%\n"
            summary_text += f"    Carbon Emissions: {cluster_data['Carbon Emissions (Million Tons)']} MT\n\n"

        return plt, summary_text

    except Exception as e:
        return None, f" Error in clustering: {str(e)}"

# ----------------- 6. Gradio Interface -----------------
def create_interface():
    # Custom CSS for dark theme with neon buttons
    custom_css = """
    .gradio-container {
        background: linear-gradient(135deg, #0c0c0c 0%, #1a1a2e 50%, #16213e 100%);
        color: white;
        font-family: 'Arial', sans-serif;
    }
    .gradio-container .tab-nav {
        background: rgba(255, 255, 255, 0.1) !important;
        backdrop-filter: blur(10px);
        border-radius: 10px;
        margin: 10px;
        padding: 10px;
    }
    .gradio-container .tab-nav button {
        background: transparent !important;
        color: white !important;
        border: none !important;
        padding: 12px 24px !important;
        margin: 5px !important;
        border-radius: 8px !important;
        transition: all 0.3s ease !important;
    }
    .gradio-container .tab-nav button:hover {
        background: rgba(0, 255, 255, 0.2) !important;
        transform: translateY(-2px);
    }
    .gradio-container .tab-nav button.selected {
        background: linear-gradient(45deg, #00ffff, #0080ff) !important;
        color: black !important;
        font-weight: bold;
        box-shadow: 0 0 20px #00ffff;
    }
    .gradio-button {
        background: linear-gradient(45deg, #ff00ff, #00ffff) !important;
        border: none !important;
        color: white !important;
        font-weight: bold !important;
        padding: 12px 30px !important;
        border-radius: 25px !important;
        transition: all 0.3s ease !important;
        box-shadow: 0 0 15px rgba(0, 255, 255, 0.5);
        margin: 10px 0px !important;
    }
    .gradio-button:hover {
        transform: translateY(-3px);
        box-shadow: 0 0 25px rgba(0, 255, 255, 0.8);
    }
    .gradio-plot {
        border-radius: 15px;
        background: rgba(255, 255, 255, 0.05) !important;
        backdrop-filter: blur(10px);
        padding: 20px;
        border: 1px solid rgba(255, 255, 255, 0.1) !important;
    }
    .gradio-textbox, .gradio-number, .gradio-dropdown {
        background: rgba(255, 255, 255, 0.1) !important;
        border: 1px solid rgba(255, 255, 255, 0.3) !important;
        color: white !important;
        border-radius: 10px !important;
        padding: 12px !important;
    }
    .gradio-textbox label, .gradio-number label, .gradio-dropdown label {
        color: #00ffff !important;
        font-weight: bold;
    }
    .gradio-markdown {
        color: white !important;
    }
    .gradio-container h1, .gradio-container h2, .gradio-container h3 {
        background: linear-gradient(45deg, #00ffff, #ff00ff);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        text-align: center;
    }
    .plot-container {
        background: transparent !important;
    }
    """

    with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
        gr.Markdown(
            """
            #  Energy Consumption Forecasting Dashboard
            ### Advanced Analytics with Machine Learning
            """
        )

        with gr.Tabs():
            # Tab 1: Data Loading
            with gr.TabItem(" Data Loading"):
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("### Step 1: Upload Your Dataset")
                        file_input = gr.File(
                            label=" Upload Energy Dataset (CSV)",
                            file_types=[".csv"],
                            type="filepath"
                        )
                        load_btn = gr.Button(" Load Dataset", variant="primary", size="lg")
                    with gr.Column():
                        gr.Markdown("### Status & Preview")
                        load_status = gr.Textbox(
                            label=" Status",
                            interactive=False,
                            lines=2
                        )
                        data_preview = gr.HTML(
                            label="üëÄ Data Preview",
                            value="<div style='color: white; text-align: center; padding: 20px;'>Upload a CSV file to begin</div>"
                        )

                # Store dropdown components
                train_country_dropdown = gr.Dropdown(visible=False)
                forecast_country_dropdown = gr.Dropdown(visible=False)

                load_btn.click(
                    load_dataset,
                    inputs=[file_input],
                    outputs=[load_status, data_preview, train_country_dropdown, forecast_country_dropdown]
                )

            # Tab 2: Data Preprocessing
            with gr.TabItem(" Preprocessing"):
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("### Step 2: Preprocess Data")
                        preprocess_btn = gr.Button(" Preprocess Data", variant="primary", size="lg")
                    with gr.Column():
                        gr.Markdown("### Processing Results")
                        preprocess_status = gr.Textbox(
                            label=" Status",
                            interactive=False,
                            lines=2
                        )
                        missing_info = gr.Textbox(
                            label=" Missing Values Info",
                            interactive=False,
                            lines=4
                        )

                with gr.Row():
                    eda_plot = gr.Plot(
                        label=" Global Energy Consumption Trend",
                        value=None
                    )

            # Tab 3: Model Training
            with gr.TabItem(" Model Training"):
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("### Step 3: Train Prediction Models")
                        country_dropdown = gr.Dropdown(
                            label=" Select Country",
                            choices=[],
                            interactive=True,
                            value=None
                        )
                        with gr.Row():
                            lookback_slider = gr.Slider(
                                minimum=3, maximum=10, value=5, step=1,
                                label=" Lookback Window (years)"
                            )
                            units_slider = gr.Slider(
                                minimum=16, maximum=64, value=32, step=16,
                                label=" LSTM Units"
                            )
                        train_btn = gr.Button("Train Models", variant="primary", size="lg")

                    with gr.Column():
                        gr.Markdown("### Training Results")
                        train_status = gr.Textbox(
                            label="Status",
                            interactive=False,
                            lines=2
                        )
                        results_text = gr.Textbox(
                            label="Model Performance",
                            interactive=False,
                            lines=12
                        )

                with gr.Row():
                    training_plot = gr.Plot(
                        label="Model Performance Comparison",
                        value=None
                    )

            # Tab 4: Forecasting
            with gr.TabItem("Forecasting"):
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("### Step 4: Generate Forecasts")
                        forecast_country = gr.Dropdown(
                            label="Select Country",
                            choices=[],
                            interactive=True,
                            value=None
                        )
                        forecast_years = gr.Slider(
                            minimum=1, maximum=10, value=5, step=1,
                            label="Years to Forecast"
                        )
                        forecast_btn = gr.Button(" Generate Forecast", variant="primary", size="lg")

                    with gr.Column():
                        gr.Markdown("### Forecast Results")
                        forecast_output = gr.Textbox(
                            label="Forecast Results",
                            interactive=False,
                            lines=8
                        )

                with gr.Row():
                    forecast_plot = gr.Plot(
                        label="Energy Consumption Forecast",
                        value=None
                    )

            # Tab 5: Clustering
            with gr.TabItem(" Country Clustering"):
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("### Step 5: Analyze Country Clusters")
                        cluster_slider = gr.Slider(
                            minimum=2, maximum=8, value=4, step=1,
                            label="Number of Clusters"
                        )
                        cluster_btn = gr.Button("üîç Analyze Clusters", variant="primary", size="lg")

                    with gr.Column():
                        gr.Markdown("### Cluster Analysis")
                        cluster_summary = gr.Textbox(
                            label="Cluster Summary",
                            interactive=False,
                            lines=15
                        )

                with gr.Row():
                    cluster_plot = gr.Plot(
                        label=" Country Clusters Visualization",
                        value=None
                    )

        # Connect the preprocessing to update dropdowns
        preprocess_btn.click(
            preprocess_data,
            outputs=[preprocess_status, missing_info, eda_plot, country_dropdown, forecast_country]
        )

        # Connect training button
        train_btn.click(
            train_models,
            inputs=[country_dropdown, lookback_slider, units_slider],
            outputs=[train_status, training_plot, results_text]
        )

        # Connect forecasting button
        forecast_btn.click(
            forecast_future,
            inputs=[forecast_country, forecast_years],
            outputs=[forecast_plot, forecast_output]
        )

        # Connect clustering button
        cluster_btn.click(
            perform_clustering,
            inputs=[cluster_slider],
            outputs=[cluster_plot, cluster_summary]
        )

        gr.Markdown(
            """
            ---
            ###  Energy Forecasting System
            *Built with TensorFlow, Scikit-learn, and Gradio*
            *Features: LSTM Neural Networks, Linear Regression, K-means Clustering, PCA*

            ** Expected CSV Format:**
            - Country, Year, Total Energy Consumption (TWh), Per Capita Energy Use (kWh)
            - Renewable Energy Share (%), Fossil Fuel Dependency (%), Industrial Energy Use (%)
            - Household Energy Use (%), Carbon Emissions (Million Tons), Energy Price Index (USD/kWh)
            """
        )

    return demo

# ----------------- 7. Main Execution -----------------
if __name__ == "__main__":
    # For Colab deployment
    print(" Starting Energy Forecasting Dashboard...")
    print(" Please upload your CSV file when prompted")
    print(" The interface will open automatically")

    demo = create_interface()
    demo.launch(
        share=True,
        debug=True,
        show_error=True
    )

 Starting Energy Forecasting Dashboard...
 Please upload your CSV file when prompted
 The interface will open automatically
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://c5ca7865352416a14f.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
