In [None]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
from statsmodels.tsa.stattools import adfuller
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from sklearn.metrics import mean_absolute_error, mean_squared_error
from pmdarima import auto_arima
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX
import mlflow
import holidays
import dagshub

# Add the parent directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# Import from src
from src.data_utils import load_and_process_taxi_data, transform_raw_data_into_ts_data

# Initialize MLflow tracking
dagshub.init(repo_owner="gourimenon8", repo_name="sp25_taxi", mlflow=True)
mlflow.set_experiment("improved_arima_model")

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("deep")

# Load the data
print("Loading data...")
rides1 = load_and_process_taxi_data(year=2022)
rides2 = load_and_process_taxi_data(year=2023)
rides = pd.concat([rides1, rides2], ignore_index=True)

# Filter for a specific pickup location
LOCATION_ID = 43  # We can parametrize this
print(f"Filtering data for location ID {LOCATION_ID}...")
temp_rides = rides[rides["pickup_location_id"] == LOCATION_ID]

# Transform into time series data
ts_data = transform_raw_data_into_ts_data(temp_rides)
ts_data = ts_data.drop(columns=["pickup_location_id"])

print(f"Time series data shape: {ts_data.shape}")
print(ts_data.head())

# Create date/time features for better analysis
ts_data["hour"] = ts_data["pickup_hour"].dt.hour
ts_data["day_of_week"] = ts_data["pickup_hour"].dt.dayofweek
ts_data["is_weekend"] = ts_data["day_of_week"].isin([5, 6]).astype(int)
ts_data["month"] = ts_data["pickup_hour"].dt.month
ts_data["day"] = ts_data["pickup_hour"].dt.day

# Add US holidays
us_holidays = holidays.US(years=[2022, 2023])
ts_data["is_holiday"] = ts_data["pickup_hour"].dt.date.isin(us_holidays).astype(int)

# Visualize time series data
plt.figure(figsize=(15, 8))
plt.plot(ts_data["pickup_hour"], ts_data["rides"])
plt.title(f"Taxi Rides for Location ID {LOCATION_ID}")
plt.xlabel("Date")
plt.ylabel("Number of Rides")
plt.tight_layout()
plt.savefig("taxi_rides_timeseries.png")
plt.close()

# Visualize daily and weekly patterns
def plot_patterns(ts_data):
    # Daily pattern
    hourly_avg = ts_data.groupby("hour")["rides"].mean()
    
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 1, 1)
    sns.barplot(x=hourly_avg.index, y=hourly_avg.values)
    plt.title("Average Rides by Hour of Day")
    plt.xlabel("Hour of Day")
    plt.ylabel("Average Rides")
    
    # Weekly pattern
    daily_avg = ts_data.groupby("day_of_week")["rides"].mean()
    day_names = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
    
    plt.subplot(2, 1, 2)
    sns.barplot(x=[day_names[i] for i in daily_avg.index], y=daily_avg.values)
    plt.title("Average Rides by Day of Week")
    plt.xlabel("Day of Week")
    plt.ylabel("Average Rides")
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig("taxi_rides_patterns.png")
    plt.close()

plot_patterns(ts_data)

# Check stationarity
def check_stationarity(timeseries):
    # Calculate rolling statistics
    rolmean = timeseries.rolling(window=24).mean()
    rolstd = timeseries.rolling(window=24).std()
    
    # Plot rolling statistics
    plt.figure(figsize=(15, 8))
    plt.plot(timeseries, label='Original')
    plt.plot(rolmean, label='Rolling Mean')
    plt.plot(rolstd, label='Rolling Std')
    plt.legend()
    plt.title('Rolling Mean & Standard Deviation')
    plt.tight_layout()
    plt.savefig("stationarity_check.png")
    plt.close()
    
    # Perform Dickey-Fuller test
    print('Results of Dickey-Fuller Test:')
    dftest = adfuller(timeseries.dropna(), autolag='AIC')
    dfoutput = pd.Series(dftest[0:4], index=['Test Statistic','p-value','#Lags Used','Number of Observations Used'])
    for key, value in dftest[4].items():
        dfoutput['Critical Value (%s)'%key] = value
    print(dfoutput)
    
    return dftest[1] <= 0.05  # Return True if stationary

is_stationary = check_stationarity(ts_data["rides"])
print(f"Is the time series stationary? {is_stationary}")

# If not stationary, differentiate the series
diff_order = 0
diff_series = ts_data["rides"].copy()
if not is_stationary:
    diff_series = ts_data["rides"].diff().dropna()
    diff_order = 1
    is_diff_stationary = check_stationarity(diff_series)
    print(f"Is the differenced series stationary? {is_diff_stationary}")

# Plot ACF and PACF to help identify (p,d,q) orders manually
plt.figure(figsize=(15, 10))
plt.subplot(211)
plot_acf(diff_series.dropna(), ax=plt.gca(), lags=48)
plt.subplot(212)
plot_pacf(diff_series.dropna(), ax=plt.gca(), lags=48)
plt.tight_layout()
plt.savefig("acf_pacf_plots.png")
plt.close()

# Split data for training and testing
train_size = int(len(ts_data) * 0.8)
train = ts_data.iloc[:train_size]
test = ts_data.iloc[train_size:]
print(f"Training set size: {train.shape}, Test set size: {test.shape}")

