# **Analyse et prévision de la demande en énergie - Energy Demand Analysis and Forecasting**

##Preliminaries: Installing libraries and downloading data


In [None]:
!pip install pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import plotly.express as px
from pyspark.sql.functions import avg
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
from prophet import Prophet
import matplotlib.pyplot as plt



In [None]:
spark = SparkSession.builder \
    .appName("EnergyDemandForecasting") \
    .config("spark.sql.shuffle.partitions", "50") \
    .config("spark.executor.memory", "2g") \
    .getOrCreate()

print("Apache Spark version:", spark.version)

Apache Spark version: 3.5.3


In [None]:
!unzip /content/drive/MyDrive/Data/archive.zip -d /content/unzipped_data

Archive:  /content/drive/MyDrive/Data/archive.zip
  inflating: /content/unzipped_data/Consumption_Data/Consumption_Coal.csv  
  inflating: /content/unzipped_data/Consumption_Data/Consumption_NaturalGas.csv  
  inflating: /content/unzipped_data/Consumption_Data/Consumption_Neuclear+renewables.csv  
  inflating: /content/unzipped_data/Consumption_Data/Consumption_Petroleum.csv  
  inflating: /content/unzipped_data/Consumption_Data/Consumption_Total.csv  
  inflating: /content/unzipped_data/Production_Data/Production_Coal.csv  
  inflating: /content/unzipped_data/Production_Data/Production_NaturalGas.csv  
  inflating: /content/unzipped_data/Production_Data/Production_Nuclear+renewables.csv  
  inflating: /content/unzipped_data/Production_Data/Production_Pertroleum.csv  
  inflating: /content/unzipped_data/Production_Data/Production_Total.csv  
  inflating: /content/unzipped_data/World Energy Overview.csv  


##Step 1: Dataset Overview

**About Dataset**\
World Energy Overview: Contains the Monthly data of energy Production, Consumption, Imports, Exports, Stock exchanges and production and consumption from different resources like renewables, nuclear and fossil fuels from 1973 to 2022.

Below files contain the Yearly data from 1980 to 2021

*   Production_Total : Total Energy Production
*   Production_Coal : energy production from coal
*   Production_NaturalGas : energy production from NaturalGas
*   Production_Nuclear+renewables : energy production from Nuclear+renewables
*   Production_Pertroleum : energy production from Pertroleum
*   Consumption_Total : Total Consumption Production
*   Consumption_Coal : energy Consumption from coal
*   Consumption_NaturalGas : energy Consumption from NaturalGas
*   Consumption_Nuclear+renewables : energy Consumption from Nuclear+renewables
*   Consumption_Pertroleum : energy Consumption from Pertroleum

In [None]:
file_path = "/content/unzipped_data/Consumption_Data/Consumption_Total.csv"
df = spark.read.csv(file_path, header=True, inferSchema=True)
df.printSchema()
df.show()

root
 |-- Continent: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- 1980: string (nullable = true)
 |-- 1981: string (nullable = true)
 |-- 1982: string (nullable = true)
 |-- 1983: string (nullable = true)
 |-- 1984: string (nullable = true)
 |-- 1985: string (nullable = true)
 |-- 1986: string (nullable = true)
 |-- 1987: string (nullable = true)
 |-- 1988: string (nullable = true)
 |-- 1989: string (nullable = true)
 |-- 1990: string (nullable = true)
 |-- 1991: string (nullable = true)
 |-- 1992: string (nullable = true)
 |-- 1993: string (nullable = true)
 |-- 1994: string (nullable = true)
 |-- 1995: string (nullable = true)
 |-- 1996: string (nullable = true)
 |-- 1997: string (nullable = true)
 |-- 1998: string (nullable = true)
 |-- 1999: string (nullable = true)
 |-- 2000: string (nullable = true)
 |-- 2001: string (nullable = true)
 |-- 2002: string (nullable = true)
 |-- 2003: string (nullable = true)
 |-- 2004: string (nullable = true)
 |-- 2005: stri

In [None]:
file_path = "/content/unzipped_data/Consumption_Data/Consumption_NaturalGas.csv"
df = spark.read.csv(file_path, header=True, inferSchema=True)
df.show()

