# 1. Environment Setup

In [None]:
import os
import subprocess


try:
    java_path = subprocess.check_output(["which", "java"]).decode("utf-8").strip()
    if os.path.islink(java_path):
        java_path = os.path.realpath(java_path)
    java_home = os.path.dirname(os.path.dirname(java_path))

    print(f"Detected Java Home: {java_home}")
    os.environ["JAVA_HOME"] = java_home

except Exception as e:
    print(f"Failed to detect Java: {e}")
    os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"

# Install PySpark & Findspark
!pip install pyspark findspark -q

# Mount Google Drive
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print("Drive already mounted.")

# Initialize Spark
import findspark
findspark.init()

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .master("local[*]") \
    .appName("SupplyChain_Forecast_Auto") \
    .config("spark.ui.port", "4050") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

print("\n" + "="*30)
print("PySpark Initialized Successfully!")
print(f"Spark Version: {spark.version}")
print("="*30)

Detected Java Home: /usr/lib/jvm/java-17-openjdk-amd64
Mounted at /content/drive

PySpark Initialized Successfully!
Spark Version: 4.0.1


# 2. Data Ingestion

In [None]:
# NOTE: Update 'file_path' to point to your local dataset or mounted drive
file_path = "/content/drive/MyDrive/Colab Notebooks/Data/train.csv"

# Read CSV with schema inference
# inferSchema=True allows Spark to guess data types (e.g., Integer, String)
df = spark.read.csv(file_path, header=True, inferSchema=True)

print("Data Loaded Successfully")
print(f"Total Records: {df.count()}")
print(f"Total Columns: {len(df.columns)}")

print("\n--- Raw Data Schema ---")
df.printSchema()

print("\n--- Data Preview (Top 3 Rows) ---")
df.show(3)

Data Loaded Successfully
Total Records: 9800
Total Columns: 18

--- Raw Data Schema ---
root
 |-- Row ID: integer (nullable = true)
 |-- Order ID: string (nullable = true)
 |-- Order Date: string (nullable = true)
 |-- Ship Date: string (nullable = true)
 |-- Ship Mode: string (nullable = true)
 |-- Customer ID: string (nullable = true)
 |-- Customer Name: string (nullable = true)
 |-- Segment: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- City: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Postal Code: integer (nullable = true)
 |-- Region: string (nullable = true)
 |-- Product ID: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Sub-Category: string (nullable = true)
 |-- Product Name: string (nullable = true)
 |-- Sales: string (nullable = true)


--- Data Preview (Top 3 Rows) ---
+------+--------------+----------+----------+------------+-----------+---------------+---------+-------------+-----------+----------+-------

# 3. Data Cleaning

In [None]:
from pyspark.sql.functions import col, to_date, regexp_replace
from pyspark.sql.types import DoubleType

# Disable ANSI Strict Mode
# Ensure Spark returns NULL instead of throwing exceptions for malformed data
spark.conf.set("spark.sql.ansi.enabled", "false")

# Define expected date format (DD/MM/YYYY based on raw data inspection)
date_fmt = 'dd/MM/yyyy'

# Cleaning Logic
# 1. Sales: Remove '$' and ',' characters, then cast to Double
# 2. Order Date: Parse string to DateType
df_cleaned = df.withColumn("Sales",
                           regexp_replace(col("Sales"), "[$,]", "")
                           .cast(DoubleType())) \
               .withColumn("Order Date",
                           to_date(col("Order Date"), date_fmt))

# Data Quality Check
total_count = df.count()
null_sales = df_cleaned.filter(col("Sales").isNull()).count()
null_date = df_cleaned.filter(col("Order Date").isNull()).count()

print(f"\n--- Data Quality Report ---")
print(f"Total Rows Processed: {total_count}")
print(f"Invalid Sales Records (Dropped): {null_sales}")
print(f"Invalid Date Records: {null_date}")

