# OLS Reconciliation with Spark

## Feng Li

### Guanghua School of Management
### Peking University


### [feng.li@gsm.pku.edu.cn](feng.li@gsm.pku.edu.cn)
### Course home page: [https://feng.li/bdcf](https://feng.li/bdcf)

In [1]:
import os, sys # Ensure All environment variables are properly set 
# os.environ["JAVA_HOME"] = os.path.dirname(sys.executable)
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

from pyspark.sql import SparkSession # build Spark Session
spark = SparkSession.builder \
    .config("spark.ui.enabled", "false") \
    .config("spark.executor.memory", "16g") \
    .config("spark.executor.cores", "4") \
    .config("spark.cores.max", "32") \
    .config("spark.driver.memory", "30g") \
    .config("spark.sql.shuffle.partitions", "96") \
    .config("spark.memory.fraction", "0.8") \
    .config("spark.memory.storageFraction", "0.5") \
    .config("spark.dynamicAllocation.enabled", "true") \
    .config("spark.dynamicAllocation.minExecutors", "4") \
    .config("spark.dynamicAllocation.maxExecutors", "8") \
    .appName("Spark Forecasting").getOrCreate()
spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/25 14:23:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
train_sdf = spark.read.csv("../data/tourism/tourism_train.csv", header=True, inferSchema=True)
test_sdf = spark.read.csv("../data/tourism/tourism_test.csv", header=True, inferSchema=True)
forecast_sdf = spark.read.csv("../data/tourism/ets_forecasts.csv", header=True, inferSchema=True)

In [3]:
train_sdf.show()

+----------+---------------+------------------+
|      date|Region_Category|          Visitors|
+----------+---------------+------------------+
|1998-01-01|       TotalAll|45151.071280099975|
|1998-01-01|           AAll|17515.502379600006|
|1998-01-01|           BAll|10393.618015699998|
|1998-01-01|           CAll| 8633.359046599999|
|1998-01-01|           DAll|3504.3133462000005|
|1998-01-01|           EAll|      3121.6191894|
|1998-01-01|           FAll|1850.7357734999998|
|1998-01-01|           GAll|131.92352909999997|
|1998-01-01|          AAAll|      4977.2096105|
|1998-01-01|          ABAll| 5322.738721600001|
|1998-01-01|          ACAll|3569.6213724000004|
|1998-01-01|          ADAll|      1472.9706096|
|1998-01-01|          AEAll|      1560.5142545|
|1998-01-01|          AFAll|        612.447811|
|1998-01-01|          BAAll|       3854.672582|
|1998-01-01|          BBAll|1653.9957826000002|
|1998-01-01|          BCAll|      2138.7473162|
|1998-01-01|          BDAll|      1395.3

In [7]:
test_sdf.show()

+----------+---------------+------------------+
|      date|Region_Category|          Visitors|
+----------+---------------+------------------+
|2015-12-01|       TotalAll|24982.024449599998|
|2015-12-01|           AAll|      7166.1237555|
|2015-12-01|           BAll|      5340.1512778|
|2015-12-01|           CAll| 5621.323498100002|
|2015-12-01|           DAll|1871.7627924999997|
|2015-12-01|           EAll|      3662.9183909|
|2015-12-01|           FAll| 903.1237441000002|
|2015-12-01|           GAll|416.62099070000016|
|2015-12-01|          AAAll|2107.3938028000007|
|2015-12-01|          ABAll|      2084.6943991|
|2015-12-01|          ACAll|       829.2701471|
|2015-12-01|          ADAll| 838.6748597999999|
|2015-12-01|          AEAll|       760.0565663|
|2015-12-01|          AFAll|       546.0339804|
|2015-12-01|          BAAll|2659.1528845999997|
|2015-12-01|          BBAll|       645.4208704|
|2015-12-01|          BCAll|       576.0730749|
|2015-12-01|          BDAll| 768.1817073

In [6]:
forecast_sdf.show()

+----------+---------------+------------------+
|      date|Region_Category|          Forecast|
+----------+---------------+------------------+
|2015-12-01|         AAAAll| 2058.838101212888|
|2016-01-01|         AAAAll|3162.9085270260384|
|2016-02-01|         AAAAll|1744.1909768938476|
|2016-03-01|         AAAAll|2059.3010229302345|
|2016-04-01|         AAAAll|  2060.36170915585|
|2016-05-01|         AAAAll| 1972.482680954841|
|2016-06-01|         AAAAll| 1846.108381522378|
|2016-07-01|         AAAAll| 2151.959971338432|
|2016-08-01|         AAAAll| 1872.768734964083|
|2016-09-01|         AAAAll|2014.0112033543805|
|2016-10-01|         AAAAll|2296.4055573391643|
|2016-11-01|         AAAAll| 2043.693618904705|
|2015-12-01|         AAABus|455.55263433165527|
|2016-01-01|         AAABus| 296.1898590727801|
|2016-02-01|         AAABus| 453.3714726155002|
|2016-03-01|         AAABus|  525.339287022726|
|2016-04-01|         AAABus| 459.2040763240983|
|2016-05-01|         AAABus| 554.7602408