+---------+--------------------+---------+---------+---------+---------+---------+--------+---------+--------+--------+--------+-----------+--------+-----------+-----------+----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------------------+
|Continent|             Country|     1980|     1981|     1982|     1983|     1984|    1985|     1986|    1987|    1988|    1989|       1990|    1991|       1992|       1993|      1994|       1995|       1996|       1997|       1998|       1999|       2000|       2001|       2002|       2003|       2004|       2005|       2006|       2007|       2008|       2009|       2010|       2011|       2012|       2013|       2014|       2015|       2016|       2017|       2

In [None]:
from pyspark.sql.functions import col
file_path = "/content/unzipped_data/Consumption_Data/Consumption_Total.csv"
production = spark.read.csv(file_path, header=True, inferSchema=True)

print('Total Continents')
production.groupBy("Continent").count().orderBy(col("count")).show()

Total Continents
+--------------------+-----+
|           Continent|count|
+--------------------+-----+
|       North America|    7|
|             Eurasia|   13|
|         Middle East|   14|
|Central & South A...|   45|
|              Europe|   45|
|      Asia & Oceania|   49|
|              Africa|   57|
+--------------------+-----+



In [None]:
from pyspark.sql.functions import col, desc
file_path = "/content/unzipped_data/Consumption_Data/Consumption_Total.csv"
production = spark.read.csv(file_path, header=True, inferSchema=True)

top_production = production.select("Country", "Total_Consumption") \
                           .orderBy(desc("Total_Consumption")) \
                           .limit(10)

top_production.show(10)

+--------------------+-----------------+
|             Country|Total_Consumption|
+--------------------+-----------------+
|        United St...|      3830.095107|
|               China|       2953.32024|
|              Russia|       894.408884|
|               Japan|      827.2707178|
|        Former U....|      697.8822618|
|               India|      632.3043422|
|              Canada|       521.849654|
|             Germany|      437.3495472|
|              France|      422.0276212|
|        United Ki...|      381.0669734|
+--------------------+-----------------+



##Step 2: Data Exploration and Cleaning:



1-Convert Yearly Columns to Numeric: The columns 1980, 1981, etc., are
stored as strings. Convert them to numeric for further processing



In [None]:
from pyspark.sql.functions import col

year_columns = [str(year) for year in range(1980, 2022)]
df = df.select(*[col(c).cast("double").alias(c) if c in year_columns else col(c) for c in df.columns])
df.printSchema()

root
 |-- Continent: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- 1980: double (nullable = true)
 |-- 1981: double (nullable = true)
 |-- 1982: double (nullable = true)
 |-- 1983: double (nullable = true)
 |-- 1984: double (nullable = true)
 |-- 1985: double (nullable = true)
 |-- 1986: double (nullable = true)
 |-- 1987: double (nullable = true)
 |-- 1988: double (nullable = true)
 |-- 1989: double (nullable = true)
 |-- 1990: double (nullable = true)
 |-- 1991: double (nullable = true)
 |-- 1992: double (nullable = true)
 |-- 1993: double (nullable = true)
 |-- 1994: double (nullable = true)
 |-- 1995: double (nullable = true)
 |-- 1996: double (nullable = true)
 |-- 1997: double (nullable = true)
 |-- 1998: double (nullable = true)
 |-- 1999: double (nullable = true)
 |-- 2000: double (nullable = true)
 |-- 2001: double (nullable = true)
 |-- 2002: double (nullable = true)
 |-- 2003: double (nullable = true)
 |-- 2004: double (nullable = true)
 |-- 2005: doub

2-Handle Null or Missing Values: Check for missing or inconsistent data (e.g., null or NA values):

In [None]:
# Check for null values
df.select([col(c).isNull().alias(c) for c in df.columns]).show()

# Replace null values with 0 (or an appropriate default)
df = df.fillna(0)
df.show()

+---------+-------+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+----------------------+
|Continent|Country| 1980| 1981| 1982| 1983| 1984| 1985| 1986| 1987| 1988| 1989| 1990| 1991| 1992| 1993| 1994| 1995| 1996| 1997| 1998| 1999| 2000| 2001| 2002| 2003| 2004| 2005| 2006| 2007| 2008| 2009| 2010| 2011| 2012| 2013| 2014| 2015| 2016| 2017| 2018| 2019| 2020| 2021|Consumption_NaturalGas|
+---------+-------+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+----------------------+
|    false|  false|false|false|false|false|false|false|false|false|false|false|false|false|false|false|false|false|

