In [2]:
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from sklearn.metrics import mean_squared_error
import pickle
import os

# Load dataset
url = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv'
df = pd.read_csv(url, index_col='Month', parse_dates=True)

# Train/test split
train_size = int(len(df) * 0.8)
train, test = df[:train_size], df[train_size:]

# Models dictionary to store fitted models and forecasts
models = {}

# --- Model 1: ARIMA ---
arima_model = ARIMA(train, order=(5, 1, 0))
arima_fit = arima_model.fit()
arima_forecast = arima_fit.forecast(steps=len(test))
arima_mse = mean_squared_error(test, arima_forecast)
models['ARIMA'] = {'model': arima_fit, 'forecast': arima_forecast, 'mse': arima_mse}

# --- Model 2: SARIMA ---
sarima_model = SARIMAX(train, order=(1, 1, 1), seasonal_order=(1, 1, 1, 12))
sarima_fit = sarima_model.fit(disp=False)
sarima_forecast = sarima_fit.forecast(steps=len(test))
sarima_mse = mean_squared_error(test, sarima_forecast)
models['SARIMA'] = {'model': sarima_fit, 'forecast': sarima_forecast, 'mse': sarima_mse}

# --- Model 3: Exponential Smoothing (Holt-Winters) ---
hw_model = ExponentialSmoothing(train, seasonal='add', seasonal_periods=12)
hw_fit = hw_model.fit()
hw_forecast = hw_fit.forecast(steps=len(test))
hw_mse = mean_squared_error(test, hw_forecast)
models['Holt-Winters'] = {'model': hw_fit, 'forecast': hw_forecast, 'mse': hw_mse}

# Save all models and forecasts
with open('time_series_models.pkl', 'wb') as f:
    pickle.dump(models, f)

# Plot and save graphs for each model
if not os.path.exists('static'):
    os.makedirs('static')

for model_name, model_info in models.items():
    plt.figure(figsize=(10, 6))
    plt.plot(test, label='Actual')
    plt.plot(model_info['forecast'], label=f'{model_name} Forecast')
    plt.title(f'{model_name} Model - Test MSE: {model_info["mse"]:.4f}')
    plt.legend()
    plt.savefig(f'static/{model_name}_forecast.png')
    plt.close()

# Print MSE for all models
for model_name, model_info in models.items():
    print(f'{model_name} Test MSE: {model_info["mse"]:.4f}')


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


ARIMA Test MSE: 6506.6721
SARIMA Test MSE: 908.3542
Holt-Winters Test MSE: 6104.7073