In [10]:
%%script echo skipping

# If you want to have alternative forecasts, change the following code to obtain `forecast_sdf`

from pyspark.sql.functions import explode, col
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, DoubleType, DateType
import pandas as pd
from pandas.tseries.offsets import MonthBegin
from statsmodels.tsa.holtwinters import ExponentialSmoothing

# Define schema for the forecast output
forecast_schema = StructType([
    StructField("date", DateType(), False),
    StructField("Region_Category", StringType(), False),
    StructField("Forecast", DoubleType(), False)
])


def ets_forecast(pdf):
    """Fits an ETS model for a single region and forecasts 12 months ahead."""
    region = pdf["Region_Category"].iloc[0]  # Extract region name
    pdf = pdf.sort_values("date")  # Ensure time series is sorted

    try:
        # Drop missing values
        ts = pdf["Visitors"].dropna()

        if len(ts) >= 24:  # Ensure at least 24 observations
            model = ExponentialSmoothing(ts, trend="add", seasonal="add", seasonal_periods=12)
            fitted_model = model.fit()
            forecast = fitted_model.forecast(steps=12)
        else:
            forecast = [None] * 12  # Not enough data
    except:
        forecast = [None] * 12  # Handle errors

    # Adjust forecast dates to start of the month
    last_date = pdf["date"].max()
    future_dates = pd.date_range(start=last_date, periods=12, freq="ME") + MonthBegin(1)

    # Return results as a DataFrame
    return pd.DataFrame({"date": future_dates, "Region_Category": region, "Forecast": forecast})

# Apply the ETS model in parallel using applyInPandas
forecast_sdf = train_sdf.groupBy("Region_Category").applyInPandas(ets_forecast, schema=forecast_schema)

# Show forecasted results
forecast_sdf.show()

# Save forecasts if needed
forecast_sdf.write.csv(os.path.expanduser("~/ets_forecasts.csv"), header=True, mode="overwrite")


skipping


## MinT-OLS approximation

Since PySpark doesn't support matrix operations like NumPy, we'll use the **MinT-OLS approximation**, which assumes the forecast error covariance matrix \( W = I \) (identity matrix). This simplifies the formula:

$$
\tilde{y} = S (S^\top S)^{-1} S^\top \hat{y}
$$



- `forecast_sdf` contains **base forecasts** for each `Region_Category` and `date`.
- `test_sdf` is your **test set** with actual `Visitors` by `Region_Category` and `date`.
- You have `summing_sdf_long` with:  
  - `Parent_Group`  
  - `Region_Category`  
  - `Weight` (usually 0 or 1)


In [11]:
from pyspark.sql.functions import col, sum as spark_sum


# Load the summing matrix file
summing_matrix_path = "../data/tourism/agg_mat.csv"  # Update with actual path

# Load the summing matrix file (skip the first column)
summing_sdf = spark.read.csv(summing_matrix_path, header=True, inferSchema=True)


# Convert from wide format to long format (Region_Category, Parent_Group, Weight)
summing_sdf_long = summing_sdf.selectExpr(
    "Parent_Group",
    "stack(" + str(len(summing_sdf.columns) - 1) + ", " +
    ", ".join([f"'{col}', {col}" for col in summing_sdf.columns if col != "Parent_Group"]) +
    ") as (Region_Category, Weight)"
)

# Show the reshaped summing matrix
summing_sdf_long.show()

25/03/25 14:59:17 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+------------+---------------+------+
|Parent_Group|Region_Category|Weight|
+------------+---------------+------+
|    TotalAll|         AAAHol|   1.0|
|    TotalAll|         AAAVis|   1.0|
|    TotalAll|         AAABus|   1.0|
|    TotalAll|         AAAOth|   1.0|
|    TotalAll|         AABHol|   1.0|
|    TotalAll|         AABVis|   1.0|
|    TotalAll|         AABBus|   1.0|
|    TotalAll|         AABOth|   1.0|
|    TotalAll|         ABAHol|   1.0|
|    TotalAll|         ABAVis|   1.0|
|    TotalAll|         ABABus|   1.0|
|    TotalAll|         ABAOth|   1.0|
|    TotalAll|         ABBHol|   1.0|
|    TotalAll|         ABBVis|   1.0|
|    TotalAll|         ABBBus|   1.0|
|    TotalAll|         ABBOth|   1.0|
|    TotalAll|         ACAHol|   1.0|
|    TotalAll|         ACAVis|   1.0|
|    TotalAll|         ACABus|   1.0|
|    TotalAll|         ACAOth|   1.0|
+------------+---------------+------+
only showing top 20 rows