3-Reshape data from a wide format to a long format in PySpark

In [None]:
file_path = "/content/unzipped_data/Consumption_Data/Consumption_Total.csv"
df = spark.read.csv(file_path, header=True, inferSchema=True)

year_columns = [col for col in df.columns if col.isdigit()]
expr = "stack({0}, {1}) as (Year, Value)".format(
    len(year_columns),
    ", ".join([f"'{col}', `{col}`" for col in year_columns])
)
reshaped_df = df.selectExpr("Continent", "Country", expr)
reshaped_df.show(truncate=False)

+---------+---------------+----+-----------+
|Continent|Country        |Year|Value      |
+---------+---------------+----+-----------+
|Africa   |        Algeria|1980|0.780695167|
|Africa   |        Algeria|1981|0.663391323|
|Africa   |        Algeria|1982|0.952188116|
|Africa   |        Algeria|1983|1.070561843|
|Africa   |        Algeria|1984|1.130786713|
|Africa   |        Algeria|1985|1.046418247|
|Africa   |        Algeria|1986|1.066300962|
|Africa   |        Algeria|1987|1.138318654|
|Africa   |        Algeria|1988|1.213119365|
|Africa   |        Algeria|1989|1.17528245 |
|Africa   |        Algeria|1990|1.217674713|
|Africa   |        Algeria|1991|1.354022185|
|Africa   |        Algeria|1992|1.301817553|
|Africa   |        Algeria|1993|1.19719019 |
|Africa   |        Algeria|1994|1.233153274|
|Africa   |        Algeria|1995|1.289340856|
|Africa   |        Algeria|1996|1.263667894|
|Africa   |        Algeria|1997|1.20373288 |
|Africa   |        Algeria|1998|1.250799348|
|Africa   

##Step 3: Data Visualization

In [None]:
import plotly.express as px
pandas_df = reshaped_df.toPandas()

fig = px.choropleth(
    pandas_df,
    locations="Country",
    locationmode="country names",
    color="Value",
    animation_frame="Year",
    title="Consumption by Country Over Time"
)
fig.show()


##Step 4: Analysis and Insights

In [None]:
from pyspark.sql.functions import avg
#Find the countries with the highest average consumption over time.

average_consumption_by_country = reshaped_df.groupBy("Country").agg(avg("Value").alias("AverageConsumption"))
top_countries = average_consumption_by_country.orderBy(desc("AverageConsumption")).limit(10)
top_countries.show()

+--------------------+------------------+
|             Country|AverageConsumption|
+--------------------+------------------+
|        United St...|  91.1927406509524|
|               China| 70.31714856357144|
|        Former U....| 58.15685514916666|
|              Russia|29.813629465000005|
|               Japan| 19.69692185333333|
|               India|15.054865291238096|
|             Germany|14.108049909032257|
|              Canada|12.424991761285716|
|        Germany, ...|11.280354672727272|
|              France|10.048276694928568|
+--------------------+------------------+



In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, year

windowSpec = Window.partitionBy("Country").orderBy("Year")
global_consumption_changes = reshaped_df.withColumn("PreviousYearValue", lag("Value", 1, 0).over(windowSpec)) \
    .withColumn("YearOverYearChange", (col("Value") - col("PreviousYearValue")) / col("PreviousYearValue")) \
    .filter(col("Year") > 1980)
global_consumption_changes.show()

+--------------+-------------------+----+-----------+-----------------+--------------------+
|     Continent|            Country|Year|      Value|PreviousYearValue|  YearOverYearChange|
+--------------+-------------------+----+-----------+-----------------+--------------------+
|Asia & Oceania|        Afghanistan|1981|0.029949458|      0.026583217| 0.12663030964235827|
|Asia & Oceania|        Afghanistan|1982|0.031897815|      0.029949458| 0.06505483337962246|
|Asia & Oceania|        Afghanistan|1983|0.039021803|      0.031897815| 0.22333780542648446|
|Asia & Oceania|        Afghanistan|1984| 0.03920224|      0.039021803|0.004624004687840...|
|Asia & Oceania|        Afghanistan|1985|0.038274961|       0.03920224|-0.02365372488918...|
|Asia & Oceania|        Afghanistan|1986|0.039111199|      0.038274961|  0.0218481737969634|
|Asia & Oceania|        Afghanistan|1987|0.063065158|      0.039111199|  0.6124578026871536|
|Asia & Oceania|        Afghanistan|1988|0.111220787|      0.063065158