# Define a function to evaluate models
def evaluate_model(y_true, y_pred, model_name):
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
    
    print(f"{model_name} Performance Metrics:")
    print(f"MAE: {mae:.2f}")
    print(f"RMSE: {rmse:.2f}")
    print(f"MAPE: {mape:.2f}%")
    
    # Plot actual vs predicted
    plt.figure(figsize=(15, 8))
    plt.plot(y_true.index, y_true, label='Actual')
    plt.plot(y_true.index, y_pred, label='Predicted', alpha=0.7)
    plt.title(f'{model_name}: Actual vs Predicted')
    plt.xlabel('Date')
    plt.ylabel('Number of Rides')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{model_name.lower().replace(' ', '_')}_prediction.png")
    plt.close()
    
    return {"mae": mae, "rmse": rmse, "mape": mape}



# 1. Auto ARIMA without seasonality
print("\nFitting Auto ARIMA without seasonality...")
model_auto = auto_arima(
    train["rides"], 
    seasonal=False,
    stepwise=True,
    trace=True,
    error_action="ignore",
    suppress_warnings=True,
    max_order=None,
    d=diff_order
)
print(model_auto.summary())

# Get the best parameters
best_p, best_d, best_q = model_auto.order

# Forecast with Auto ARIMA
forecast_auto = model_auto.predict(n_periods=len(test))
auto_metrics = evaluate_model(test["rides"], forecast_auto, "Auto ARIMA (Non-Seasonal)")

# 2. SARIMA Model (Seasonal ARIMA)
print("\nFitting SARIMA with seasonality...")
# Use daily seasonality (24 hours)
seasonal_period = 24

model_sarima = auto_arima(
    train["rides"],
    seasonal=True,
    m=seasonal_period,
    stepwise=True,
    trace=True,
    error_action="ignore",
    suppress_warnings=True,
    d=diff_order,
    max_order=10
)
print(model_sarima.summary())

# Get the best parameters for SARIMA
best_p, best_d, best_q = model_sarima.order
best_P, best_D, best_Q, best_m = model_sarima.seasonal_order

# Forecast with SARIMA
forecast_sarima = model_sarima.predict(n_periods=len(test))
sarima_metrics = evaluate_model(test["rides"], forecast_sarima, "SARIMA (Seasonal)")

# 3. Manual ARIMA with external regressors (features)

# Prepare external regressors (features)
exog_train = train[["hour", "day_of_week", "is_weekend", "is_holiday"]].values
exog_test = test[["hour", "day_of_week", "is_weekend", "is_holiday"]].values

print("\nFitting ARIMAX with external regressors...")
model_arimax = SARIMAX(
    train["rides"],
    exog=exog_train,
    order=(best_p, best_d, best_q),
    seasonal_order=(0, 0, 0, 0),  # No seasonality as we're using external regressors
    enforce_stationarity=False
)
arimax_result = model_arimax.fit(disp=False)
print(arimax_result.summary())

# Forecast with ARIMAX
forecast_arimax = arimax_result.forecast(steps=len(test), exog=exog_test)
arimax_metrics = evaluate_model(test["rides"], forecast_arimax, "ARIMAX with Regressors")

# Compare all models
models = ["Auto ARIMA", "SARIMA", "ARIMAX"]
maes = [auto_metrics["mae"], sarima_metrics["mae"], arimax_metrics["mae"]]
rmses = [auto_metrics["rmse"], sarima_metrics["rmse"], arimax_metrics["rmse"]]
mapes = [auto_metrics["mape"], sarima_metrics["mape"], arimax_metrics["mape"]]

plt.figure(figsize=(15, 10))

plt.subplot(3, 1, 1)
plt.bar(models, maes)
plt.title("Mean Absolute Error (MAE)")
plt.ylabel("MAE")

plt.subplot(3, 1, 2)
plt.bar(models, rmses)
plt.title("Root Mean Squared Error (RMSE)")
plt.ylabel("RMSE")

plt.subplot(3, 1, 3)
plt.bar(models, mapes)
plt.title("Mean Absolute Percentage Error (MAPE)")
plt.ylabel("MAPE (%)")

plt.tight_layout()
plt.savefig("model_comparison.png")
plt.close()

# Determine the best model
best_model_index = np.argmin(maes)
best_model_name = models[best_model_index]
print(f"\nBest model based on MAE: {best_model_name}")

# Log results to MLflow
with mlflow.start_run():
    # Log parameters
    mlflow.log_param("best_model", best_model_name)
    
    if best_model_name == "Auto ARIMA":
        mlflow.log_param("order", model_auto.order)
    elif best_model_name == "SARIMA":
        mlflow.log_param("order", model_sarima.order)
        mlflow.log_param("seasonal_order", model_sarima.seasonal_order)
    else:  # ARIMAX
        mlflow.log_param("order", arimax_result.model.order)
        mlflow.log_param("used_regressors", ["hour", "day_of_week", "is_weekend", "is_holiday"])
    
    # Log metrics
    mlflow.log_metric("mae", maes[best_model_index])
    mlflow.log_metric("rmse", rmses[best_model_index])
    mlflow.log_metric("mape", mapes[best_model_index])
    
    # Log artifacts
    mlflow.log_artifact("taxi_rides_timeseries.png")
    mlflow.log_artifact("taxi_rides_patterns.png")
    mlflow.log_artifact("stationarity_check.png")
    mlflow.log_artifact("acf_pacf_plots.png")
    mlflow.log_artifact("model_comparison.png")
    
    if best_model_name == "Auto ARIMA":
        mlflow.log_artifact("auto_arima_prediction.png")
    elif best_model_name == "SARIMA":
        mlflow.log_artifact("sarima_prediction.png")
    else:  # ARIMAX
        mlflow.log_artifact("arimax_with_regressors_prediction.png")

print("Model training and evaluation complete!")