In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error
from prophet import Prophet

import sys
import os

# Add the parent directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# Now you can import from src
from src.data_utils import load_and_process_taxi_data

import pandas as pd

import dagshub
dagshub.init(repo_owner="gourimenon8", repo_name="sp25_taxi", mlflow=True)


# Create your own data transformation function
def create_time_series_data(df):
    """
    Transform raw taxi data into time series format
    """
    # Implement your own transformation logic
    if df.empty:
        return pd.DataFrame()
    
    # Extract hour from pickup datetime
    df['pickup_hour'] = df['pickup_datetime'].dt.floor('H')
    
    # Count rides by location and hour
    ts_data = df.groupby(['pickup_location_id', 'pickup_hour']).size().reset_index(name='rides')
    
    return ts_data

# Main analysis
try:
    rides1 = load_and_process_taxi_data(year=2022)
    rides2=(load_and_process_taxi_data(year=2023))
    rides = pd.concat([rides1,rides2],ignore_index=True)
    # If data loading worked, process it
    if not (rides1.empty or rides2.empty):
        rides = pd.concat([rides1, rides2], ignore_index=True)
        ts_data = create_time_series_data(rides)
    else:
        # Create sample data for demonstration
        print("Using sample data for demonstration")
        # Generate synthetic time series data
        hours = pd.date_range(start='2022-01-01', end='2023-12-31', freq='H')
        locations = [43, 151, 239]  # Example location IDs
        
        data = []
        for loc in locations:
            # Create time pattern with daily and weekly seasonality
            for h in hours:
                # Hourly pattern (more rides during day, fewer at night)
                hour_factor = np.sin(h.hour / 24 * 2 * np.pi) + 1.5
                # Weekly pattern (weekdays vs weekends)
                day_factor = 0.7 if h.dayofweek >= 5 else 1.0
                # Base count with some randomness
                count = int(max(0, 10 * hour_factor * day_factor + np.random.normal(0, 3)))
                data.append([loc, h, count])
        
        ts_data = pd.DataFrame(data, columns=['pickup_location_id', 'pickup_hour', 'rides'])
    
    # Select a location for analysis
    location_id = 43  # Choose a location ID relevant to your analysis
    prop_df = ts_data[ts_data["pickup_location_id"] == location_id].copy()
    
    # Drop unnecessary columns and rename for Prophet
    prop_df = prop_df.drop(columns=["pickup_location_id"])
    prop_df = prop_df.rename(columns={'pickup_hour': 'ds', 'rides': 'y'})
    
    # Ensure correct data types
    prop_df['ds'] = pd.to_datetime(prop_df['ds'])
    prop_df['y'] = pd.to_numeric(prop_df['y'])
    
    # Train Prophet model
    model = Prophet()
    model.fit(prop_df)
    
    # Create forecast
    future = model.make_future_dataframe(periods=12, freq='H')
    forecast = model.predict(future)
    
    # Display results
    print("Last 12 predictions:")
    print(forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(12))
    
    # Plot results
    fig = model.plot(forecast)
    plt.title(f"Ride Forecast for Location {location_id}")
    plt.savefig("forecast_plot.png")
    
    # Plot components
    fig_comp = model.plot_components(forecast)
    plt.savefig("forecast_components.png")
    
    # Evaluate model
    results = prop_df.merge(forecast[['ds', 'yhat']], on='ds', how='inner')
    mae = mean_absolute_error(results.dropna()['y'], results.dropna()['yhat'])
    print(f"Mean Absolute Error: {mae:.2f}")
    
    # Try hyperparameter tuning
    print("\nTuning model hyperparameters...")
    models = []
    maes = []
    
    # Test different changepoint_prior_scale values
    for changepoint_prior_scale in [0.001, 0.01, 0.1, 0.5]:
        m = Prophet(changepoint_prior_scale=changepoint_prior_scale)
        m.fit(prop_df)
        future = m.make_future_dataframe(periods=12, freq='H')
        fcst = m.predict(future)
        
        # Calculate MAE
        results = prop_df.merge(fcst[['ds', 'yhat']], on='ds', how='inner')
        current_mae = mean_absolute_error(results.dropna()['y'], results.dropna()['yhat'])
        
        models.append(m)
        maes.append(current_mae)
        print(f"  changepoint_prior_scale={changepoint_prior_scale}, MAE: {current_mae:.2f}")
    
    # Find best model
    best_idx = np.argmin(maes)
    best_model = models[best_idx]
    best_params = f"changepoint_prior_scale={[0.001, 0.01, 0.1, 0.5][best_idx]}"
    print(f"\nBest model: {best_params}, MAE: {maes[best_idx]:.2f}")
    
    # Generate final forecast with best model
    future = best_model.make_future_dataframe(periods=12, freq='H')
    final_forecast = best_model.predict(future)
    
    # Plot final forecast
    fig = best_model.plot(final_forecast)
    plt.title(f"Best Model Forecast (Location {location_id}, {best_params})")
    plt.savefig("best_forecast_plot.png")
    
except Exception as e:
    print(f"An error occurred: {e}")
    print("Make sure you have the required data files and libraries installed.")

File already exists for 2022-01.
Loading data for 2022-01...
Total records: 2,463,931
Valid records: 2,415,141
Records dropped: 48,790 (1.98%)
Successfully processed data for 2022-01.
File already exists for 2022-02.
Loading data for 2022-02...
Total records: 2,979,431
Valid records: 2,921,118
Records dropped: 58,313 (1.96%)
Successfully processed data for 2022-02.
File already exists for 2022-03.
Loading data for 2022-03...
Total records: 3,627,882
Valid records: 3,551,986
Records dropped: 75,896 (2.09%)
Successfully processed data for 2022-03.
File already exists for 2022-04.
Loading data for 2022-04...
Total records: 3,599,920
Valid records: 3,522,113
Records dropped: 77,807 (2.16%)
Successfully processed data for 2022-04.
File already exists for 2022-05.
Loading data for 2022-05...
Total records: 3,588,295
Valid records: 3,509,056
Records dropped: 79,239 (2.21%)
Successfully processed data for 2022-05.
File already exists for 2022-06.
Loading data for 2022-06...
Total records: 3,55