# Drop Malformed Records
# Remove rows where critical business fields are invalid
df_final = df_cleaned.dropna(subset=["Sales", "Order Date"])
print(f"Final Cleaned Row Count: {df_final.count()}")

df_final.show(5)


--- Data Quality Report ---
Total Rows Processed: 9800
Invalid Sales Records (Dropped): 292
Invalid Date Records: 0
Final Cleaned Row Count: 9508
+------+--------------+----------+----------+--------------+-----------+---------------+---------+-------------+---------------+----------+-----------+------+---------------+---------------+------------+--------------------+--------+
|Row ID|      Order ID|Order Date| Ship Date|     Ship Mode|Customer ID|  Customer Name|  Segment|      Country|           City|     State|Postal Code|Region|     Product ID|       Category|Sub-Category|        Product Name|   Sales|
+------+--------------+----------+----------+--------------+-----------+---------------+---------+-------------+---------------+----------+-----------+------+---------------+---------------+------------+--------------------+--------+
|     1|CA-2017-152156|2017-11-08|11/11/2017|  Second Class|   CG-12520|    Claire Gute| Consumer|United States|      Henderson|  Kentucky|      42420|

# 4. Feature Engineering

In [None]:
from pyspark.sql.functions import year, weekofyear, sum as _sum, avg, col, lag
from pyspark.sql.window import Window

# 1. Aggregation: Order Level -> Weekly Category Level
# Aggregate data to "Weekly Sales per Sub-Category" to reduce sparsity
df_weekly = df_final.withColumn("Year", year("Order Date")) \
                    .withColumn("Week", weekofyear("Order Date")) \
                    .groupBy("Year", "Week", "Sub-Category") \
                    .agg(_sum("Sales").alias("Weekly_Sales"))

# 2. Window Definition for Time-Series Features
# Partition by Product Category, Ordered by Time
window_spec = Window.partitionBy("Sub-Category").orderBy("Year", "Week")

# 3. Feature Generation (Lag & Rolling Metrics)
# Lag_1_Week: Sales from the previous week (Autocorrelation)
# Rolling_Avg_4_Weeks: Trend indicator over the last month
df_features = df_weekly.withColumn("Lag_1_Week", lag("Weekly_Sales", 1).over(window_spec)) \
                       .withColumn("Lag_4_Weeks", lag("Weekly_Sales", 4).over(window_spec)) \
                       .withColumn("Rolling_Avg_4_Weeks", avg("Weekly_Sales").over(window_spec.rowsBetween(-3, 0)))

# 4. Drop Nulls Created by Lag
df_model_data = df_features.dropna()

print("Feature Engineering Complete")
df_model_data.show(5)

Feature Engineering Complete
+----+----+------------+------------------+------------------+------------------+-------------------+
|Year|Week|Sub-Category|      Weekly_Sales|        Lag_1_Week|       Lag_4_Weeks|Rolling_Avg_4_Weeks|
+----+----+------------+------------------+------------------+------------------+-------------------+
|2015|   7| Accessories|474.41999999999996|            115.36|              31.2|           463.8425|
|2015|   8| Accessories|             62.31|474.41999999999996|            796.69| 280.24749999999995|
|2015|  10| Accessories|            479.97|             62.31|             468.9|            283.015|
|2015|  11| Accessories|115.75999999999999|            479.97|            115.36|            283.115|
|2015|  12| Accessories|            170.24|115.75999999999999|474.41999999999996|             207.07|
+----+----+------------+------------------+------------------+------------------+-------------------+
only showing top 5 rows


# 5. Modeling

In [None]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator

# 1. Vectorization
# Assemble input features into a single vector column for Spark ML
feature_cols = ["Week", "Lag_1_Week", "Lag_4_Weeks", "Rolling_Avg_4_Weeks"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
df_final_prep = assembler.transform(df_model_data)

# 2. Time-Based Train/Test Split
# Use time-based split instead of random split to prevent data leakage
split_year = 2017
train_data = df_final_prep.filter(col("Year") < split_year)
test_data = df_final_prep.filter(col("Year") >= split_year)

print(f"Training Records (Pre-{split_year}): {train_data.count()}")
print(f"Test Records ({split_year}+): {test_data.count()}")

# 3. Model Training
# RandomForest is selected for its ability to handle non-linear relationships
rf = RandomForestRegressor(featuresCol="features", labelCol="Weekly_Sales", numTrees=100)
model = rf.fit(train_data)

# 4. Prediction & Evaluation
predictions = model.transform(test_data)
evaluator = RegressionEvaluator(labelCol="Weekly_Sales", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)

print("\n" + "="*40)
print(f"Model Performance (RMSE): {rmse:.2f}")
print("="*40)

predictions.select("Year", "Week", "Sub-Category", "Weekly_Sales", "prediction").show(5)

Training Records (Pre-2017): 1200
Test Records (2017+): 1413

Model Performance (RMSE): 1164.46
+----+----+------------+------------------+------------------+
|Year|Week|Sub-Category|      Weekly_Sales|        prediction|
+----+----+------------+------------------+------------------+
|2017|   1| Accessories|387.15200000000004|2716.7355204159117|
|2017|   2| Accessories|             674.9|402.74854129549954|
|2017|   3| Accessories|           863.706|511.50262330915444|
|2017|   4| Accessories|             99.98|374.82522257151754|
|2017|   6| Accessories|221.82999999999998|  541.020372455817|
+----+----+------------+------------------+------------------+
only showing top 5 rows


# 6. Inventory Optimization

In [None]:
from pyspark.sql.functions import sqrt, mean, round, lit

# 1. Calculate Category-Specific Error (RMSE)
# Different categories have different volatilities
df_with_error = predictions.withColumn("Error", col("Weekly_Sales") - col("prediction"))

category_stats = df_with_error.groupBy("Sub-Category").agg(
    (sqrt(mean(col("Error")**2))).alias("Category_RMSE")
)

# 2. Inventory Policy Definition
# Target Service Level: 95% (Z-Score approx. 1.65)
Z_SCORE = 1.65

# 3. Generate Replenishment Plan
# Total Inventory = Cycle Stock (Predicted Demand) + Safety Stock (Buffer)
df_inventory = predictions.join(category_stats, "Sub-Category", "left")

df_final_plan = df_inventory.withColumn("Safety_Stock", round(col("Category_RMSE") * Z_SCORE, 0)) \
                            .withColumn("Cycle_Stock", round(col("prediction"), 0)) \
                            .withColumn("Total_Inventory_Needed", col("Cycle_Stock") + col("Safety_Stock"))

print("=== Final Inventory Recommendations (Top Needs) ===")
df_final_plan.select("Year", "Week", "Sub-Category",
                     "Cycle_Stock", "Safety_Stock", "Total_Inventory_Needed") \
             .orderBy(col("Total_Inventory_Needed").desc()) \
             .show(10)

=== Final Inventory Recommendations (Top Needs) ===
+----+----+------------+-----------+------------+----------------------+
|Year|Week|Sub-Category|Cycle_Stock|Safety_Stock|Total_Inventory_Needed|
+----+----+------------+-----------+------------+----------------------+
|2018|  30|     Copiers|     4337.0|      8235.0|               12572.0|
|2018|  46|     Copiers|     4153.0|      8235.0|               12388.0|
|2017|  26|     Copiers|     4137.0|      8235.0|               12372.0|
|2017|  21|     Copiers|     4088.0|      8235.0|               12323.0|
|2018|  12|     Copiers|     3980.0|      8235.0|               12215.0|
|2017|  49|     Copiers|     3912.0|      8235.0|               12147.0|
|2017|  51|     Copiers|     3839.0|      8235.0|               12074.0|
|2018|  52|     Copiers|     3779.0|      8235.0|               12014.0|
|2018|  44|     Copiers|     3488.0|      8235.0|               11723.0|
|2018|  19|     Copiers|     3482.0|      8235.0|               11717.0|