<a href="https://colab.research.google.com/github/eghib22/Store-Sales-Forecasting/blob/main/model_experiment_sarima.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
! mkdir ~/.kaggle

mkdir: cannot create directory ‘/root/.kaggle’: File exists


In [3]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"anigasitashvili","key":"373c2b9f2c34a9b29562a5a4a9b08fcf"}'}

In [4]:
!mv "kaggle.json" ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

In [5]:
!ls -l ~/.kaggle/

total 4
-rw------- 1 root root 71 Jul  6 13:57 kaggle.json


In [6]:
!kaggle competitions download -c walmart-recruiting-store-sales-forecasting

walmart-recruiting-store-sales-forecasting.zip: Skipping, found more recently modified local copy (use --force to force download)


In [7]:
! unzip walmart-recruiting-store-sales-forecasting

Archive:  walmart-recruiting-store-sales-forecasting.zip
replace features.csv.zip? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: features.csv.zip        
replace sampleSubmission.csv.zip? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: sampleSubmission.csv.zip  
replace stores.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: stores.csv              
replace test.csv.zip? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: test.csv.zip            
replace train.csv.zip? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: train.csv.zip           


In [8]:
!unzip '*.csv.zip'

Archive:  test.csv.zip
replace test.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: test.csv                

Archive:  features.csv.zip
replace features.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: features.csv            

Archive:  sampleSubmission.csv.zip
replace sampleSubmission.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: sampleSubmission.csv    

Archive:  train.csv.zip
replace train.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: train.csv               

4 archives were successfully processed.


In [9]:
!unzip '*.csv.zip'

Archive:  test.csv.zip
replace test.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: test.csv                

Archive:  features.csv.zip
replace features.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: features.csv            

Archive:  sampleSubmission.csv.zip
replace sampleSubmission.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: sampleSubmission.csv    

Archive:  train.csv.zip
replace train.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: train.csv               

4 archives were successfully processed.


In [10]:
!pip install wandb
import wandb
wandb.login()
wandb.init(project="Store-Sales-Forecasting", entity="agasi22-free-university-of-tbilisi-", name="sarima-training-run")