In [None]:
from pyspark.sql.functions import sum, col
import plotly.express as px

global_consumption = reshaped_df.groupBy("Year").agg(sum(col("Value").cast("double")).alias("TotalConsumption"))
# Sort the data by year before converting to pandas DataFrame
global_consumption = global_consumption.orderBy("Year")
pandas_global_consumption = global_consumption.toPandas()
fig = px.line(pandas_global_consumption, x="Year", y="TotalConsumption", title="Global Consumption Over Time")
fig.show()

##Step 5: Forecasting energy consumption

In [None]:
reshaped_df = reshaped_df.withColumn('quarter', (col('Year') % 4) + 1)
reshaped_df.show()

+---------+---------------+----+-----------+-------+
|Continent|        Country|Year|      Value|quarter|
+---------+---------------+----+-----------+-------+
|   Africa|        Algeria|1980|0.780695167|    1.0|
|   Africa|        Algeria|1981|0.663391323|    2.0|
|   Africa|        Algeria|1982|0.952188116|    3.0|
|   Africa|        Algeria|1983|1.070561843|    4.0|
|   Africa|        Algeria|1984|1.130786713|    1.0|
|   Africa|        Algeria|1985|1.046418247|    2.0|
|   Africa|        Algeria|1986|1.066300962|    3.0|
|   Africa|        Algeria|1987|1.138318654|    4.0|
|   Africa|        Algeria|1988|1.213119365|    1.0|
|   Africa|        Algeria|1989| 1.17528245|    2.0|
|   Africa|        Algeria|1990|1.217674713|    3.0|
|   Africa|        Algeria|1991|1.354022185|    4.0|
|   Africa|        Algeria|1992|1.301817553|    1.0|
|   Africa|        Algeria|1993| 1.19719019|    2.0|
|   Africa|        Algeria|1994|1.233153274|    3.0|
|   Africa|        Algeria|1995|1.289340856|  

In [None]:
# Splitting dataset into training and testing sets based on Year
train_data = reshaped_df.filter(col('Year') < 2015)  # Training data
test_data = reshaped_df.filter(col('Year') >= 2015)  # Test data


a. Linear Regression with Time as Feature
A simple method for forecasting energy consumption is to use linear regression, where we treat the Year as a feature and the Value as the target to predict.

In [None]:
from pyspark.sql.types import IntegerType, DoubleType
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler

# Cast 'Year' column to IntegerType before assembling features
train_data = train_data.withColumn("Year", train_data["Year"].cast(IntegerType()))
# Cast 'Value' column to DoubleType before training
train_data = train_data.withColumn("Value", train_data["Value"].cast(DoubleType()))

# Remove rows with null or NaN in the 'Value' column (label column)
train_data = train_data.dropna(subset=['Value'])  # This line removes the problematic rows

# Check if 'features' column already exists before applying VectorAssembler
if 'features' not in train_data.columns:  # This check prevents the error
    assembler = VectorAssembler(inputCols=["Year"], outputCol="features")
    train_data = assembler.transform(train_data)

# Initialize the linear regression model
lr = LinearRegression(featuresCol="features", labelCol="Value")

# Train the model
lr_model = lr.fit(train_data)

# Make predictions on the test data
# Ensure 'Year' is also cast to IntegerType in test data
test_data = test_data.withColumn("Year", test_data["Year"].cast(IntegerType()))
# Ensure 'Value' is also cast to DoubleType in test data
test_data = test_data.withColumn("Value", test_data["Value"].cast(DoubleType()))

# Remove rows with null or NaN in the 'Value' column of test_data
test_data = test_data.dropna(subset=['Value']) # Apply the same fix to test_data

# Check if 'features' column already exists in test_data before applying VectorAssembler
if 'features' not in test_data.columns:  # This check prevents the error on test_data
    test_data = assembler.transform(test_data)

