In [1]:
import pandas as pd
import numpy as np
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
import os

pd.set_option("display.max_rows", 500)

In [2]:
%pwd

'/Users/harendrakumar/Documents/Demand_forecast/research'

In [3]:
os.chdir("../")

In [4]:
%pwd

'/Users/harendrakumar/Documents/Demand_forecast'

In [5]:
df = pd.read_csv("artifacts/data_transformation/train.csv")
df.index

RangeIndex(start=0, stop=46, step=1)

In [6]:
df['Date'] = pd.to_datetime(df['Date'])

In [7]:
df = df[['Date', 'Product_ID', 'Demand', 'Inventory']]
df = df.set_index('Date')

In [8]:
df = df['Demand']

In [9]:
df

Date
2023-06-01     51
2023-06-02    141
2023-06-03    172
2023-06-04     91
2023-06-05    198
2023-06-06     70
2023-06-07     95
2023-06-08     53
2023-06-09    136
2023-06-10    168
2023-06-11    126
2023-06-12    135
2023-06-13    198
2023-06-14    135
2023-06-15    120
2023-06-16     67
2023-06-17    190
2023-06-18    196
2023-06-19    125
2023-06-20    143
2023-06-21    107
2023-06-22    108
2023-06-23     56
2023-06-24     69
2023-06-25     52
2023-06-26     93
2023-06-27     83
2023-06-28    135
2023-06-29     56
2023-06-30    152
2023-07-01    142
2023-07-02    183
2023-07-03     98
2023-07-04     95
2023-07-05     78
2023-07-06    108
2023-07-07    191
2023-07-08    146
2023-07-09     84
2023-07-10    125
2023-07-11     70
2023-07-12     96
2023-07-13    130
2023-07-14    174
2023-07-15    157
2023-07-16    128
Name: Demand, dtype: int64

In [10]:
len(df)

46

In [11]:
df

Date
2023-06-01     51
2023-06-02    141
2023-06-03    172
2023-06-04     91
2023-06-05    198
2023-06-06     70
2023-06-07     95
2023-06-08     53
2023-06-09    136
2023-06-10    168
2023-06-11    126
2023-06-12    135
2023-06-13    198
2023-06-14    135
2023-06-15    120
2023-06-16     67
2023-06-17    190
2023-06-18    196
2023-06-19    125
2023-06-20    143
2023-06-21    107
2023-06-22    108
2023-06-23     56
2023-06-24     69
2023-06-25     52
2023-06-26     93
2023-06-27     83
2023-06-28    135
2023-06-29     56
2023-06-30    152
2023-07-01    142
2023-07-02    183
2023-07-03     98
2023-07-04     95
2023-07-05     78
2023-07-06    108
2023-07-07    191
2023-07-08    146
2023-07-09     84
2023-07-10    125
2023-07-11     70
2023-07-12     96
2023-07-13    130
2023-07-14    174
2023-07-15    157
2023-07-16    128
Name: Demand, dtype: int64

In [12]:
model = SARIMAX(df, order= (1, 1, 1), seasonal_order= (1, 1, 1, 2))
model_fit = model.fit(disp=False)
predictions = model_fit.predict(start=len(df), end=len(df) + 16 -1)
predictions = predictions.astype(int)
predictions

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


2023-07-17    118
2023-07-18    131
2023-07-19    118
2023-07-20    132
2023-07-21    119
2023-07-22    132
2023-07-23    119
2023-07-24    133
2023-07-25    120
2023-07-26    133
2023-07-27    120
2023-07-28    134
2023-07-29    120
2023-07-30    134
2023-07-31    121
2023-08-01    134
Freq: D, Name: predicted_mean, dtype: int64

In [None]:
future_dates = pd.date_range(start=df.index[-1] + pd.DateOffset(1), periods=16, freq='D')
future_dates

In [None]:
predictions = predictions.to_list()
pd.Series(data=predictions, index=future_dates)

In [13]:
test = pd.read_csv("artifacts/data_transformation/test.csv")
test['Date'] = pd.to_datetime(test['Date'])

test = test.set_index(test['Date'])
test = test['Demand']
test

Date
2023-07-17    100
2023-07-18    199
2023-07-19     99
2023-07-20     88
2023-07-21    123
2023-07-22     63
2023-07-23    126
2023-07-24    190
2023-07-25    153
2023-07-26     71
2023-07-27    158
2023-07-28    174
2023-07-29     72
2023-07-30     52
2023-07-31    188
2023-08-01    102
Name: Demand, dtype: int64

In [14]:
def eval_metrics(actual, pred):
        rmse = np.sqrt(mean_squared_error(actual, pred))
        mae = mean_absolute_error(actual, pred)
        r2 = r2_score(actual, pred)
        return rmse, mae, r2

In [15]:
eval_metrics(actual=test, pred=predictions)

(48.629980464729776, 43.0, -0.052934751605295594)