In [13]:
mint_input_sdf = forecast_sdf.join(summing_sdf_long, on="Region_Category", how="inner")

# Multiply base forecast by structure weights
mint_input_sdf = mint_input_sdf.withColumn("Weighted_Forecast", col("Forecast") * col("Weight"))

# Sum up forecasts to each parent group
mint_reconciled_sdf = mint_input_sdf.groupBy("date", "Parent_Group").agg(
    spark_sum("Weighted_Forecast").alias("Reconciled_Forecast")
)
mint_reconciled_sdf.show()



+----------+------------+-------------------+
|      date|Parent_Group|Reconciled_Forecast|
+----------+------------+-------------------+
|2015-12-01|      DBBHol|  93.30359324977445|
|2015-12-01|      BEDBus|  9.587146785063556|
|2015-12-01|      BEAHol|  5.495956927738096|
|2015-12-01|      ADBBus|  33.48580576600097|
|2015-12-01|      ADBHol|  67.11613382657367|
|2015-12-01|       FCVis|  49.72621582615672|
|2015-12-01|       EABus|  337.9295021569872|
|2015-12-01|       ACBus|  72.84836998814711|
|2015-12-01|      GACAll| 25.191637997601127|
|2016-01-01|      DCBHol|  122.5716884443587|
|2016-01-01|      BBAHol| 1215.3210703251198|
|2016-01-01|      AFAHol| 231.28278439478635|
|2016-01-01|      ABAVis|  444.3775568094911|
|2016-01-01|      ABAHol|  738.9134710470801|
|2016-01-01|       GAVis|   93.9877077354365|
|2016-01-01|       BEHol|  499.4794722311132|
|2016-01-01|       AABus| 324.74260821819576|
|2016-01-01|      DABAll|  53.60525695425169|
|2016-01-01|        GAll| 360.1925

                                                                                

In [14]:
test_parent_sdf = test_sdf.join(summing_sdf_long, on="Region_Category", how="inner") \
    .groupBy("date", "Parent_Group") \
    .agg(spark_sum("Visitors").alias("Actual_Visitors"))

from pyspark.sql.functions import abs, mean

evaluation_sdf = mint_reconciled_sdf.join(test_parent_sdf, on=["date", "Parent_Group"], how="inner") \
    .withColumn("APE", abs((col("Reconciled_Forecast") - col("Actual_Visitors")) / col("Actual_Visitors")))

mape_sdf = evaluation_sdf.groupBy("Parent_Group").agg(mean("APE").alias("MAPE"))
mape_sdf.show()



+------------+------------------+
|Parent_Group|              MAPE|
+------------+------------------+
|      CBDAll|0.9918551321816805|
|       BCHol|0.9857934130236617|
|      BCBOth|0.9998273964952409|
|      DDBHol|0.9965713950479036|
|      CCBAll|0.9872982701563636|
|       CCOth|0.9971415937249763|
|      DCCAll|0.9987860340371489|
|      BDEAll|0.9986953765843504|
|      FBAVis|0.9996472874910226|
|      EABVis|0.9849344998635591|
|      GABVis|0.9997748506792165|
|      ADBAll|0.9903498750139902|
|      FAAHol|0.9929209612350367|
|      BDFAll|0.9989229459704517|
|      CBCHol| 0.996186434565634|
|      GBCAll|0.9974573023722869|
|      CDBHol| 0.998463935084155|
|      BEGAll|0.9955394793307324|
|       DABus|0.9938826685744151|
|      DAAVis|0.9899752829805119|
+------------+------------------+
only showing top 20 rows



                                                                                

In [15]:
# Compute overall mean MAPE across all Region_Category
overall_mape_ols = mape_sdf.agg(mean("MAPE").alias("Overall_MAPE"))

# Show the result
overall_mape_ols.show()



+------------------+
|      Overall_MAPE|
+------------------+
|0.9860567355464035|
+------------------+



                                                                                

## Summary of MinT-OLS

- **MinT-OLS (simple projection)**  
- **No covariance matrix needed**  
- **Coherent forecasts at Parent_Group level**  
- **Evaluated using MAPE**