predictions = lr_model.transform(test_data)

# Show predictions
predictions.select("Continent","Country","Year", "Value", "prediction").show()

+---------+---------------+----+-----------+------------------+
|Continent|        Country|Year|      Value|        prediction|
+---------+---------------+----+-----------+------------------+
|   Africa|        Algeria|2015| 2.35214754| 2.424900379450463|
|   Africa|        Algeria|2016| 2.33598238| 2.455077855369126|
|   Africa|        Algeria|2017|2.365986905|2.4852553312877887|
|   Africa|        Algeria|2018|2.522119152|2.5154328072064516|
|   Africa|        Algeria|2019|2.624873096|2.5456102831251144|
|   Africa|        Algeria|2020|2.480810774| 2.575787759043777|
|   Africa|        Algeria|2021|2.581123314|  2.60596523496244|
|   Africa|         Angola|2015| 0.39921871| 2.424900379450463|
|   Africa|         Angola|2016|0.364251517| 2.455077855369126|
|   Africa|         Angola|2017|0.347458949|2.4852553312877887|
|   Africa|         Angola|2018|0.359226946|2.5154328072064516|
|   Africa|         Angola|2019|0.404865374|2.5456102831251144|
|   Africa|         Angola|2020|0.402317

b. ARIMA and Other Time Series Models
While ARIMA is not natively supported by PySpark's MLlib, you can still use libraries like statsmodels (in Python) outside Spark for time-series forecasting. For example, using ARIMA for energy consumption prediction:

c-Using Prophet for Forecasting
If you are dealing with more complex seasonality or trends, Prophet (a forecasting tool developed by Facebook) is a great choice. You can use Prophet via PySpark or outside Spark.

**Model Evaluation**
Once we have trained the model, evaluate it using the test data. We can use metrics such as Root Mean Squared Error (RMSE), Mean Absolute Error (MAE), or R-squared (R²) to assess prediction accuracy

In [None]:
from pyspark.ml.evaluation import RegressionEvaluator

# Evaluate the model using RMSE, MAE, and R2
evaluator = RegressionEvaluator(labelCol="Value", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE): {rmse}")

evaluator.setMetricName("mae")
mae = evaluator.evaluate(predictions)
print(f"Mean Absolute Error (MAE): {mae}")


Root Mean Squared Error (RMSE): 12.370504914303442
Mean Absolute Error (MAE): 3.7739328466705437


**Forecasting Future Consumption**

In [None]:
from pyspark.sql.functions import lit

# Select a sample row for Continent and Country
# (or use any other suitable method to obtain the desired values)
sample_row = reshaped_df.select("Continent", "Country").first()

# Create a DataFrame with the desired Continent and Country values
continent_country_df = spark.createDataFrame([(sample_row.Continent, sample_row.Country)], ["Continent", "Country"])

# Create future years with Continent and Country
future_years = spark.createDataFrame([(2025,), (2026,), (2027,),(2028,),(2029,),(2030,)], ["Year"])
future_years = future_years.crossJoin(continent_country_df)  # Join to add Continent and Country
future_years = assembler.transform(future_years)  # Add features column

# Make predictions
future_predictions = lr_model.transform(future_years)

# Show predictions with Continent and Country
future_predictions.select("Continent", "Country", "Year", "prediction").show()

+---------+---------------+----+------------------+
|Continent|        Country|Year|        prediction|
+---------+---------------+----+------------------+
|   Africa|        Algeria|2025|2.7266751386370913|
|   Africa|        Algeria|2026| 2.756852614555754|
|   Africa|        Algeria|2027| 2.787030090474417|
|   Africa|        Algeria|2028| 2.817207566393087|
|   Africa|        Algeria|2029|2.8473850423117497|
|   Africa|        Algeria|2030|2.8775625182304125|
+---------+---------------+----+------------------+



#Gradio-based interface

In [None]:
pip install gradio


Collecting gradio
  Downloading gradio-5.10.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.6-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.5.3 (from gradio)
  Downloading gradio_client-1.5.3-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.2.2 (from gradio)
  Downloading ruff-0.8.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.meta

In [None]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://www-us.apache.org/dist/spark/spark-3.3.0/spark-3.3.0-bin-hadoop3.tgz
!tar xf spark-3.3.0-bin-hadoop3.tgz
!pip install -q findspark