[34m[1mwandb[0m: Currently logged in as: [33magasi22[0m ([33magasi22-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import gc

from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error


In [12]:
train = pd.read_csv('train.csv')
features = pd.read_csv('features.csv')
stores = pd.read_csv('stores.csv')

train['Date'] = pd.to_datetime(train['Date'])
features['Date'] = pd.to_datetime(features['Date'])


In [13]:
df = train.merge(features, on=['Store', 'Date', 'IsHoliday'], how='left')
df = df.merge(stores, on='Store', how='left')


In [14]:
exog_cols = ['Temperature', 'Fuel_Price', 'CPI', 'Unemployment']

df[exog_cols] = df[exog_cols].fillna(method='ffill').fillna(method='bfill')


In [15]:
def weighted_mae(y_true, y_pred, weights):
    return np.sum(weights * np.abs(y_true - y_pred)) / np.sum(weights)


In [16]:
import warnings
warnings.filterwarnings("ignore")

import logging
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
logging.getLogger("prophet").setLevel(logging.WARNING)

results = []
all_preds = []

store_dept_groups = df.groupby(['Store', 'Dept'])
total_groups = len(store_dept_groups)

print(f"--- Starting SARIMAX for {total_groups} Store-Dept combos ---")

for idx, ((store_id, dept_id), group) in enumerate(store_dept_groups, start=1):
    print(f"\n--- Processing Store: {store_id}, Dept: {dept_id} ({idx}/{total_groups}) ---")

    g = group.sort_values('Date')
    g = g.set_index('Date')

    y = g['Weekly_Sales']
    X = g[exog_cols]
    weights = g['IsHoliday'].apply(lambda x: 5 if x else 1)

    y_train = y[y.index < '2012-01-01']
    y_val = y[(y.index >= '2012-01-01') & (y.index < '2012-07-01')]
    X_train = X.loc[y_train.index]
    X_val = X.loc[y_val.index]
    weights_val = weights.loc[y_val.index]

    if len(y_train) < 100 or len(y_val) < 20:
        print(f"   Skipped: Not enough data (Train: {len(y_train)}, Val: {len(y_val)})")
        continue

    try:
        model = SARIMAX(
            y_train,
            exog=X_train,
            order=(1, 1, 1),
            seasonal_order=(1, 1, 1, 52),
            enforce_stationarity=False,
            enforce_invertibility=False
        )
        result = model.fit(disp=False)

        pred = result.predict(
            start=y_val.index[0],
            end=y_val.index[-1],
            exog=X_val
        )

        wmae = weighted_mae(y_val, pred, weights_val)
        rmse = np.sqrt(mean_squared_error(y_val, pred))

        print(f"   WMAE: {wmae:.2f}")

        results.append({
            'Store': store_id,
            'Dept': dept_id,
            'RMSE': rmse,
            'WMAE': wmae
        })

        all_preds.append(pd.DataFrame({
            'Date': y_val.index,
            'Store': store_id,
            'Dept': dept_id,
            'y_true': y_val.values,
            'y_pred': pred.values,
            'weight': weights_val.values
        }))

        wandb.log({
            'Store': store_id,
            'Dept': dept_id,
            'RMSE': rmse,
            'WMAE': wmae
        })

    except Exception as e:
        print(f"   Failed: {e}")
        continue

    gc.collect()

print("\n--- SARIMAX Loop Finished ---")


--- Starting SARIMAX for 3331 Store-Dept combos ---

--- Processing Store: 1, Dept: 1 (1/3331) ---
   WMAE: 24582.05

--- Processing Store: 1, Dept: 2 (2/3331) ---
   WMAE: 5603.48

--- Processing Store: 1, Dept: 3 (3/3331) ---
   WMAE: 1129.18

--- Processing Store: 1, Dept: 4 (4/3331) ---
   WMAE: 5532.35

--- Processing Store: 1, Dept: 5 (5/3331) ---
   WMAE: 10827.34

--- Processing Store: 1, Dept: 6 (6/3331) ---
   WMAE: 5312.25

--- Processing Store: 1, Dept: 7 (7/3331) ---
   WMAE: 13384.97

--- Processing Store: 1, Dept: 8 (8/3331) ---
   WMAE: 1590.26

--- Processing Store: 1, Dept: 9 (9/3331) ---
   WMAE: 2687.69

--- Processing Store: 1, Dept: 10 (10/3331) ---
   WMAE: 5272.90

--- Processing Store: 1, Dept: 11 (11/3331) ---
   WMAE: 6205.35

--- Processing Store: 1, Dept: 12 (12/3331) ---
   WMAE: 1086.13

--- Processing Store: 1, Dept: 13 (13/3331) ---
   WMAE: 1935.06

--- Processing Store: 1, Dept: 14 (14/3331) ---
   WMAE: 2875.56

--- Processing Store: 1, Dept: 16 (15/



[1;30;43mStreaming output truncated to the last 5000 lines.[0m

--- Processing Store: 22, Dept: 77 (1666/3331) ---
   Skipped: Not enough data (Train: 3, Val: 2)

--- Processing Store: 22, Dept: 78 (1667/3331) ---
   Skipped: Not enough data (Train: 6, Val: 0)

--- Processing Store: 22, Dept: 79 (1668/3331) ---
   WMAE: 1793.91

--- Processing Store: 22, Dept: 80 (1669/3331) ---
   WMAE: 278.20

--- Processing Store: 22, Dept: 81 (1670/3331) ---
   WMAE: 1397.22

--- Processing Store: 22, Dept: 82 (1671/3331) ---
   WMAE: 6430.23

--- Processing Store: 22, Dept: 83 (1672/3331) ---
   WMAE: 307.20

--- Processing Store: 22, Dept: 85 (1673/3331) ---
   WMAE: 2149.66

--- Processing Store: 22, Dept: 87 (1674/3331) ---
   WMAE: 2708.61

--- Processing Store: 22, Dept: 90 (1675/3331) ---
   WMAE: 1394.64

--- Processing Store: 22, Dept: 91 (1676/3331) ---
   WMAE: 2596.57

--- Processing Store: 22, Dept: 92 (1677/3331) ---
   WMAE: 3351.56

--- Processing Store: 22, Dept: 93 (1678/3331) -

In [17]:
all_df = pd.concat(all_preds)
overall_wmae = np.sum(all_df['weight'] * np.abs(all_df['y_true'] - all_df['y_pred'])) / np.sum(all_df['weight'])

print("Overall WMAE:", overall_wmae)

results_df = pd.DataFrame(results)
print(results_df.head())

wandb.log({'Overall_WMAE': overall_wmae})


Overall WMAE: 1.210416545752286e+31
   Store  Dept          RMSE          WMAE
0      1     1  25627.020132  24582.054540
1      1     2   6098.173468   5603.477791
2      1     3   1490.437595   1129.175814
3      1     4   6188.019766   5532.350510
4      1     5  11507.791605  10827.343425


In [18]:

wandb.finish()