tar: spark-3.3.0-bin-hadoop3.tgz: Cannot open: No such file or directory
tar: Error is not recoverable: exiting now


In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
from pyspark.sql.types import IntegerType, DoubleType
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
import plotly.express as px
import gradio as gr

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("EnergyDemandForecasting") \
    .config("spark.sql.shuffle.partitions", "50") \
    .config("spark.executor.memory", "2g") \
    .getOrCreate()

# Define global variables
assembler = VectorAssembler(inputCols=["Year"], outputCol="features")
linear_regression_model = None
processed_df = None  # Store processed DataFrame globally

def upload_and_process(file_path):
    """Load and preprocess uploaded CSV file."""
    try:
        df = spark.read.csv(file_path.name, header=True, inferSchema=True)
        year_columns = [col for col in df.columns if col.isdigit()]
        expr = "stack({0}, {1}) as (Year, Value)".format(
            len(year_columns),
            ", ".join([f"'{col}', `{col}`" for col in year_columns])
        )
        reshaped_df = df.selectExpr("Continent", "Country", expr)
        reshaped_df = reshaped_df.withColumn("Year", reshaped_df["Year"].cast("int"))
        reshaped_df = reshaped_df.withColumn("Value", reshaped_df["Value"].cast("double"))
        global processed_df  # Store DataFrame globally for reuse
        processed_df = reshaped_df
        return reshaped_df.toPandas()  # Return as Pandas DataFrame for Gradio
    except Exception as e:
        return f"Error processing data: {e}"

def visualize_data():
    """Generate a plotly chart for energy consumption trends."""
    if processed_df is None:
        return "No data processed. Please upload and process a file first."
    pandas_df = processed_df.toPandas()
    fig = px.line(
        pandas_df,
        x="Year",
        y="Value",
        color="Country",
        title="Energy Consumption Trends"
    )
    return fig

def train_model():
    """Train a linear regression model for forecasting."""
    global linear_regression_model
    if processed_df is None:
        return "No data processed. Please upload and process a file first."
    try:
        train_data = processed_df.filter(col("Year") < 2015).dropna(subset=["Value"])
        train_data = assembler.transform(train_data)
        lr = LinearRegression(featuresCol="features", labelCol="Value")
        linear_regression_model = lr.fit(train_data)
        return "Model trained successfully!"
    except Exception as e:
        return f"Error training model: {e}"

def forecast():
    """Forecast future energy consumption."""
    if not linear_regression_model:
        return "Model not trained yet!"
    if processed_df is None:
        return "No data processed. Please upload and process a file first."
    try:
        future_years = spark.createDataFrame([(y,) for y in range(2025, 2031)], ["Year"])
        country = processed_df.select("Continent", "Country").first()
        future_years = future_years.withColumn("Continent", lit(country.Continent))
        future_years = future_years.withColumn("Country", lit(country.Country))
        future_years = assembler.transform(future_years)
        predictions = linear_regression_model.transform(future_years)
        return predictions.select("Year", "prediction").toPandas()
    except Exception as e:
        return f"Error during forecasting: {e}"

# Define Gradio Interface
with gr.Blocks() as app:
    gr.Markdown("# Energy Demand Analysis and Forecasting")

    with gr.Tab("Upload and Process"):
        file_input = gr.File(label="Upload Energy Data CSV")
        data_preview = gr.Dataframe()
        process_button = gr.Button("Process Data")
        process_button.click(upload_and_process, inputs=file_input, outputs=data_preview)

    with gr.Tab("Visualize Trends"):
        chart_output = gr.Plot()
        visualize_button = gr.Button("Generate Chart")
        visualize_button.click(visualize_data, outputs=chart_output)

    with gr.Tab("Train Model"):
        train_button = gr.Button("Train Forecast Model")
        train_status = gr.Textbox()
        train_button.click(train_model, outputs=train_status)

    with gr.Tab("Forecast"):
        forecast_output = gr.Dataframe()
        forecast_button = gr.Button("Forecast Future Consumption")
        forecast_button.click(forecast, outputs=forecast_output)

# Launch Gradio App
app.launch()


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://b15da1690289cb779f.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


