# NYC Taxi Data: Exploratory Data Analysis

In [24]:
import pandas as pd
import numpy as np
import duckdb
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import folium
from PIL import Image
import warnings
import os
from datetime import datetime, timedelta
from IPython.display import display, HTML

# Suppress warnings
warnings.filterwarnings('ignore')

In [25]:
# Set styling
plt.style.use('fivethirtyeight')
pd.set_option('display.max_columns', None)

### Connect to DuckDB and Load Data

In [26]:
# Initialize DuckDB connection
conn = duckdb.connect('../db/nyc_taxi.duckdb')

In [27]:
# Function to execute SQL queries with error handling
def execute_query(query, error_message="Error executing query"):
    try:
        return conn.execute(query).fetchdf()
    except Exception as e:
        print(f"{error_message}: {str(e)}")
        return pd.DataFrame()  # Return empty dataframe instead of failing

In [28]:
# Check available tables in the database
tables = execute_query("SHOW TABLES", "Error listing tables")
tables

Unnamed: 0,name
0,all_fhv_trips
1,all_taxi_trips
2,fhv_trips
3,green_taxi_trips
4,hvfhv_trips
5,nyc_weather
6,taxi_zone_lookup
7,yellow_taxi_trips


In [29]:
# First, let's examine the structure of the tables to understand available columns
all_taxi_columns = execute_query("PRAGMA table_info(all_taxi_trips)", "Error getting taxi table structure")
all_taxi_columns

Unnamed: 0,cid,name,type,notnull,dflt_value,pk
0,0,payment_type,DOUBLE,False,,False
1,1,extra,DOUBLE,False,,False
2,2,PULocationID,BIGINT,False,,False
3,3,congestion_surcharge,DOUBLE,False,,False
4,4,DOLocationID,BIGINT,False,,False
5,5,VendorID,BIGINT,False,,False
6,6,tip_amount,DOUBLE,False,,False
7,7,mta_tax,DOUBLE,False,,False
8,8,fare_amount,DOUBLE,False,,False
9,9,total_amount,DOUBLE,False,,False


In [30]:
# List column names for easy reference
if not all_taxi_columns.empty:
    print("\nAvailable columns in all_taxi_trips:")
    column_names = all_taxi_columns['name'].tolist()
    display(column_names)
else:
    column_names = []
    print("No taxi columns available.")


Available columns in all_taxi_trips:


['payment_type',
 'extra',
 'PULocationID',
 'congestion_surcharge',
 'DOLocationID',
 'VendorID',
 'tip_amount',
 'mta_tax',
 'fare_amount',
 'total_amount',
 'store_and_fwd_flag',
 'passenger_count',
 'improvement_surcharge',
 'tolls_amount',
 'RatecodeID',
 'trip_distance']

In [31]:
# Check structure of weather table if it exists
if 'nyc_weather' in tables['name'].values:
    weather_columns = execute_query("PRAGMA table_info(nyc_weather)", "Error getting weather table structure")
    display(weather_columns)

Unnamed: 0,cid,name,type,notnull,dflt_value,pk
0,0,name,VARCHAR,False,,False
1,1,datetime,DATE,False,,False
2,2,tempmax,DOUBLE,False,,False
3,3,tempmin,DOUBLE,False,,False
4,4,temp,DOUBLE,False,,False
5,5,feelslikemax,DOUBLE,False,,False
6,6,feelslikemin,DOUBLE,False,,False
7,7,feelslike,DOUBLE,False,,False
8,8,dew,DOUBLE,False,,False
9,9,humidity,DOUBLE,False,,False


### Data Overview and Basic Statistics

In [32]:
# Sample of taxi data
taxi_sample = execute_query("""
    SELECT * FROM all_taxi_trips
    LIMIT 5
""", "Error getting taxi sample")
taxi_sample

Unnamed: 0,payment_type,extra,PULocationID,congestion_surcharge,DOLocationID,VendorID,tip_amount,mta_tax,fare_amount,total_amount,store_and_fwd_flag,passenger_count,improvement_surcharge,tolls_amount,RatecodeID,trip_distance
0,1.0,3.0,142,2.5,236,1,3.65,0.5,14.5,21.95,N,2.0,0.3,0.0,1.0,3.8
1,1.0,0.5,236,0.0,42,1,4.0,0.5,8.0,13.3,N,1.0,0.3,0.0,1.0,2.1
2,1.0,0.5,166,0.0,166,2,1.76,0.5,7.5,10.56,N,1.0,0.3,0.0,1.0,0.97
3,2.0,0.5,114,2.5,68,2,0.0,0.5,8.0,11.8,N,1.0,0.3,0.0,1.0,1.09
4,1.0,0.5,68,2.5,163,2,3.0,0.5,23.5,30.3,N,1.0,0.3,0.0,1.0,4.3


In [33]:
# Check for datetime columns specifically
datetime_columns = execute_query("""
    SELECT column_name 
    FROM information_schema.columns 
    WHERE table_name = 'all_taxi_trips' 
    AND (data_type LIKE '%TIMESTAMP%' OR data_type LIKE '%DATE%' OR column_name LIKE '%time%' OR column_name LIKE '%date%')
""", "Error checking for datetime columns")
display(datetime_columns)

Unnamed: 0,column_name


In [34]:
# Sample of weather data if available
if 'nyc_weather' in tables['name'].values:
    weather_sample = execute_query("""
        SELECT * FROM nyc_weather
        LIMIT 3
    """, "Error getting weather sample")
    print("\nSample of weather data:")
    display(weather_sample)


Sample of weather data:


Unnamed: 0,name,datetime,tempmax,tempmin,temp,feelslikemax,feelslikemin,feelslike,dew,humidity,precip,precipprob,precipcover,preciptype,snow,snowdepth,windgust,windspeed,winddir,sealevelpressure,cloudcover,visibility,solarradiation,solarenergy,uvindex,severerisk,sunrise,sunset,moonphase,conditions,description,icon,stations
0,new york,2022-01-01,56.1,50.0,52.7,56.1,50.0,52.7,50.6,92.5,0.735,100,62.5,rain,0.0,0.0,13.2,8.1,139.9,1008.3,100.0,5.5,14.5,1.3,1,,2022-01-01 07:20:13,2022-01-01 16:39:22,0.96,"Rain, Overcast",Cloudy skies throughout the day with rain.,rain,"72505394728,KEWR,KLGA,72502014734,F8726,KNYC,F..."
1,new york,2022-01-02,58.4,38.3,50.4,58.4,31.5,48.9,45.6,84.1,0.087,100,16.67,rain,0.0,0.0,31.4,14.9,307.9,1004.8,92.4,7.5,25.9,2.1,2,,2022-01-02 07:20:18,2022-01-02 16:40:13,0.0,"Rain, Overcast",Cloudy skies throughout the day with rain clea...,rain,"72505394728,KEWR,KLGA,72502014734,KNYC,F1417,7..."
2,new york,2022-01-03,37.6,23.5,30.0,31.1,12.2,20.2,13.4,50.2,0.0,0,0.0,"rain,snow",0.7,0.4,31.1,17.3,0.4,1019.3,67.1,9.9,12.4,1.1,1,,2022-01-03 07:20:21,2022-01-03 16:41:06,0.03,Partially cloudy,Partly cloudy throughout the day.,partly-cloudy-day,"72505394728,KEWR,KLGA,72502014734,KNYC,F1417,7..."


### Basic statistics for numerical columns in taxi data

In [35]:
# Use only columns we know exist from our earlier check
numeric_stats_query = """
    SELECT 
        COUNT(*) as total_trips"""

# Add numeric columns that should be common in taxi data
if 'fare_amount' in column_names:
    numeric_stats_query += """,
        AVG(fare_amount) as avg_fare,
        MIN(fare_amount) as min_fare,
        MAX(fare_amount) as max_fare,
        STDDEV(fare_amount) as std_fare"""

if 'trip_distance' in column_names:
    numeric_stats_query += """,
        AVG(trip_distance) as avg_distance,
        MIN(trip_distance) as min_distance,
        MAX(trip_distance) as max_distance,
        STDDEV(trip_distance) as std_distance"""

if 'tip_amount' in column_names:
    numeric_stats_query += """,
        AVG(tip_amount) as avg_tip"""

if 'total_amount' in column_names:
    numeric_stats_query += """,
        AVG(total_amount) as avg_total"""

if 'passenger_count' in column_names:
    numeric_stats_query += """,
        AVG(passenger_count) as avg_passengers"""

numeric_stats_query += """
    FROM all_taxi_trips
"""

taxi_stats = execute_query(numeric_stats_query, "Error calculating basic statistics")
taxi_stats

Unnamed: 0,total_trips,avg_fare,min_fare,max_fare,std_fare,avg_distance,min_distance,max_distance,std_distance,avg_tip,avg_total,avg_passengers
0,121423028,16.401456,-133391414.0,401092.32,12760.427062,5.723947,0.0,398608.62,529.804328,4.638736,25.914899,1.367551


In [36]:
# Check for missing values in key columns
missing_values_query = """
    SELECT 
        COUNT(*) as total_rows"""

# Add checks for missing values in important columns
columns_to_check = ['payment_type', 'fare_amount', 'trip_distance', 
                   'passenger_count', 'PULocationID', 'DOLocationID', 
                   'tip_amount', 'total_amount']

for col in columns_to_check:
    if col in column_names:
        missing_values_query += f""",
        COUNT(*) - COUNT({col}) as missing_{col}"""

missing_values_query += """
    FROM all_taxi_trips
"""

missing_values = execute_query(missing_values_query, "Error checking for missing values")
missing_values

Unnamed: 0,total_rows,missing_payment_type,missing_fare_amount,missing_trip_distance,missing_passenger_count,missing_PULocationID,missing_DOLocationID,missing_tip_amount,missing_total_amount
0,121423028,170281,0,0,6939172,0,0,0,0


### Temporal Analysis

In [37]:
# # Check if we have a specific datetime column in the taxi data
# datetime_col = None
# for possible_name in ['pickup_datetime', 'pickup_date', 'tpep_pickup_datetime', 'lpep_pickup_datetime']:
#     if possible_name in column_names:
#         datetime_col = possible_name
#         break

# if datetime_col:
#     print(f"Found datetime column: {datetime_col}")
    
#     # Check data type to ensure we can use it in time functions
#     try:
#         # Try to do a basic query to check column usability
#         date_test = execute_query(f"""
#             SELECT {datetime_col}, COUNT(*) 
#             FROM all_taxi_trips 
#             GROUP BY {datetime_col} 
#             LIMIT 1
#         """)
        
#         # Trips by date
#         trips_by_date = execute_query(f"""
#             SELECT 
#                 DATE_TRUNC('day', {datetime_col}) as trip_date,
#                 COUNT(*) as trip_count
#             FROM all_taxi_trips
#             GROUP BY trip_date
#             ORDER BY trip_date
#             LIMIT 100  -- Limiting results for performance
#         """)
        
#         if not trips_by_date.empty:
#             print("\nTrips by date (first 100 days):")
#             display(trips_by_date.head())
            
#             # Plot trips by date using Plotly
#             fig = px.line(trips_by_date, x='trip_date', y='trip_count', 
#                         title='Daily Taxi Trip Volume',
#                         labels={'trip_date': 'Date', 'trip_count': 'Number of Trips'})
#             fig.update_layout(xaxis_title='Date', yaxis_title='Number of Trips')
#             fig.show()
        
#         # Trips by hour of day
#         trips_by_hour = execute_query(f"""
#             SELECT 
#                 EXTRACT(hour FROM {datetime_col}) as hour_of_day,
#                 COUNT(*) as trip_count
#             FROM all_taxi_trips
#             GROUP BY hour_of_day
#             ORDER BY hour_of_day
#         """)
        
#         if not trips_by_hour.empty:
#             print("\nTrips by hour of day:")
#             display(trips_by_hour)
            
#             # Create hour of day plot
#             fig = px.bar(trips_by_hour, x='hour_of_day', y='trip_count',
#                         title='Taxi Trips by Hour of Day',
#                         labels={'hour_of_day': 'Hour of Day', 'trip_count': 'Number of Trips'})
#             fig.update_layout(xaxis_title='Hour of Day', yaxis_title='Number of Trips',
#                             xaxis=dict(tickmode='linear', tick0=0, dtick=1))
#             fig.show()
        
#         # Trips by day of week using ISO day of week (1=Monday, 7=Sunday)
#         trips_by_dow = execute_query(f"""
#             SELECT 
#                 EXTRACT(DOW FROM {datetime_col}) as day_of_week_num,
#                 CASE EXTRACT(DOW FROM {datetime_col})
#                     WHEN 0 THEN 'Sunday'
#                     WHEN 1 THEN 'Monday'
#                     WHEN 2 THEN 'Tuesday'
#                     WHEN 3 THEN 'Wednesday'
#                     WHEN 4 THEN 'Thursday'
#                     WHEN 5 THEN 'Friday'
#                     WHEN 6 THEN 'Saturday'
#                 END as day_of_week,
#                 COUNT(*) as trip_count
#             FROM all_taxi_trips
#             GROUP BY day_of_week_num, day_of_week
#             ORDER BY day_of_week_num
#         """)
        
#         if not trips_by_dow.empty:
#             print("\nTrips by day of week:")
#             display(trips_by_dow)
            
#             # Create day of week plot
#             fig = px.bar(trips_by_dow, x='day_of_week', y='trip_count',
#                         title='Taxi Trips by Day of Week',
#                         labels={'day_of_week': 'Day of Week', 'trip_count': 'Number of Trips'})
#             fig.update_layout(xaxis_title='Day of Week', yaxis_title='Number of Trips')
#             fig.show()
            
#     except Exception as e:
#         print(f"Could not perform temporal analysis with {datetime_col} column: {e}")
#         print("The column may not be in a proper datetime format.")
# else:
#     print("No datetime column found in the taxi data. Skipping temporal analysis.")

## Fare Analysis

In [38]:
if 'fare_amount' in column_names:
    # Use a faster approach with binned fare amounts for the histogram
    fare_bins_query = """
        SELECT 
            FLOOR(fare_amount) as fare_bin,
            COUNT(*) as trip_count
        FROM all_taxi_trips
        WHERE fare_amount BETWEEN 0 AND 100  -- Filter out extreme outliers
        GROUP BY fare_bin
        ORDER BY fare_bin
    """
    
    fare_bins = execute_query(fare_bins_query, "Error creating fare distribution")
    
    if not fare_bins.empty:
        print("\nFare amount distribution (binned):")
        display(fare_bins.head())
        
        # Create fare distribution plot
        fig = px.bar(
            fare_bins,
            x="fare_bin",
            y="trip_count",
            title="Distribution of Fare Amounts (up to $100)",
            labels={"fare_bin": "Fare Amount ($)", "trip_count": "Number of Trips"}
        )
        fig.update_layout(xaxis_title="Fare Amount ($)", yaxis_title="Number of Trips")
        fig.show()
    
    # Analyze the relationship between trip distance and fare amount with sampling
    if 'trip_distance' in column_names:
        distance_fare_relation = execute_query("""
            SELECT 
                trip_distance, 
                fare_amount
            FROM all_taxi_trips
            WHERE trip_distance BETWEEN 0 AND 20
            AND fare_amount BETWEEN 0 AND 100
            LIMIT 100  -- Limiting for performance in visualization
        """, "Error analyzing distance-fare relationship")
        
        if not distance_fare_relation.empty:
            print("\nRelationship between trip distance and fare (sample):")
            display(distance_fare_relation.head())
            
            # Create scatter plot of distance vs fare
            fig = px.scatter(
                distance_fare_relation, 
                x="trip_distance", 
                y="fare_amount",
                title="Relationship between Trip Distance and Fare Amount",
                labels={"trip_distance": "Trip Distance (miles)", "fare_amount": "Fare Amount ($)"},
                opacity=0.5
            )
            
            # Add regression line
            fig.update_layout(xaxis_title="Trip Distance (miles)", yaxis_title="Fare Amount ($)")
            fig.show()
        
        # Calculate fare per mile with sampling
        fare_per_mile_query = """
            SELECT 
                trip_distance,
                fare_amount,
                CASE 
                    WHEN trip_distance > 0 THEN fare_amount / trip_distance 
                    ELSE NULL 
                END as fare_per_mile
            FROM all_taxi_trips
            WHERE trip_distance BETWEEN 0.5 AND 20  -- Avoid division by zero or very small values
            AND fare_amount BETWEEN 2.5 AND 100     -- Filter out potential errors
            LIMIT 100 -- Limiting for performance in visualization
        """
        
        fare_per_mile = execute_query(fare_per_mile_query, "Error calculating fare per mile")
        
        if not fare_per_mile.empty:
            # Create distance bucket categories
            fare_per_mile['distance_bucket'] = pd.cut(
                fare_per_mile['trip_distance'], 
                bins=[0, 1, 2, 5, 10, 20],
                labels=['0-1', '1-2', '2-5', '5-10', '10-20']
            )
            
            # Aggregate by distance bucket
            fare_per_mile_agg = fare_per_mile.groupby('distance_bucket').agg({
                'fare_per_mile': ['mean', 'median', 'std', 'count']
            }).reset_index()
            
            fare_per_mile_agg.columns = ['distance_bucket', 'mean', 'median', 'std', 'count']
            
            print("\nFare per mile statistics by distance bucket:")
            display(fare_per_mile_agg)
            
            # Create boxplot of fare per mile by distance buckets
            fig = px.box(
                fare_per_mile,
                x="distance_bucket",
                y="fare_per_mile",
                title="Fare per Mile by Trip Distance",
                labels={"distance_bucket": "Trip Distance (miles)", "fare_per_mile": "Fare per Mile ($)"}
            )
            fig.update_layout(xaxis_title="Trip Distance Range (miles)", yaxis_title="Fare per Mile ($)")
            fig.show()
else:
    print("Fare amount column not found. Skipping fare analysis.")


Fare amount distribution (binned):


Unnamed: 0,fare_bin,trip_count
0,0.0,99204
1,1.0,47809
2,2.0,212335
3,3.0,1079437
4,4.0,2430872



Relationship between trip distance and fare (sample):


Unnamed: 0,trip_distance,fare_amount
0,3.8,14.5
1,2.1,8.0
2,0.97,7.5
3,1.09,8.0
4,4.3,23.5



Fare per mile statistics by distance bucket:


Unnamed: 0,distance_bucket,mean,median,std,count
0,0-1,6.90227,6.5,1.139147,16
1,1-2,5.197518,5.244755,0.744081,23
2,2-5,4.301244,4.028509,0.911283,38
3,5-10,3.361138,3.330375,0.31312,16
4,10-20,2.868884,2.764487,0.204459,7


## Geographic Analysis

In [39]:
# Check if we have location ID columns
has_location_data = 'PULocationID' in column_names and 'DOLocationID' in column_names

if has_location_data:
    # Load taxi zone lookup table if available
    has_zone_lookup = False
    try:
        taxi_zones = pd.read_csv('taxi_zone_lookup.csv')
        has_zone_lookup = True
        print("Loaded taxi zone lookup table")
        display(taxi_zones.head())
    except:
        has_zone_lookup = False
        print("Taxi zone lookup table not available, using location IDs only")
    
    # Top pickup locations
    top_pickups = execute_query("""
        SELECT 
            PULocationID,
            COUNT(*) as pickup_count
        FROM all_taxi_trips
        GROUP BY PULocationID
        ORDER BY pickup_count DESC
        LIMIT 20
    """, "Error getting top pickup locations")
    
    if not top_pickups.empty:
        print("\nTop 20 pickup locations:")
        display(top_pickups)
        
        # Create bar chart for top pickup locations
        fig = px.bar(
            top_pickups,
            x="PULocationID",
            y="pickup_count",
            title="Top 20 Pickup Locations",
            labels={"PULocationID": "Pickup Location ID", "pickup_count": "Number of Pickups"}
        )
        fig.update_layout(xaxis_title="Pickup Location ID", yaxis_title="Number of Pickups")
        fig.show()
    
    # Top dropoff locations
    top_dropoffs = execute_query("""
        SELECT 
            DOLocationID,
            COUNT(*) as dropoff_count
        FROM all_taxi_trips
        GROUP BY DOLocationID
        ORDER BY dropoff_count DESC
        LIMIT 20
    """, "Error getting top dropoff locations")
    
    if not top_dropoffs.empty:
        print("\nTop 20 dropoff locations:")
        display(top_dropoffs)
        
        # Create bar chart for top dropoff locations
        fig = px.bar(
            top_dropoffs,
            x="DOLocationID",
            y="dropoff_count",
            title="Top 20 Dropoff Locations",
            labels={"DOLocationID": "Dropoff Location ID", "dropoff_count": "Number of Dropoffs"}
        )
        fig.update_layout(xaxis_title="Dropoff Location ID", yaxis_title="Number of Dropoffs")
        fig.show()
    
    # Top location pairs (pickup to dropoff)
    top_location_pairs = execute_query("""
        SELECT 
            PULocationID,
            DOLocationID,
            COUNT(*) as trip_count
        FROM all_taxi_trips
        GROUP BY PULocationID, DOLocationID
        ORDER BY trip_count DESC
        LIMIT 20
    """, "Error getting top location pairs")
    
    if not top_location_pairs.empty:
        print("\nTop 20 pickup-dropoff location pairs:")
        display(top_location_pairs)
        
        # Create sunburst chart for pickup-dropoff pairs
        fig = px.sunburst(
            top_location_pairs,
            path=['PULocationID', 'DOLocationID'],
            values='trip_count',
            title="Top Pickup-Dropoff Location Pairs"
        )
        fig.update_layout(margin=dict(t=40, b=40, l=0, r=0))
        fig.show()
    
    # Load taxi zone maps and attempt to overlay data
    try:
        # Create a figure for taxi zone visualization
        print("Attempting to load and display taxi zone maps...")
        
        # Check if the map files exist first
        zone_map_files = {
            'Bronx': '../data/raw/taxi_zone_map_bronx.jpg',
            'Brooklyn': '../data/raw/Taxi_Zone_Map_Brooklyn.jpg',
            'Manhattan': '../data/raw/Taxi_Zone_Map_Manhattan.jpg',
            'Queens': '../data/raw/Taxi_Zone_Map_Queens.jpg',
            'Staten_Island': '../data/raw/Taxi_Zone_Map_Staten_Island.jpg'
        }
        
        # For demonstration, display the first available map
        map_found = False
        for borough, file_name in zone_map_files.items():
            try:
                img = Image.open(file_name)
                print(f"Successfully loaded map for {borough}")
                
                # Plot the image with annotations
                fig = px.imshow(img)
                fig.update_layout(title=f"Taxi Zone Map - {borough}")
                fig.show()
                map_found = True
                break
            except Exception as e:
                print(f"Could not load map for {borough}: {e}")
        
        if not map_found:
            print("No taxi zone maps found. Make sure the map files are in the current directory.")
        
        print("Note: For a complete integration of zone maps with taxi data, ")
        print("we would need geographic coordinates for each zone ID to overlay the data accurately.")
        
    except Exception as e:
        print(f"Could not process taxi zone maps: {e}")
        print("Continuing with analysis without map visualization")
else:
    print("Location ID columns not found. Skipping geographic analysis.")

Taxi zone lookup table not available, using location IDs only

Top 20 pickup locations:


Unnamed: 0,PULocationID,pickup_count
0,132,5898868
1,237,5566271
2,161,5253823
3,236,4997465
4,162,4064701
5,186,3961770
6,230,3882503
7,142,3867582
8,138,3690412
9,170,3522290



Top 20 dropoff locations:


Unnamed: 0,DOLocationID,dropoff_count
0,236,5258673
1,237,4978028
2,161,4488000
3,230,3667387
4,170,3518898
5,239,3338298
6,162,3331539
7,142,3324967
8,141,3165837
9,48,3028207



Top 20 pickup-dropoff location pairs:


Unnamed: 0,PULocationID,DOLocationID,trip_count
0,237,236,797164
1,236,237,683358
2,264,264,564575
3,237,237,550691
4,236,236,531785
5,161,237,356502
6,237,161,349334
7,161,236,306596
8,142,239,302724
9,239,142,298586


Attempting to load and display taxi zone maps...
Successfully loaded map for Bronx


Note: For a complete integration of zone maps with taxi data, 
we would need geographic coordinates for each zone ID to overlay the data accurately.


## Analysis of Other Features

In [40]:
# Payment type distribution
if 'payment_type' in column_names:
    payment_dist = execute_query("""
        SELECT 
            payment_type,
            COUNT(*) as count,
            100.0 * COUNT(*) / (SELECT COUNT(*) FROM all_taxi_trips) as percentage
        FROM all_taxi_trips
        GROUP BY payment_type
        ORDER BY count DESC
    """, "Error analyzing payment types")
    
    if not payment_dist.empty:
        print("Payment type distribution:")
        display(payment_dist)
        
        # Create pie chart for payment types
        fig = px.pie(
            payment_dist,
            values='count',
            names='payment_type',
            title="Distribution of Payment Types"
        )
        fig.update_traces(textposition='inside', textinfo='percent+label')
        fig.show()

# Passenger count distribution
if 'passenger_count' in column_names:
    passenger_dist = execute_query("""
        SELECT 
            passenger_count,
            COUNT(*) as count
        FROM all_taxi_trips
        WHERE passenger_count BETWEEN 1 AND 9  -- Filter out potential errors
        GROUP BY passenger_count
        ORDER BY passenger_count
    """, "Error analyzing passenger counts")
    
    if not passenger_dist.empty:
        print("\nPassenger count distribution:")
        display(passenger_dist)
        
        # Create bar chart for passenger counts
        fig = px.bar(
            passenger_dist,
            x="passenger_count",
            y="count",
            title="Distribution of Passenger Counts",
            labels={"passenger_count": "Number of Passengers", "count": "Number of Trips"}
        )
        fig.update_layout(xaxis_title="Number of Passengers", yaxis_title="Number of Trips",
                         xaxis=dict(tickmode='linear', tick0=1, dtick=1))
        fig.show()

# Tip amount analysis
if all(col in column_names for col in ['payment_type', 'tip_amount', 'total_amount']):
    tip_analysis = execute_query("""
        SELECT 
            payment_type,
            AVG(tip_amount) as avg_tip,
            AVG(CASE WHEN total_amount > 0 THEN tip_amount / total_amount * 100 ELSE 0 END) as avg_tip_percentage
        FROM all_taxi_trips
        GROUP BY payment_type
        ORDER BY avg_tip DESC
    """, "Error analyzing tip amounts")
    
    if not tip_analysis.empty:
        print("\nTip amount analysis by payment type:")
        display(tip_analysis)
        
        # Create bar chart for average tip by payment type
        fig = px.bar(
            tip_analysis,
            x="payment_type",
            y=["avg_tip", "avg_tip_percentage"],
            title="Average Tip by Payment Type",
            barmode="group",
            labels={"payment_type": "Payment Type", "value": "Amount", "variable": "Metric"}
        )
        fig.update_layout(xaxis_title="Payment Type", yaxis_title="Value")
        fig.show()

Payment type distribution:


Unnamed: 0,payment_type,count,percentage
0,1.0,91808036,75.61007
1,2.0,20393274,16.795228
2,0.0,6768891,5.574635
3,4.0,1541190,1.269273
4,3.0,741288,0.6105
5,,170281,0.140238
6,5.0,68,5.6e-05



Passenger count distribution:


Unnamed: 0,passenger_count,count
0,1.0,86505723
1,2.0,17073922
2,3.0,4248100
3,4.0,2324358
4,5.0,1541183
5,6.0,1024885
6,7.0,536
7,8.0,759
8,9.0,235



Tip amount analysis by payment type:


Unnamed: 0,payment_type,avg_tip,avg_tip_percentage
0,0.0,28.0301,36.821387
1,1.0,4.060519,14.85439
2,,3.507404,11.779765
3,4.0,0.050554,0.024052
4,3.0,0.020813,0.039735
5,2.0,0.001854,0.001265
6,5.0,0.0,0.0


## Weather Impact Analysis

In [41]:
# # Check if we have the necessary weather table and datetime column
# has_weather_table = 'nyc_weather' in tables['name'].values
# has_datetime_column = datetime_col is not None

# if has_weather_table and has_datetime_column:
#     try:
#         # Verify weather data date range
#         weather_range = execute_query("""
#             SELECT 
#                 MIN(datetime) as min_date,
#                 MAX(datetime) as max_date
#             FROM nyc_weather
#         """, "Error checking weather date range")
        
#         if not weather_range.empty:
#             print("Weather data date range:")
#             display(weather_range)
            
#             # Try to analyze trip volume by weather conditions
#             try:
#                 weather_impact = execute_query(f"""
#                     SELECT 
#                         w.conditions,
#                         COUNT(*) as trip_count,
#                         AVG(t.fare_amount) as avg_fare,
#                         AVG(t.trip_distance) as avg_distance
#                     FROM all_taxi_trips t
#                     JOIN nyc_weather w ON DATE(t.{datetime_col}) = DATE(w.datetime)
#                     GROUP BY w.conditions
#                     ORDER BY trip_count DESC
#                 """, "Error analyzing weather impact")
                
#                 if not weather_impact.empty:
#                     print("Trip analysis by weather conditions:")
#                     display(weather_impact)
                    
#                     # Create a grouped bar chart for weather impact
#                     fig = make_subplots(specs=[[{"secondary_y": True}]])
                    
#                     # Add trip count bars
#                     fig.add_trace(
#                         go.Bar(
#                             x=weather_impact['conditions'],
#                             y=weather_impact['trip_count'],
#                             name="Trip Count"
#                         ),
#                         secondary_y=False
#                     )
                    
#                     # Add average fare line
#                     fig.add_trace(
#                         go.Scatter(
#                             x=weather_impact['conditions'],
#                             y=weather_impact['avg_fare'],
#                             name="Avg Fare ($)",
#                             mode="lines+markers"
#                         ),
#                         secondary_y=True
#                     )
                    
#                     fig.update_layout(
#                         title_text="Trip Count and Average Fare by Weather Conditions",
#                         xaxis_title="Weather Conditions"
#                     )
                    
#                     fig.update_yaxes(title_text="Number of Trips", secondary_y=False)
#                     fig.update_yaxes(title_text="Average Fare ($)", secondary_y=True)
                    
#                     fig.show()
                
#                 # Analyze impact of temperature on trip volume
#                 temp_impact = execute_query(f"""
#                     SELECT 
#                         FLOOR(w.temp / 5) * 5 as temp_bucket,
#                         COUNT(*) as trip_count,
#                         AVG(t.fare_amount) as avg_fare
#                     FROM all_taxi_trips t
#                     JOIN nyc_weather w ON DATE(t.{datetime_col}) = DATE(w.datetime)
#                     GROUP BY temp_bucket
#                     ORDER BY temp_bucket
#                 """, "Error analyzing temperature impact")
                
#                 if not temp_impact.empty:
#                     temp_impact['temp_range'] = temp_impact['temp_bucket'].apply(lambda x: f"{x}-{x+5}°F")
                    
#                     print("\nTrip analysis by temperature range:")
#                     display(temp_impact)
                    
#                     # Create a dual-axis chart for temperature impact
#                     fig = make_subplots(specs=[[{"secondary_y": True}]])
                    
#                     # Add trip count bars
#                     fig.add_trace(
#                         go.Bar(
#                             x=temp_impact['temp_range'],
#                             y=temp_impact['trip_count'],
#                             name="Trip Count"
#                         ),
#                         secondary_y=False
#                     )
                    
#                     # Add average fare line
#                     fig.add_trace(
#                         go.Scatter(
#                             x=temp_impact['temp_range'],
#                             y=temp_impact['avg_fare'],
#                             name="Avg Fare ($)",
#                             mode="lines+markers"
#                         ),
#                         secondary_y=True
#                     )
                    
#                     fig.update_layout(
#                         title_text="Trip Count and Average Fare by Temperature",
#                         xaxis_title="Temperature Range (°F)"
#                     )
                    
#                     fig.update_yaxes(title_text="Number of Trips", secondary_y=False)
#                     fig.update_yaxes(title_text="Average Fare ($)", secondary_y=True)
                    
#                     fig.show()
                
#                 # Analyze impact of precipitation on trip volume
#                 precip_impact = execute_query(f"""
#                     SELECT 
#                         CASE 
#                             WHEN w.precip = 0 THEN 'No Rain'
#                             WHEN w.precip < 0.1 THEN 'Light Rain'
#                             WHEN w.precip < 0.5 THEN 'Moderate Rain'
#                             ELSE 'Heavy Rain'
#                         END as precipitation_category,
#                         COUNT(*) as trip_count,
#                         AVG(t.fare_amount) as avg_fare,
#                         AVG(t.trip_distance) as avg_distance
#                     FROM all_taxi_trips t
#                     JOIN nyc_weather w ON DATE(t.{datetime_col}) = DATE(w.datetime)
#                     GROUP BY precipitation_category
#                     ORDER BY avg_fare DESC
#                 """, "Error analyzing precipitation impact")
                
#                 if not precip_impact.empty:
#                     print("\nTrip analysis by precipitation level:")
#                     display(precip_impact)
                    
#                     # Create visualization for precipitation impact
#                     fig = px.bar(
#                         precip_impact,
#                         x="precipitation_category",
#                         y=["trip_count", "avg_fare", "avg_distance"],
#                         title="Impact of Precipitation on Taxi Trips",
#                         barmode="group",
#                         labels={
#                             "precipitation_category": "Precipitation Level",
#                             "value": "Value",
#                             "variable": "Metric"
#                         }
#                     )
#                     fig.update_layout(xaxis_title="Precipitation Level")
#                     fig.show()
                
#             except Exception as e:
#                 print(f"Error in weather impact analysis: {e}")
#                 print("Check if the datetime formats in taxi and weather tables are compatible.")
#         else:
#             print("No weather data found.")
#     except Exception as e:
#         print(f"Error accessing weather data: {e}")
#         print("Skipping weather impact analysis")
# else:
#     print("Weather table or datetime column not available. Skipping weather impact analysis.")

## Correlation Analysis

In [42]:
# Create a correlation matrix for numerical columns
numeric_columns = [col for col in ['fare_amount', 'tip_amount', 'total_amount', 
                                   'trip_distance', 'passenger_count', 'extra', 
                                   'mta_tax', 'tolls_amount'] if col in column_names]

if numeric_columns:
    # First check if we have data to correlate
    try:
        # Create a query that selects only existing columns
        correlation_query = "SELECT " + ", ".join(numeric_columns) + """
            FROM all_taxi_trips
            WHERE """
        
        # # Add sensible filters for each column to exclude outliers
        # filters = []
        # for col in numeric_columns:
        #     if 'amount' in col.lower() or 'fare' in col.lower() or 'tip' in col.lower():
        #         filters.append(f"{col} BETWEEN 0 AND 100")
        #     elif 'distance' in col.lower():
        #         filters.append(f"{col} BETWEEN 0

        # Add sensible filters for each column to exclude outliers
        filters = []
        for col in numeric_columns:
            if 'amount' in col.lower() or 'fare' in col.lower() or 'tip' in col.lower():
                filters.append(f"{col} BETWEEN 0 AND 100")
            elif 'distance' in col.lower():
                filters.append(f"{col} BETWEEN 0 AND 50")
            elif 'passenger' in col.lower():
                filters.append(f"{col} BETWEEN 1 AND 9")
            else:
                filters.append(f"{col} IS NOT NULL")
        
        correlation_query += " AND ".join(filters)
        correlation_query += " LIMIT 50000"  # Limit for performance
        
        numerical_data = execute_query(correlation_query, "Error getting correlation data")
        
        if not numerical_data.empty and len(numerical_data) > 100:  # Ensure we have enough data
            # Calculate correlation matrix
            corr_matrix = numerical_data.corr()
            print("Correlation matrix between numerical variables:")
            display(corr_matrix)
            
            # Create heatmap using Plotly
            fig = px.imshow(
                corr_matrix,
                text_auto=True,
                aspect="auto",
                title="Correlation Matrix of Numerical Features",
                color_continuous_scale='RdBu_r'
            )
            fig.update_layout(width=800, height=800)
            fig.show()
        else:
            print("Not enough numerical data available for correlation analysis.")
    except Exception as e:
        print(f"Could not perform correlation analysis: {e}")
        print("Some columns might not be numeric or contain valid data.")
else:
    print("No suitable numerical columns found for correlation analysis.")

Correlation matrix between numerical variables:


Unnamed: 0,fare_amount,tip_amount,total_amount,trip_distance,passenger_count,extra,mta_tax,tolls_amount
fare_amount,1.0,0.509427,0.97769,0.9051,0.033945,-0.054719,-0.291278,0.603118
tip_amount,0.509427,1.0,0.653356,0.503353,0.009714,-0.006035,-0.090866,0.3476
total_amount,0.97769,0.653356,1.0,0.904156,0.032294,-0.03919,-0.271638,0.676263
trip_distance,0.9051,0.503353,0.904156,1.0,0.028734,-0.031969,-0.066334,0.597463
passenger_count,0.033945,0.009714,0.032294,0.028734,1.0,-0.024988,-0.017411,0.019142
extra,-0.054719,-0.006035,-0.03919,-0.031969,-0.024988,1.0,0.074922,-0.029166
mta_tax,-0.291278,-0.090866,-0.271638,-0.066334,-0.017411,0.074922,1.0,-0.278714
tolls_amount,0.603118,0.3476,0.676263,0.597463,0.019142,-0.029166,-0.278714,1.0


## Features Engineering Preview

In [43]:
# Create examples of engineered features
feature_engineering_columns = []
feature_engineering_expressions = []

# We'll only add features if the relevant columns exist
if 'trip_distance' in column_names:
    feature_engineering_columns.append('trip_distance')
    
if 'fare_amount' in column_names:
    feature_engineering_columns.append('fare_amount')
    
    # Add cost per mile if both fare and distance are available
    if 'trip_distance' in column_names:
        feature_engineering_expressions.append("""
            CASE 
                WHEN trip_distance > 0 THEN fare_amount / trip_distance 
                ELSE NULL 
            END as cost_per_mile
        """)
        
        # Add distance category
        feature_engineering_expressions.append("""
            CASE
                WHEN trip_distance < 1 THEN 'very_short'
                WHEN trip_distance < 3 THEN 'short'
                WHEN trip_distance < 10 THEN 'medium'
                ELSE 'long'
            END as distance_category
        """)

if 'passenger_count' in column_names:
    feature_engineering_columns.append('passenger_count')
    
    # Add passenger group
    feature_engineering_expressions.append("""
        CASE
            WHEN passenger_count = 1 THEN 'solo'
            WHEN passenger_count = 2 THEN 'couple'
            WHEN passenger_count <= 4 THEN 'small_group'
            ELSE 'large_group'
        END as passenger_group
    """)

if 'payment_type' in column_names:
    feature_engineering_columns.append('payment_type')
    
    # Add payment category
    feature_engineering_expressions.append("""
        CASE
            WHEN payment_type = 2 THEN 'cash'  -- Assuming payment_type 2 is cash
            ELSE 'non_cash'
        END as payment_category
    """)

if 'tip_amount' in column_names:
    feature_engineering_columns.append('tip_amount')
    
    # Add has tip feature
    feature_engineering_expressions.append("""
        tip_amount > 0 as has_tip
    """)

if 'PULocationID' in column_names and 'DOLocationID' in column_names:
    feature_engineering_columns.extend(['PULocationID', 'DOLocationID'])
    
    # Add is round trip
    feature_engineering_expressions.append("""
        PULocationID = DOLocationID as is_round_trip
    """)

# Build the query if we have columns to work with
if feature_engineering_columns and feature_engineering_expressions:
    feature_sample_query = """
        SELECT
            """ + ", ".join(feature_engineering_columns) + """,
            """ + ", ".join(feature_engineering_expressions) + """
        FROM all_taxi_trips
        LIMIT 10
    """
    
    feature_sample = execute_query(feature_sample_query, "Error creating engineered features")
    
    if not feature_sample.empty:
        print("Example of engineered features:")
        display(feature_sample)
    
    # Define fare categories based on quantiles (if fare_amount exists)
    if 'fare_amount' in column_names:
        try:
            fare_quantiles = execute_query("""
                SELECT
                    PERCENTILE_CONT(0.33) WITHIN GROUP (ORDER BY fare_amount) as q33,
                    PERCENTILE_CONT(0.66) WITHIN GROUP (ORDER BY fare_amount) as q66
                FROM all_taxi_trips
                WHERE fare_amount BETWEEN 0 AND 100  -- Filter extreme outliers
            """, "Error calculating fare quantiles")
            
            if not fare_quantiles.empty:
                q33 = fare_quantiles['q33'].iloc[0]
                q66 = fare_quantiles['q66'].iloc[0]
                
                fare_categories_query = f"""
                    SELECT
                        fare_amount,
                        CASE
                            WHEN fare_amount < {q33} THEN 'low'
                            WHEN fare_amount < {q66} THEN 'medium'
                            ELSE 'high'
                        END as fare_category
                    FROM all_taxi_trips
                    WHERE fare_amount BETWEEN 0 AND 100  -- Filter extreme outliers
                    LIMIT 10
                """
                
                fare_categories = execute_query(fare_categories_query, "Error creating fare categories")
                
                if not fare_categories.empty:
                    print("\nExample of fare categories:")
                    display(fare_categories)
                
                # Distribution of fare categories
                fare_category_dist_query = f"""
                    SELECT
                        CASE
                            WHEN fare_amount < {q33} THEN 'low'
                            WHEN fare_amount < {q66} THEN 'medium'
                            ELSE 'high'
                        END as fare_category,
                        COUNT(*) as count
                    FROM all_taxi_trips
                    WHERE fare_amount BETWEEN 0 AND 100  -- Filter extreme outliers
                    GROUP BY fare_category
                """
                
                fare_category_dist = execute_query(fare_category_dist_query, "Error calculating fare category distribution")
                
                if not fare_category_dist.empty:
                    print("\nDistribution of fare categories:")
                    display(fare_category_dist)
                    
                    # Create pie chart for fare categories
                    fig = px.pie(
                        fare_category_dist,
                        values='count',
                        names='fare_category',
                        title="Distribution of Fare Categories",
                        color='fare_category',
                        color_discrete_map={'low': '#2ecc71', 'medium': '#f39c12', 'high': '#e74c3c'}
                    )
                    fig.update_traces(textposition='inside', textinfo='percent+label')
                    fig.show()
        except Exception as e:
            print(f"Could not analyze fare categories: {e}")
else:
    print("Not enough data columns available to demonstrate feature engineering.")

Example of engineered features:


Unnamed: 0,trip_distance,fare_amount,passenger_count,payment_type,tip_amount,PULocationID,DOLocationID,cost_per_mile,distance_category,passenger_group,payment_category,has_tip,is_round_trip
0,3.8,14.5,2.0,1.0,3.65,142,236,3.815789,medium,couple,non_cash,True,False
1,2.1,8.0,1.0,1.0,4.0,236,42,3.809524,short,solo,non_cash,True,False
2,0.97,7.5,1.0,1.0,1.76,166,166,7.731959,very_short,solo,non_cash,True,True
3,1.09,8.0,1.0,2.0,0.0,114,68,7.33945,short,solo,cash,False,False
4,4.3,23.5,1.0,1.0,3.0,68,163,5.465116,medium,solo,non_cash,True,False
5,10.3,33.0,1.0,1.0,13.0,138,161,3.203883,long,solo,non_cash,True,False
6,5.07,17.0,1.0,1.0,5.2,233,87,3.353057,medium,solo,non_cash,True,False
7,2.02,9.0,1.0,2.0,0.0,238,152,4.455446,short,solo,cash,False,False
8,2.71,12.0,1.0,1.0,2.25,166,236,4.428044,short,solo,non_cash,True,False
9,0.78,5.0,1.0,2.0,0.0,236,141,6.410256,very_short,solo,cash,False,False



Example of fare categories:


Unnamed: 0,fare_amount,fare_category
0,14.5,medium
1,8.0,low
2,7.5,low
3,8.0,low
4,23.5,high
5,33.0,high
6,17.0,high
7,9.0,low
8,12.0,medium
9,5.0,low



Distribution of fare categories:


Unnamed: 0,fare_category,count
0,medium,40171272
1,low,38431115
2,high,41123471


## Outlier Analysis

In [44]:
# Identify outliers in fare amounts and trip distances
outlier_queries = []

if 'fare_amount' in column_names:
    outlier_queries.append("""
        SELECT
            'fare_amount' as feature,
            MIN(fare_amount) as min_value,
            MAX(fare_amount) as max_value,
            AVG(fare_amount) as mean,
            PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY fare_amount) as q1,
            PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY fare_amount) as q3,
            PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY fare_amount) - 
            PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY fare_amount) as iqr,
            COUNT(*) FILTER (WHERE fare_amount < 0) as negative_values,
            COUNT(*) FILTER (WHERE fare_amount > 100) as extreme_high_values
        FROM all_taxi_trips
    """)

if 'trip_distance' in column_names:
    outlier_queries.append("""
        SELECT
            'trip_distance' as feature,
            MIN(trip_distance) as min_value,
            MAX(trip_distance) as max_value,
            AVG(trip_distance) as mean,
            PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY trip_distance) as q1,
            PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY trip_distance) as q3,
            PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY trip_distance) - 
            PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY trip_distance) as iqr,
            COUNT(*) FILTER (WHERE trip_distance < 0) as negative_values,
            COUNT(*) FILTER (WHERE trip_distance > 50) as extreme_high_values
        FROM all_taxi_trips
    """)

if outlier_queries:
    outliers_query = " UNION ALL ".join(outlier_queries)
    outliers = execute_query(outliers_query, "Error in outlier analysis")
    
    if not outliers.empty:
        print("Outlier analysis for key features:")
        display(outliers)
    
    # Get examples of outlier trips
    outlier_examples_queries = []
    
    if 'fare_amount' in column_names:
        outlier_examples_queries.append("""
            SELECT
                'Negative fare' as outlier_type,
                *
            FROM all_taxi_trips
            WHERE fare_amount < 0
            LIMIT 3
        """)
        
        outlier_examples_queries.append("""
            SELECT
                'Extremely high fare' as outlier_type,
                *
            FROM all_taxi_trips
            WHERE fare_amount > 100
            LIMIT 3
        """)
    
    if 'trip_distance' in column_names and 'fare_amount' in column_names:
        outlier_examples_queries.append("""
            SELECT
                'Zero distance but non-zero fare' as outlier_type,
                *
            FROM all_taxi_trips
            WHERE trip_distance = 0 AND fare_amount > 10
            LIMIT 3
        """)
    
    if 'trip_distance' in column_names:
        outlier_examples_queries.append("""
            SELECT
                'Unreasonable distance' as outlier_type,
                *
            FROM all_taxi_trips
            WHERE trip_distance > 100
            LIMIT 3
        """)
    
    if outlier_examples_queries:
        outlier_examples_query = " UNION ALL ".join(outlier_examples_queries)
        outlier_examples = execute_query(outlier_examples_query, "Error getting outlier examples")
        
        if not outlier_examples.empty:
            print("\nExamples of outlier trips:")
            display(outlier_examples)
    
    # Visualize the outliers in fare amounts
    if 'fare_amount' in column_names:
        fare_boxplot = execute_query("""
            SELECT fare_amount
            FROM all_taxi_trips
            WHERE fare_amount BETWEEN -10 AND 100  -- Include some negative values but cap the upper limit
            LIMIT 10000  -- Limit for performance
        """, "Error creating fare boxplot data")
        
        if not fare_boxplot.empty:
            fig = px.box(
                fare_boxplot,
                y="fare_amount",
                title="Boxplot of Fare Amounts (capped at $100)",
                labels={"fare_amount": "Fare Amount ($)"}
            )
            fig.update_layout(yaxis_title="Fare Amount ($)")
            fig.show()
    
    # Visualize the outliers in trip distances
    if 'trip_distance' in column_names:
        distance_boxplot = execute_query("""
            SELECT trip_distance
            FROM all_taxi_trips
            WHERE trip_distance BETWEEN 0 AND 30  -- Cap the upper limit for better visualization
            LIMIT 10000  -- Limit for performance
        """, "Error creating distance boxplot data")
        
        if not distance_boxplot.empty:
            fig = px.box(
                distance_boxplot,
                y="trip_distance",
                title="Boxplot of Trip Distances (capped at 30 miles)",
                labels={"trip_distance": "Trip Distance (miles)"}
            )
            fig.update_layout(yaxis_title="Trip Distance (miles)")
            fig.show()
else:
    print("Required columns for outlier analysis not available.")

Outlier analysis for key features:


Unnamed: 0,feature,min_value,max_value,mean,q1,q3,iqr,negative_values,extreme_high_values
0,fare_amount,-133391414.0,401092.32,16.401456,8.5,20.5,12.0,1372078,325092
1,trip_distance,0.0,398608.62,5.723947,1.06,3.44,2.38,0,21459


Error getting outlier examples: Parser Error: syntax error at or near "UNION"


## Trip Duration Analysis

In [None]:
# # Check if we have trip duration data (depends on having both pickup and dropoff times)
# pickup_col = None
# dropoff_col = None

# # Look for pickup and dropoff datetime columns with various possible names
# for p_col in ['pickup_datetime', 'lpep_pickup_datetime', 'tpep_pickup_datetime']:
#     if p_col in column_names:
#         pickup_col = p_col
#         break

# for d_col in ['dropoff_datetime', 'lpep_dropoff_datetime', 'tpep_dropoff_datetime']:
#     if d_col in column_names:
#         dropoff_col = d_col
#         break

# if pickup_col and dropoff_col:
#     print(f"Found pickup column: {pickup_col} and dropoff column: {dropoff_col}")
    
#     try:
#         # Analyze trip durations
#         duration_stats_query = f"""
#             SELECT
#                 AVG(EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 60) as avg_duration_minutes,
#                 MIN(EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 60) as min_duration_minutes,
#                 MAX(EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 60) as max_duration_minutes,
#                 PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 60) as median_duration_minutes
#             FROM all_taxi_trips
#             WHERE 
#                 {dropoff_col} > {pickup_col}
#                 AND EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 60 BETWEEN 1 AND 120  -- Between 1 min and 2 hours
#         """
        
#         duration_stats = execute_query(duration_stats_query, "Error calculating trip duration statistics")
        
#         if not duration_stats.empty:
#             print("Trip duration statistics (in minutes):")
#             display(duration_stats)
            
#             # Distribution of trip durations
#             duration_dist_query = f"""
#                 SELECT
#                     FLOOR(EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 60 / 5) * 5 as duration_bucket,
#                     COUNT(*) as count
#                 FROM all_taxi_trips
#                 WHERE 
#                     {dropoff_col} > {pickup_col}
#                     AND EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 60 BETWEEN 1 AND 120
#                 GROUP BY duration_bucket
#                 ORDER BY duration_bucket
#             """
            
#             duration_dist = execute_query(duration_dist_query, "Error calculating trip duration distribution")
            
#             if not duration_dist.empty:
#                 duration_dist['duration_range'] = duration_dist['duration_bucket'].apply(lambda x: f"{x}-{x+5}")
                
#                 print("\nDistribution of trip durations:")
#                 display(duration_dist.head())
                
#                 fig = px.bar(
#                     duration_dist,
#                     x="duration_range",
#                     y="count",
#                     title="Distribution of Trip Durations",
#                     labels={"duration_range": "Duration (minutes)", "count": "Number of Trips"}
#                 )
#                 fig.update_layout(xaxis_title="Duration (minutes)", yaxis_title="Number of Trips")
#                 fig.show()
            
#             # Calculate average speed (only if we have trip_distance)
#             if 'trip_distance' in column_names:
#                 speed_analysis_query = f"""
#                     SELECT
#                         trip_distance,
#                         EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 3600 as duration_hours,
#                         trip_distance / NULLIF(EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 3600, 0) as speed_mph
#                     FROM all_taxi_trips
#                     WHERE 
#                         {dropoff_col} > {pickup_col}
#                         AND EXTRACT(EPOCH FROM ({dropoff_col} - {pickup_col})) / 60 BETWEEN 1 AND 120
#                         AND trip_distance > 0
#                         AND trip_distance < 50
#                     LIMIT 10000  -- Limit for visualization performance
#                 """
                
#                 speed_analysis = execute_query(speed_analysis_query, "Error calculating trip speeds")
                
#                 if not speed_analysis.empty:
#                     # Filter out unreasonable speeds (e.g., > 100 mph)
#                     speed_analysis = speed_analysis[speed_analysis['speed_mph'] <= 100]
                    
#                     print("\nTrip speed analysis (sample):")
#                     display(speed_analysis.head())
                    
#                     # Create scatter plot of distance vs. duration with speed as color
#                     fig = px.scatter(
#                         speed_analysis,
#                         x="duration_hours",
#                         y="trip_distance",
#                         color="speed_mph",
#                         title="Trip Distance vs. Duration (with Speed)",
#                         labels={
#                             "duration_hours": "Trip Duration (hours)",
#                             "trip_distance": "Trip Distance (miles)",
#                             "speed_mph": "Average Speed (mph)"
#                         },
#                         color_continuous_scale="Viridis"
#                     )
#                     fig.update_layout(
#                         xaxis_title="Trip Duration (hours)",
#                         yaxis_title="Trip Distance (miles)",
#                         coloraxis_colorbar=dict(title="Speed (mph)")
#                     )
#                     fig.show()
#     except Exception as e:
#         print(f"Could not perform trip duration analysis: {e}")
#         print("There might be issues with the datetime columns or their format.")
# else:
#     print("Required pickup and dropoff datetime columns not found. Skipping trip duration analysis.")

Required pickup and dropoff datetime columns not found. Skipping trip duration analysis.


## Rate Code Analysis

In [46]:
if 'RatecodeID' in column_names:
    try:
        # Analyze rate code distribution
        ratecode_dist = execute_query("""
            SELECT
                RatecodeID,
                COUNT(*) as count,
                100.0 * COUNT(*) / (SELECT COUNT(*) FROM all_taxi_trips) as percentage
            FROM all_taxi_trips
            GROUP BY RatecodeID
            ORDER BY count DESC
        """, "Error analyzing rate codes")
        
        if not ratecode_dist.empty:
            # Add rate code descriptions if available
            ratecode_descriptions = {
                1: "Standard rate",
                2: "JFK",
                3: "Newark",
                4: "Nassau or Westchester",
                5: "Negotiated fare",
                6: "Group ride"
            }
            
            ratecode_dist['description'] = ratecode_dist['RatecodeID'].map(
                lambda x: ratecode_descriptions.get(x, f"Unknown ({x})")
            )
            
            print("Rate code distribution:")
            display(ratecode_dist)
            
            # Create pie chart for rate code distribution
            fig = px.pie(
                ratecode_dist,
                values='count',
                names='description',
                title="Distribution of Rate Codes",
                hover_data=['percentage']
            )
            fig.update_traces(textposition='inside', textinfo='percent+label')
            fig.show()
            
            # Analyze fare amounts by rate code (if fare_amount exists)
            if 'fare_amount' in column_names and 'trip_distance' in column_names:
                rate_fare_query = """
                    SELECT
                        RatecodeID,
                        COUNT(*) as count,
                        AVG(fare_amount) as avg_fare,
                        AVG(trip_distance) as avg_distance,
                        AVG(CASE WHEN trip_distance > 0 THEN fare_amount / trip_distance ELSE NULL END) as avg_fare_per_mile
                    FROM all_taxi_trips
                    WHERE fare_amount BETWEEN 0 AND 200
                    AND trip_distance BETWEEN 0 AND 100
                    GROUP BY RatecodeID
                    ORDER BY avg_fare DESC
                """
                
                rate_fare_analysis = execute_query(rate_fare_query, "Error analyzing fares by rate code")
                
                if not rate_fare_analysis.empty:
                    # Add rate code descriptions
                    rate_fare_analysis['description'] = rate_fare_analysis['RatecodeID'].map(
                        lambda x: ratecode_descriptions.get(x, f"Unknown ({x})")
                    )
                    
                    print("\nFare analysis by rate code:")
                    display(rate_fare_analysis)
                    
                    # Create grouped bar chart for rate code metrics
                    fig = make_subplots(specs=[[{"secondary_y": True}]])
                    
                    # Add bars for average fare
                    fig.add_trace(
                        go.Bar(
                            x=rate_fare_analysis['description'],
                            y=rate_fare_analysis['avg_fare'],
                            name="Avg Fare ($)"
                        ),
                        secondary_y=False
                    )
                    
                    # Add line for average distance
                    fig.add_trace(
                        go.Scatter(
                            x=rate_fare_analysis['description'],
                            y=rate_fare_analysis['avg_distance'],
                            name="Avg Distance (miles)",
                            mode="lines+markers"
                        ),
                        secondary_y=True
                    )
                    
                    fig.update_layout(
                        title_text="Average Fare and Distance by Rate Code",
                        xaxis_title="Rate Code"
                    )
                    
                    fig.update_yaxes(title_text="Average Fare ($)", secondary_y=False)
                    fig.update_yaxes(title_text="Average Distance (miles)", secondary_y=True)
                    
                    fig.show()
    except Exception as e:
        print(f"Could not perform rate code analysis: {e}")
else:
    print("RatecodeID column not found. Skipping rate code analysis.")

Rate code distribution:


Unnamed: 0,RatecodeID,count,percentage,description
0,1.0,107727695,88.720976,Standard rate
1,,6939172,5.714873,Unknown (nan)
2,2.0,4388898,3.614552,JFK
3,5.0,951118,0.783309,Negotiated fare
4,99.0,814390,0.670705,Unknown (99.0)
5,3.0,368072,0.303132,Newark
6,4.0,233174,0.192034,Nassau or Westchester
7,6.0,509,0.000419,Group ride



Fare analysis by rate code:


Unnamed: 0,RatecodeID,count,avg_fare,avg_distance,avg_fare_per_mile,description
0,4.0,208767,89.95987,18.550923,5.12769,Nassau or Westchester
1,3.0,353283,79.965613,16.22882,29.58851,Newark
2,2.0,4301868,63.940195,17.462328,45.991865,JFK
3,5.0,903262,60.971895,4.603308,458.460433,Negotiated fare
4,99.0,814144,33.629908,7.244079,8.44877,Unknown (99.0)
5,,6800224,21.070285,3.490498,17.033666,Unknown (nan)
6,1.0,106624782,15.312545,2.740673,7.557727,Standard rate
7,6.0,460,4.776978,1.444609,34.416578,Group ride


## Borough-Level Analysis with Taxi Zone Maps

In [47]:
# Attempt to load and use taxi zone lookup information if available
try:
    # Check if we have a zone lookup table
    zone_lookup_available = False
    try:
        taxi_zones = pd.read_csv('taxi_zone_lookup.csv')
        zone_lookup_available = True
        print("Loaded taxi zone lookup table")
        display(taxi_zones.head())
    except:
        zone_lookup_available = False
        print("Taxi zone lookup table not available")
    
    # Try to work with zone maps
    zone_map_files = {
        'Bronx': 'Taxi_Zone_Map_Bronx.jpg',
        'Brooklyn': 'Taxi_Zone_Map_Brooklyn.jpg',
        'Manhattan': 'Taxi_Zone_Map_Manhattan.jpg',
        'Queens': 'Taxi_Zone_Map_Queens.jpg',
        'Staten_Island': 'Taxi_Zone_Map_Staten_Island.jpg'
    }
    
    # Load available maps
    loaded_maps = {}
    print("Checking for available taxi zone maps...")
    for borough, file_name in zone_map_files.items():
        try:
            img = Image.open(file_name)
            loaded_maps[borough] = img
            print(f"✓ Loaded map for {borough}")
        except Exception as e:
            print(f"✗ Could not load map for {borough}: {e}")
    
    # If we have both location IDs, zone lookup, and map files, we can do detailed analysis
    if 'PULocationID' in column_names and 'DOLocationID' in column_names and zone_lookup_available and loaded_maps:
        print("\nPerforming borough-level analysis with zone information and maps")
        
        # Create a temporary table with the zone lookup data
        zones_csv_path = os.path.abspath('taxi_zone_lookup.csv')
        conn.execute(f"""
            CREATE OR REPLACE TEMP TABLE taxi_zones AS
            SELECT * FROM read_csv_auto('{zones_csv_path}')
        """)
        
        # Analyze trips by borough
        borough_analysis = execute_query("""
            SELECT
                z1.Borough as pickup_borough,
                COUNT(*) as trip_count,
                AVG(t.fare_amount) as avg_fare,
                AVG(t.trip_distance) as avg_distance
            FROM all_taxi_trips t
            JOIN taxi_zones z1 ON t.PULocationID = z1.LocationID
            GROUP BY pickup_borough
            ORDER BY trip_count DESC
        """, "Error analyzing trips by borough")
        
        if not borough_analysis.empty:
            print("\nTrips by pickup borough:")
            display(borough_analysis)
            
            # Create visualization for borough analysis
            fig = px.bar(
                borough_analysis,
                x="pickup_borough",
                y="trip_count",
                color="avg_fare",
                title="Trips by Pickup Borough",
                labels={
                    "pickup_borough": "Pickup Borough",
                    "trip_count": "Number of Trips",
                    "avg_fare": "Average Fare ($)"
                },
                color_continuous_scale="Viridis"
            )
            fig.update_layout(xaxis_title="Pickup Borough", yaxis_title="Number of Trips")
            fig.show()
        
        # Inter-borough travel analysis
        inter_borough = execute_query("""
            SELECT
                z1.Borough as pickup_borough,
                z2.Borough as dropoff_borough,
                COUNT(*) as trip_count,
                AVG(t.fare_amount) as avg_fare,
                AVG(t.trip_distance) as avg_distance
            FROM all_taxi_trips t
            JOIN taxi_zones z1 ON t.PULocationID = z1.LocationID
            JOIN taxi_zones z2 ON t.DOLocationID = z2.LocationID
            GROUP BY pickup_borough, dropoff_borough
            ORDER BY trip_count DESC
        """, "Error analyzing inter-borough travel")
        
        if not inter_borough.empty:
            print("\nInter-borough travel patterns:")
            display(inter_borough)
            
            # Create a heatmap of inter-borough travel
            inter_borough_pivot = inter_borough.pivot(
                index="pickup_borough",
                columns="dropoff_borough",
                values="trip_count"
            ).fillna(0)
            
            fig = px.imshow(
                inter_borough_pivot,
                text_auto=True,
                aspect="auto",
                title="Inter-Borough Travel Patterns (Trip Count)",
                color_continuous_scale='Viridis'
            )
            fig.update_layout(
                xaxis_title="Dropoff Borough",
                yaxis_title="Pickup Borough"
            )
            fig.show()
            
            # Create another heatmap for average fare between boroughs
            inter_borough_fare_pivot = inter_borough.pivot(
                index="pickup_borough",
                columns="dropoff_borough",
                values="avg_fare"
            ).fillna(0)
            
            fig = px.imshow(
                inter_borough_fare_pivot,
                text_auto=".2f",
                aspect="auto",
                title="Average Fare Between Boroughs ($)",
                color_continuous_scale='RdBu_r'
            )
            fig.update_layout(
                xaxis_title="Dropoff Borough",
                yaxis_title="Pickup Borough"
            )
            fig.show()
        
        # For each borough with an available map, show additional analysis
        for borough, img in loaded_maps.items():
            # Get top pickup and dropoff zones in this borough
            top_zones_query = f"""
                SELECT
                    z.Zone as zone_name,
                    COUNT(*) as pickup_count
                FROM all_taxi_trips t
                JOIN taxi_zones z ON t.PULocationID = z.LocationID
                WHERE z.Borough = '{borough}'
                GROUP BY zone_name
                ORDER BY pickup_count DESC
                LIMIT 10
            """
            
            top_zones = execute_query(top_zones_query, f"Error getting top zones for {borough}")
            
            if not top_zones.empty:
                print(f"\nTop 10 pickup zones in {borough}:")
                display(top_zones)
                
                # Display the map with annotation for the top zone
                fig = px.imshow(img)
                fig.update_layout(
                    title=f"Taxi Zone Map - {borough}",
                    annotations=[
                        dict(
                            text=f"Top pickup zone: {top_zones['zone_name'].iloc[0]}",
                            x=0.5,
                            y=0.05,
                            xref="paper",
                            yref="paper",
                            showarrow=False,
                            font=dict(size=14, color="white", family="Arial"),
                            bgcolor="rgba(0,0,0,0.7)",
                            bordercolor="white",
                            borderwidth=1,
                            borderpad=4
                        )
                    ]
                )
                fig.show()
    
    # If we have maps but no zone lookup or location IDs
    elif loaded_maps:
        print("\nShowing taxi zone maps without data overlay (missing zone lookup or location data)")
        
        # Just display the available maps
        for borough, img in loaded_maps.items():
            fig = px.imshow(img)
            fig.update_layout(title=f"Taxi Zone Map - {borough}")
            fig.show()
    
    # If we have neither maps nor zone lookup
    else:
        print("\nNo map files or zone lookup data available for geographic visualization.")
        
        # If we have location IDs, we can still do basic analysis
        if 'PULocationID' in column_names and 'DOLocationID' in column_names:
            print("Performing basic location ID analysis without zone information.")
            
            # Most common pairs
            location_pairs = execute_query("""
                SELECT 
                    PULocationID, 
                    DOLocationID,
                    COUNT(*) as trip_count
                FROM all_taxi_trips
                GROUP BY PULocationID, DOLocationID
                ORDER BY trip_count DESC
                LIMIT 20
            """, "Error analyzing location pairs")
            
            if not location_pairs.empty:
                print("\nMost common location ID pairs:")
                display(location_pairs)
                
                # Create a simple visualization
                fig = px.scatter(
                    location_pairs,
                    x="PULocationID",
                    y="DOLocationID",
                    size="trip_count",
                    color="trip_count",
                    title="Most Common Pickup-Dropoff Location Pairs",
                    labels={
                        "PULocationID": "Pickup Location ID",
                        "DOLocationID": "Dropoff Location ID",
                        "trip_count": "Number of Trips"
                    }
                )
                fig.update_layout(
                    xaxis_title="Pickup Location ID",
                    yaxis_title="Dropoff Location ID"
                )
                fig.show()
                
                print("\nNote: For better geographic visualization, obtain the taxi zone lookup table")
                print("and map files to overlay data on actual zones.")

except Exception as e:
    print(f"Could not perform borough-level analysis: {e}")
    print("Skipping borough-level analysis with taxi zone maps")


Taxi zone lookup table not available
Checking for available taxi zone maps...
✗ Could not load map for Bronx: [Errno 2] No such file or directory: 'Taxi_Zone_Map_Bronx.jpg'
✗ Could not load map for Brooklyn: [Errno 2] No such file or directory: 'Taxi_Zone_Map_Brooklyn.jpg'
✗ Could not load map for Manhattan: [Errno 2] No such file or directory: 'Taxi_Zone_Map_Manhattan.jpg'
✗ Could not load map for Queens: [Errno 2] No such file or directory: 'Taxi_Zone_Map_Queens.jpg'
✗ Could not load map for Staten_Island: [Errno 2] No such file or directory: 'Taxi_Zone_Map_Staten_Island.jpg'

No map files or zone lookup data available for geographic visualization.
Performing basic location ID analysis without zone information.

Most common location ID pairs:


Unnamed: 0,PULocationID,DOLocationID,trip_count
0,237,236,797164
1,236,237,683358
2,264,264,564575
3,237,237,550691
4,236,236,531785
5,161,237,356502
6,237,161,349334
7,161,236,306596
8,142,239,302724
9,239,142,298586



Note: For better geographic visualization, obtain the taxi zone lookup table
and map files to overlay data on actual zones.


In [48]:

# ------------------------------------------------
# 14. Conclusion and Summary Insights
# ------------------------------------------------

print("\n14. Conclusion and Summary Insights")
print("-----------------------------------------")

print("""
Based on the exploratory data analysis of NYC taxi trips, here are the key insights:

1. **Data Overview**: 
   - The dataset contains detailed records of taxi trips including fare information, locations, and trip characteristics
   - We've identified key numerical and categorical features for modeling

2. **Fare Analysis**:
   - There are clear patterns in fare amounts related to trip distance
   - Fare per mile varies by trip length, with shorter trips typically costing more per mile
   - We can effectively categorize fares into low, medium, and high segments

3. **Geographic Patterns**:
   - Certain pickup and dropoff locations show significantly higher activity
   - Location pairs exhibit strong patterns that can be leveraged for prediction
   - Different boroughs show distinct fare and trip characteristics

4. **Feature Engineering Opportunities**:
   - Created meaningful features like cost_per_mile, distance categories, and passenger groups
   - Identified key categorical features for prediction models
   - Developed fare categories that can be used for classification tasks

5. **Outliers and Data Quality**:
   - Identified patterns in data outliers that need attention during preprocessing
   - Specific thresholds for fare amounts and distances established for data cleaning
   - Time-based patterns revealed important trends for demand prediction

6. **Next Steps for Modeling**:
   - Multi-stage pipeline approach validated by data characteristics
   - Clustering, classification, and specialized regression approach aligns with data patterns
   - Weather and location data integration will enhance predictive power
""")

# ------------------------------------------------
# 15. Next Steps for Modeling
# ------------------------------------------------

print("\n15. Next Steps for Modeling")
print("-----------------------------------------")

print("""
## Next Steps for Multi-Stage Modeling Pipeline

Based on the EDA findings, here's the recommended approach for implementing the multi-stage modeling pipeline:

### 1. Data Preparation

**Cleaning Tasks:**
- Remove trips with negative fares or distances
- Handle extreme outliers in fare amounts (>$100) and trip distances (>30 miles)
- Address inconsistencies between fare components and total amounts
- Standardize categorical variables (payment types, rate codes)

**Feature Engineering:**
- Create cost per mile feature
- Develop trip distance categories (very short, short, medium, long)
- Generate passenger group features (solo, couple, small/large group)
- Construct payment type features (cash vs. non-cash)
- Create location-based features using pickup/dropoff patterns
- Incorporate weather data as external features
- Add time-based features (hour, day, week patterns) if timestamp data is available

### 2. Implementation of Multi-Stage Model

**Stage 1: Clustering**
- Use K-means clustering with 3-5 clusters
- Features: trip_distance, fare_amount, passenger_count, time features
- Expected clusters: short commuter trips, airport trips, long-distance trips
- Validate with silhouette score and analyze cluster characteristics

**Stage 2: Fare Category Classification**
- Define fare categories as low/medium/high based on quantiles
- Train Random Forest and Gradient Boosting classifiers
- Features: trip characteristics, passenger info, location, cluster assignment
- Evaluate using accuracy, F1-score, and confusion matrix
- Analyze feature importance across segments

**Stage 3: Specialized Regression Models**
- Train separate regression models for each fare category
- Use Random Forest Regressor for non-linear relationships
- Consider Gradient Boosting for optimal performance
- Features: full feature set including engineered features
- Evaluate using MAE, RMSE, R-squared metrics

### 3. Weather and Demand Integration

- Incorporate weather data as predictive features
- Analyze demand fluctuations by weather conditions
- Develop separate models for demand forecasting
- Integrate fare and demand predictions for business applications

### 4. Model Evaluation Framework

- Implement temporal validation (train on earlier data, test on later)
- Use geographic validation for spatial robustness
- Develop business impact metrics beyond statistical measures
- Create visualizations to communicate model insights

This approach will provide a comprehensive solution for both fare price prediction and demand forecasting based on the taxi trip characteristics and external factors.
""")

# Close the database connection
conn.close()

print("\nExploratory Data Analysis Complete!")
print("====================================")


14. Conclusion and Summary Insights
-----------------------------------------

Based on the exploratory data analysis of NYC taxi trips, here are the key insights:

1. **Data Overview**: 
   - The dataset contains detailed records of taxi trips including fare information, locations, and trip characteristics
   - We've identified key numerical and categorical features for modeling

2. **Fare Analysis**:
   - There are clear patterns in fare amounts related to trip distance
   - Fare per mile varies by trip length, with shorter trips typically costing more per mile
   - We can effectively categorize fares into low, medium, and high segments

3. **Geographic Patterns**:
   - Certain pickup and dropoff locations show significantly higher activity
   - Location pairs exhibit strong patterns that can be leveraged for prediction
   - Different boroughs show distinct fare and trip characteristics

4. **Feature Engineering Opportunities**:
   - Created meaningful features like cost_per_mile, d

In [49]:
import pandas as pd
import numpy as np
import duckdb
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
import os

# Suppress warnings
warnings.filterwarnings('ignore')

# Set styling
plt.style.use('fivethirtyeight')
pd.set_option('display.max_columns', None)

# Function to execute SQL queries with error handling
def execute_query(conn, query, error_message="Error executing query"):
    try:
        return conn.execute(query).fetchdf()
    except Exception as e:
        print(f"{error_message}: {str(e)}")
        return pd.DataFrame()  # Return empty dataframe instead of failing

# ------------------------------------------------
# 1. Connect to DuckDB and Load Data
# ------------------------------------------------

print("1. Connecting to DuckDB and Loading Data")
print("-----------------------------------------")

# Initialize DuckDB connection - update path as needed for your environment
conn = duckdb.connect('../db/nyc_taxi.duckdb')

# Check available tables in the database
tables = execute_query(conn, "SHOW TABLES", "Error listing tables")
print(f"Available tables in the database:")
print(tables)

# Get column names from the taxi trips table
all_taxi_columns = execute_query(conn, "PRAGMA table_info(all_taxi_trips)", "Error getting taxi table structure")
if not all_taxi_columns.empty:
    column_names = all_taxi_columns['name'].tolist()
else:
    column_names = []
    print("No taxi columns available.")

# ------------------------------------------------
# 2. Improved Geographic Analysis Visualizations
# ------------------------------------------------

print("\n2. Improved Geographic Analysis Visualizations")
print("-----------------------------------------")

# Top pickup locations - convert to horizontal bar chart
if 'PULocationID' in column_names:
    top_pickups = execute_query(conn, """
        SELECT 
            PULocationID,
            COUNT(*) as pickup_count
        FROM all_taxi_trips
        GROUP BY PULocationID
        ORDER BY pickup_count DESC
        LIMIT 20
    """, "Error getting top pickup locations")
    
    if not top_pickups.empty:
        print("\nTop 20 pickup locations:")
        print(top_pickups)
        
        # Create horizontal bar chart for top pickup locations
        fig = px.bar(
            top_pickups.sort_values('pickup_count'),  # Sort to have highest bar at top
            y="PULocationID",
            x="pickup_count",
            orientation='h',  # Horizontal orientation
            title="Top 20 Pickup Locations",
            labels={"PULocationID": "Pickup Location ID", "pickup_count": "Number of Pickups"},
            color="pickup_count",  # Color by count for better visualization
            color_continuous_scale="Blues"  # Use a color scale
        )
        fig.update_layout(
            yaxis_title="Pickup Location ID", 
            xaxis_title="Number of Pickups",
            height=600  # Make it taller to fit all bars
        )
        fig.show()

# Top dropoff locations - convert to horizontal bar chart
if 'DOLocationID' in column_names:
    top_dropoffs = execute_query(conn, """
        SELECT 
            DOLocationID,
            COUNT(*) as dropoff_count
        FROM all_taxi_trips
        GROUP BY DOLocationID
        ORDER BY dropoff_count DESC
        LIMIT 20
    """, "Error getting top dropoff locations")
    
    if not top_dropoffs.empty:
        print("\nTop 20 dropoff locations:")
        print(top_dropoffs)
        
        # Create horizontal bar chart for top dropoff locations
        fig = px.bar(
            top_dropoffs.sort_values('dropoff_count'),  # Sort to have highest bar at top
            y="DOLocationID",
            x="dropoff_count",
            orientation='h',  # Horizontal orientation
            title="Top 20 Dropoff Locations",
            labels={"DOLocationID": "Dropoff Location ID", "dropoff_count": "Number of Dropoffs"},
            color="dropoff_count",  # Color by count
            color_continuous_scale="Reds"  # Use a color scale
        )
        fig.update_layout(
            yaxis_title="Dropoff Location ID", 
            xaxis_title="Number of Dropoffs",
            height=600  # Make it taller to fit all bars
        )
        fig.show()

# Top location pairs - convert sunburst to heatmap
if 'PULocationID' in column_names and 'DOLocationID' in column_names:
    top_location_pairs = execute_query(conn, """
        SELECT 
            PULocationID,
            DOLocationID,
            COUNT(*) as trip_count
        FROM all_taxi_trips
        GROUP BY PULocationID, DOLocationID
        ORDER BY trip_count DESC
        LIMIT 50  # Increased to get better heatmap
    """, "Error getting top location pairs")
    
    if not top_location_pairs.empty:
        print("\nTop pickup-dropoff location pairs:")
        print(top_location_pairs.head(20))
        
        # Create a pivot table for the heatmap
        pivot_data = top_location_pairs.pivot_table(
            index='PULocationID', 
            columns='DOLocationID', 
            values='trip_count',
            fill_value=0
        )
        
        # Create heatmap for top location pairs
        fig = px.imshow(
            pivot_data,
            labels=dict(x="Dropoff Location ID", y="Pickup Location ID", color="Trip Count"),
            title="Top Pickup-Dropoff Location Pairs",
            color_continuous_scale="Viridis"
        )
        fig.update_layout(
            width=800,
            height=800
        )
        fig.show()
        
        # Additionally, create a more readable bar chart for the top 15 pairs
        top_15_pairs = top_location_pairs.head(15)
        top_15_pairs['location_pair'] = top_15_pairs.apply(
            lambda row: f"{row['PULocationID']} → {row['DOLocationID']}", axis=1
        )
        
        fig = px.bar(
            top_15_pairs.sort_values('trip_count'),
            y='location_pair',
            x='trip_count',
            orientation='h',
            title="Top 15 Pickup-Dropoff Location Pairs",
            labels={'location_pair': 'Pickup → Dropoff', 'trip_count': 'Number of Trips'},
            color='trip_count',
            color_continuous_scale="Viridis"
        )
        fig.update_layout(
            height=500,
            yaxis_title="Location Pair (Pickup → Dropoff)",
            xaxis_title="Number of Trips"
        )
        fig.show()

# ------------------------------------------------
# 3. Improved Payment Type Analysis
# ------------------------------------------------

print("\n3. Improved Payment Type Analysis")
print("-----------------------------------------")

# Payment type distribution - convert pie chart to bar chart
if 'payment_type' in column_names:
    payment_dist = execute_query(conn, """
        SELECT 
            payment_type,
            COUNT(*) as count,
            100.0 * COUNT(*) / (SELECT COUNT(*) FROM all_taxi_trips) as percentage
        FROM all_taxi_trips
        GROUP BY payment_type
        ORDER BY count DESC
    """, "Error analyzing payment types")
    
    if not payment_dist.empty:
        print("Payment type distribution:")
        print(payment_dist)
        
        # Add descriptions for payment types
        payment_descriptions = {
            1: "Credit Card",
            2: "Cash",
            3: "No Charge",
            4: "Dispute",
            5: "Unknown",
            0: "Other"
        }
        
        payment_dist['description'] = payment_dist['payment_type'].map(
            lambda x: payment_descriptions.get(x, f"Unknown ({x})")
        )
        
        # Create bar chart for payment types
        fig = px.bar(
            payment_dist,
            x='description',
            y='percentage',
            title="Distribution of Payment Types",
            labels={'description': 'Payment Type', 'percentage': 'Percentage (%)'},
            color='description',
            text='percentage'
        )
        fig.update_traces(texttemplate='%{text:.1f}%', textposition='outside')
        fig.update_layout(
            xaxis_title="Payment Type",
            yaxis_title="Percentage (%)",
            yaxis=dict(range=[0, 100])  # Set y-axis from 0 to 100%
        )
        fig.show()
        
        # Tip amount analysis by payment type
        if all(col in column_names for col in ['payment_type', 'tip_amount', 'total_amount']):
            tip_analysis = execute_query(conn, """
                SELECT 
                    payment_type,
                    AVG(tip_amount) as avg_tip,
                    AVG(CASE WHEN total_amount > 0 THEN tip_amount / total_amount * 100 ELSE 0 END) as avg_tip_percentage
                FROM all_taxi_trips
                GROUP BY payment_type
                ORDER BY avg_tip DESC
            """, "Error analyzing tip amounts")
            
            if not tip_analysis.empty:
                # Add payment type descriptions
                tip_analysis['description'] = tip_analysis['payment_type'].map(
                    lambda x: payment_descriptions.get(x, f"Unknown ({x})")
                )
                
                # Create bar chart for average tip by payment type
                fig = px.bar(
                    tip_analysis,
                    x='description',
                    y=['avg_tip', 'avg_tip_percentage'],
                    title="Average Tip by Payment Type",
                    barmode="group",
                    labels={
                        'description': 'Payment Type', 
                        'value': 'Value',
                        'variable': 'Metric'
                    }
                )
                fig.update_layout(
                    xaxis_title="Payment Type", 
                    yaxis_title="Value",
                    legend_title="Metric",
                    legend=dict(
                        orientation="h",
                        yanchor="bottom",
                        y=1.02,
                        xanchor="right",
                        x=1
                    )
                )
                # Rename the legend items
                newnames = {'avg_tip': 'Avg Tip ($)', 'avg_tip_percentage': 'Avg Tip (%)'}
                fig.for_each_trace(lambda t: t.update(name = newnames[t.name]))
                fig.show()

# ------------------------------------------------
# 4. Improved Rate Code Analysis
# ------------------------------------------------

print("\n4. Improved Rate Code Analysis")
print("-----------------------------------------")

if 'RatecodeID' in column_names:
    try:
        # Analyze rate code distribution
        ratecode_dist = execute_query(conn, """
            SELECT
                RatecodeID,
                COUNT(*) as count,
                100.0 * COUNT(*) / (SELECT COUNT(*) FROM all_taxi_trips) as percentage
            FROM all_taxi_trips
            GROUP BY RatecodeID
            ORDER BY count DESC
        """, "Error analyzing rate codes")
        
        if not ratecode_dist.empty:
            # Add rate code descriptions if available
            ratecode_descriptions = {
                1: "Standard rate",
                2: "JFK",
                3: "Newark",
                4: "Nassau or Westchester",
                5: "Negotiated fare",
                6: "Group ride"
            }
            
            ratecode_dist['description'] = ratecode_dist['RatecodeID'].map(
                lambda x: ratecode_descriptions.get(x, f"Unknown ({x})")
            )
            
            print("Rate code distribution:")
            print(ratecode_dist)
            
            # Create bar chart for rate code distribution
            fig = px.bar(
                ratecode_dist,
                x='description',
                y='percentage',
                title="Distribution of Rate Codes",
                labels={'description': 'Rate Code', 'percentage': 'Percentage (%)'},
                color='description',
                text='percentage'
            )
            fig.update_traces(texttemplate='%{text:.1f}%', textposition='outside')
            fig.update_layout(
                xaxis_title="Rate Code",
                yaxis_title="Percentage (%)",
                yaxis=dict(range=[0, 100])  # Set y-axis from 0 to 100%
            )
            fig.show()
            
            # Analyze fare amounts by rate code (if fare_amount exists)
            if 'fare_amount' in column_names and 'trip_distance' in column_names:
                rate_fare_query = """
                    SELECT
                        RatecodeID,
                        COUNT(*) as count,
                        AVG(fare_amount) as avg_fare,
                        AVG(trip_distance) as avg_distance,
                        AVG(CASE WHEN trip_distance > 0 THEN fare_amount / trip_distance ELSE NULL END) as avg_fare_per_mile
                    FROM all_taxi_trips
                    WHERE fare_amount BETWEEN 0 AND 200
                    AND trip_distance BETWEEN 0 AND 100
                    GROUP BY RatecodeID
                    ORDER BY avg_fare DESC
                """
                
                rate_fare_analysis = execute_query(conn, rate_fare_query, "Error analyzing fares by rate code")
                
                if not rate_fare_analysis.empty:
                    # Add rate code descriptions
                    rate_fare_analysis['description'] = rate_fare_analysis['RatecodeID'].map(
                        lambda x: ratecode_descriptions.get(x, f"Unknown ({x})")
                    )
                    
                    print("\nFare analysis by rate code:")
                    print(rate_fare_analysis)
                    
                    # Create a more readable multi-measure bar chart
                    fig = make_subplots(specs=[[{"secondary_y": True}]])
                    
                    # Add bars for average fare
                    fig.add_trace(
                        go.Bar(
                            x=rate_fare_analysis['description'],
                            y=rate_fare_analysis['avg_fare'],
                            name="Avg Fare ($)",
                            marker_color='royalblue'
                        ),
                        secondary_y=False
                    )
                    
                    # Add bars for average distance
                    fig.add_trace(
                        go.Bar(
                            x=rate_fare_analysis['description'],
                            y=rate_fare_analysis['avg_distance'],
                            name="Avg Distance (miles)",
                            marker_color='firebrick'
                        ),
                        secondary_y=False
                    )
                    
                    # Add line for average fare per mile
                    fig.add_trace(
                        go.Scatter(
                            x=rate_fare_analysis['description'],
                            y=rate_fare_analysis['avg_fare_per_mile'],
                            name="Avg Fare per Mile ($)",
                            mode="lines+markers",
                            marker=dict(color='green', size=10),
                            line=dict(width=3)
                        ),
                        secondary_y=True
                    )
                    
                    fig.update_layout(
                        title_text="Rate Code Analysis: Fare and Distance Metrics",
                        barmode='group',
                        xaxis_title="Rate Code",
                        legend=dict(
                            orientation="h",
                            yanchor="bottom",
                            y=1.02,
                            xanchor="right",
                            x=1
                        )
                    )
                    
                    fig.update_yaxes(title_text="Value", secondary_y=False)
                    fig.update_yaxes(title_text="Fare per Mile ($)", secondary_y=True)
                    
                    fig.show()
    except Exception as e:
        print(f"Could not perform rate code analysis: {e}")
else:
    print("RatecodeID column not found. Skipping rate code analysis.")

# ------------------------------------------------
# 5. Passenger Count Distribution
# ------------------------------------------------

print("\n5. Passenger Count Distribution")
print("-----------------------------------------")

if 'passenger_count' in column_names:
    passenger_dist = execute_query(conn, """
        SELECT 
            passenger_count,
            COUNT(*) as count,
            100.0 * COUNT(*) / (SELECT COUNT(*) FROM all_taxi_trips) as percentage
        FROM all_taxi_trips
        WHERE passenger_count BETWEEN 0 AND 9  -- Include 0 but filter out unreasonable values
        GROUP BY passenger_count
        ORDER BY passenger_count
    """, "Error analyzing passenger counts")
    
    if not passenger_dist.empty:
        print("\nPassenger count distribution:")
        print(passenger_dist)
        
        # Create column chart for passenger counts
        fig = px.bar(
            passenger_dist,
            x="passenger_count",
            y="percentage",
            title="Distribution of Passenger Counts",
            labels={"passenger_count": "Number of Passengers", "percentage": "Percentage (%)"},
            color="passenger_count",
            text="percentage"
        )
        fig.update_traces(texttemplate='%{text:.1f}%', textposition='outside')
        fig.update_layout(
            xaxis_title="Number of Passengers", 
            yaxis_title="Percentage (%)",
            xaxis=dict(tickmode='linear', tick0=0, dtick=1),
            yaxis=dict(range=[0, max(passenger_dist['percentage'])*1.2])  # Add some headroom for labels
        )
        fig.show()

# ------------------------------------------------
# 6. Trip Distance Distribution
# ------------------------------------------------

print("\n6. Trip Distance Distribution")
print("-----------------------------------------")

if 'trip_distance' in column_names:
    # Create distance buckets for better visualization
    distance_dist = execute_query(conn, """
        SELECT
            CASE
                WHEN trip_distance < 1 THEN '0-1'
                WHEN trip_distance < 2 THEN '1-2'
                WHEN trip_distance < 3 THEN '2-3'
                WHEN trip_distance < 5 THEN '3-5'
                WHEN trip_distance < 10 THEN '5-10'
                WHEN trip_distance < 20 THEN '10-20'
                ELSE '20+'
            END as distance_range,
            COUNT(*) as count,
            100.0 * COUNT(*) / (SELECT COUNT(*) FROM all_taxi_trips) as percentage
        FROM all_taxi_trips
        WHERE trip_distance >= 0 AND trip_distance <= 100  -- Filter reasonable values
        GROUP BY distance_range
        ORDER BY
            CASE distance_range
                WHEN '0-1' THEN 1
                WHEN '1-2' THEN 2
                WHEN '2-3' THEN 3
                WHEN '3-5' THEN 4
                WHEN '5-10' THEN 5
                WHEN '10-20' THEN 6
                WHEN '20+' THEN 7
            END
    """, "Error analyzing trip distances")
    
    if not distance_dist.empty:
        print("\nTrip distance distribution:")
        print(distance_dist)
        
        # Create column chart for distance distribution
        fig = px.bar(
            distance_dist,
            x="distance_range",
            y="percentage",
            title="Distribution of Trip Distances",
            labels={"distance_range": "Trip Distance (miles)", "percentage": "Percentage (%)"},
            color="distance_range",
            color_continuous_scale="Viridis",
            text="percentage"
        )
        fig.update_traces(texttemplate='%{text:.1f}%', textposition='outside')
        fig.update_layout(
            xaxis_title="Trip Distance Range (miles)", 
            yaxis_title="Percentage (%)",
            yaxis=dict(range=[0, max(distance_dist['percentage'])*1.2])  # Add some headroom for labels
        )
        fig.show()

# Close the database connection
conn.close()

print("\nVisualization Analysis Complete!")
print("======================================")

1. Connecting to DuckDB and Loading Data
-----------------------------------------
Available tables in the database:
                name
0      all_fhv_trips
1     all_taxi_trips
2          fhv_trips
3   green_taxi_trips
4        hvfhv_trips
5        nyc_weather
6   taxi_zone_lookup
7  yellow_taxi_trips

2. Improved Geographic Analysis Visualizations
-----------------------------------------

Top 20 pickup locations:
    PULocationID  pickup_count
0            132       5898868
1            237       5566271
2            161       5253823
3            236       4997465
4            162       4064701
5            186       3961770
6            230       3882503
7            142       3867582
8            138       3690412
9            170       3522290
10           163       3397725
11           239       3335144
12            48       3267726
13           234       3173747
14            68       3106080
15            79       2882707
16           141       2869329
17           164    


Top 20 dropoff locations:
    DOLocationID  dropoff_count
0            236        5258673
1            237        4978028
2            161        4488000
3            230        3667387
4            170        3518898
5            239        3338298
6            162        3331539
7            142        3324967
8            141        3165837
9             48        3028207
10            68        3014203
11           163        2927239
12           234        2768441
13           238        2737347
14           186        2646929
15           164        2523920
16           263        2518805
17           229        2499588
18            79        2469637
19           140        2428947


Error getting top location pairs: Parser Error: syntax error at or near "#"

3. Improved Payment Type Analysis
-----------------------------------------
Payment type distribution:
   payment_type     count  percentage
0           1.0  91808036   75.610070
1           2.0  20393274   16.795228
2           0.0   6768891    5.574635
3           4.0   1541190    1.269273
4           3.0    741288    0.610500
5           NaN    170281    0.140238
6           5.0        68    0.000056



4. Improved Rate Code Analysis
-----------------------------------------
Rate code distribution:
   RatecodeID      count  percentage            description
0         1.0  107727695   88.720976          Standard rate
1         NaN    6939172    5.714873          Unknown (nan)
2         2.0    4388898    3.614552                    JFK
3         5.0     951118    0.783309        Negotiated fare
4        99.0     814390    0.670705         Unknown (99.0)
5         3.0     368072    0.303132                 Newark
6         4.0     233174    0.192034  Nassau or Westchester
7         6.0        509    0.000419             Group ride



Fare analysis by rate code:
   RatecodeID      count   avg_fare  avg_distance  avg_fare_per_mile  \
0         4.0     208767  89.959870     18.550923           5.127690   
1         3.0     353283  79.965613     16.228820          29.588510   
2         2.0    4301868  63.940195     17.462328          45.991865   
3         5.0     903262  60.971895      4.603308         458.460433   
4        99.0     814144  33.629908      7.244079           8.448770   
5         NaN    6800224  21.070285      3.490498          17.033666   
6         1.0  106624782  15.312545      2.740673           7.557727   
7         6.0        460   4.776978      1.444609          34.416578   

             description  
0  Nassau or Westchester  
1                 Newark  
2                    JFK  
3        Negotiated fare  
4         Unknown (99.0)  
5          Unknown (nan)  
6          Standard rate  
7             Group ride  



5. Passenger Count Distribution
-----------------------------------------

Passenger count distribution:
   passenger_count     count  percentage
0              0.0   1764155    1.452900
1              1.0  86505723   71.243260
2              2.0  17073922   14.061519
3              3.0   4248100    3.498595
4              4.0   2324358    1.914265
5              5.0   1541183    1.269267
6              6.0   1024885    0.844061
7              7.0       536    0.000441
8              8.0       759    0.000625
9              9.0       235    0.000194



6. Trip Distance Distribution
-----------------------------------------

Trip distance distribution:
  distance_range     count  percentage
0            0-1  26797768   22.069758
1            1-2  39215058   32.296228
2            2-3  19828541   16.330132
3            3-5  14628383   12.047454
4           5-10  10843525    8.930369
5          10-20   8813809    7.258762
6            20+   1289822    1.062255



Visualization Analysis Complete!


In [50]:
import pandas as pd
import numpy as np
import duckdb
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
import os

# Suppress warnings
warnings.filterwarnings('ignore')

# Set styling
plt.style.use('fivethirtyeight')
pd.set_option('display.max_columns', None)

# ------------------------------------------------
# Connect to DuckDB and Load Data
# ------------------------------------------------
print("Connecting to DuckDB and Loading Data")
print("-----------------------------------------")

# Initialize DuckDB connection
conn = duckdb.connect('../db/nyc_taxi.duckdb')
print("Connected to database")

# Function to execute SQL queries
def execute_query(query, error_message="Error executing query"):
    """Execute SQL query and return results as DataFrame"""
    try:
        return conn.execute(query).fetchdf()
    except Exception as e:
        print(f"{error_message}: {str(e)}")
        return pd.DataFrame()

# ------------------------------------------------
# 1. Top Pickup Locations Visualization
# ------------------------------------------------
print("\nCreating top pickup locations visualization")

# Get top pickup locations
top_pickups = execute_query("""
    SELECT 
        PULocationID,
        COUNT(*) as pickup_count
    FROM all_taxi_trips
    GROUP BY PULocationID
    ORDER BY pickup_count DESC
    LIMIT 20
""", "Error getting top pickup locations")

if not top_pickups.empty:
    # Sort data for better visualization (highest count at top)
    top_pickups_sorted = top_pickups.sort_values('pickup_count', ascending=True)
    
    # Format large numbers for display
    top_pickups_sorted['formatted_count'] = top_pickups_sorted['pickup_count'].apply(
        lambda x: f"{x/1000000:.1f}M" if x >= 1000000 else f"{x/1000:.0f}K"
    )
    
    # Create horizontal bar chart
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        y=top_pickups_sorted['PULocationID'].astype(str),
        x=top_pickups_sorted['pickup_count'],
        orientation='h',
        marker=dict(
            color=top_pickups_sorted['pickup_count'],
            colorscale='Blues',
            colorbar=dict(
                title="Number of Pickups",
                tickvals=[min(top_pickups_sorted['pickup_count']), max(top_pickups_sorted['pickup_count'])],
                ticktext=[f"{min(top_pickups_sorted['pickup_count'])/1000000:.1f}M", 
                          f"{max(top_pickups_sorted['pickup_count'])/1000000:.1f}M"]
            )
        ),
        text=top_pickups_sorted['formatted_count'],
        textposition='outside',
        hovertemplate='Location ID: %{y}<br>Pickups: %{x:,.0f}<extra></extra>'
    ))
    
    fig.update_layout(
        title={
            'text': 'Top 20 Pickup Locations',
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'
        },
        xaxis_title='Number of Pickups',
        yaxis_title='Pickup Location ID',
        xaxis=dict(
            tickformat=",.0f"
        ),
        height=600,
        margin=dict(l=100, r=50, t=80, b=80),
        plot_bgcolor='white',
        paper_bgcolor='white',
        font=dict(
            family="Arial",
            size=14
        )
    )
    
    # Add grid for readability
    fig.update_yaxes(
        showgrid=True, 
        gridwidth=1, 
        gridcolor='lightgray'
    )
    
    fig.show()
else:
    print("No pickup location data available")

# ------------------------------------------------
# 2. Top Dropoff Locations Visualization
# ------------------------------------------------
print("\nCreating top dropoff locations visualization")

# Get top dropoff locations
top_dropoffs = execute_query("""
    SELECT 
        DOLocationID,
        COUNT(*) as dropoff_count
    FROM all_taxi_trips
    GROUP BY DOLocationID
    ORDER BY dropoff_count DESC
    LIMIT 20
""", "Error getting top dropoff locations")

if not top_dropoffs.empty:
    # Sort data for better visualization (highest count at top)
    top_dropoffs_sorted = top_dropoffs.sort_values('dropoff_count', ascending=True)
    
    # Format large numbers for display
    top_dropoffs_sorted['formatted_count'] = top_dropoffs_sorted['dropoff_count'].apply(
        lambda x: f"{x/1000000:.1f}M" if x >= 1000000 else f"{x/1000:.0f}K"
    )
    
    # Create horizontal bar chart
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        y=top_dropoffs_sorted['DOLocationID'].astype(str),
        x=top_dropoffs_sorted['dropoff_count'],
        orientation='h',
        marker=dict(
            color=top_dropoffs_sorted['dropoff_count'],
            colorscale='Reds',
            colorbar=dict(
                title="Number of Dropoffs",
                tickvals=[min(top_dropoffs_sorted['dropoff_count']), max(top_dropoffs_sorted['dropoff_count'])],
                ticktext=[f"{min(top_dropoffs_sorted['dropoff_count'])/1000000:.1f}M", 
                          f"{max(top_dropoffs_sorted['dropoff_count'])/1000000:.1f}M"]
            )
        ),
        text=top_dropoffs_sorted['formatted_count'],
        textposition='outside',
        hovertemplate='Location ID: %{y}<br>Dropoffs: %{x:,.0f}<extra></extra>'
    ))
    
    fig.update_layout(
        title={
            'text': 'Top 20 Dropoff Locations',
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'
        },
        xaxis_title='Number of Dropoffs',
        yaxis_title='Dropoff Location ID',
        xaxis=dict(
            tickformat=",.0f"
        ),
        height=600,
        margin=dict(l=100, r=50, t=80, b=80),
        plot_bgcolor='white',
        paper_bgcolor='white',
        font=dict(
            family="Arial",
            size=14
        )
    )
    
    # Add grid for readability
    fig.update_yaxes(
        showgrid=True, 
        gridwidth=1, 
        gridcolor='lightgray'
    )
    
    fig.show()
else:
    print("No dropoff location data available")

# ------------------------------------------------
# 3. Top Location Pairs Visualization
# ------------------------------------------------
print("\nCreating top location pairs visualization")

# Get top location pairs
top_location_pairs = execute_query("""
    SELECT 
        PULocationID,
        DOLocationID,
        COUNT(*) as trip_count
    FROM all_taxi_trips
    GROUP BY PULocationID, DOLocationID
    ORDER BY trip_count DESC
    LIMIT 15
""", "Error getting top location pairs")

if not top_location_pairs.empty:
    # Create labels for the pairs
    top_location_pairs['pair_label'] = top_location_pairs.apply(
        lambda row: f"{row['PULocationID']} → {row['DOLocationID']}", axis=1
    )
    
    # Sort for visualization
    top_location_pairs_sorted = top_location_pairs.sort_values('trip_count', ascending=True)
    
    # Format counts for display
    top_location_pairs_sorted['formatted_count'] = top_location_pairs_sorted['trip_count'].apply(
        lambda x: f"{x/1000000:.1f}M" if x >= 1000000 else f"{x/1000:.0f}K"
    )
    
    # Create horizontal bar chart
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        y=top_location_pairs_sorted['pair_label'],
        x=top_location_pairs_sorted['trip_count'],
        orientation='h',
        marker=dict(
            color=top_location_pairs_sorted['trip_count'],
            colorscale='Viridis',
            colorbar=dict(
                title="Number of Trips",
                tickvals=[min(top_location_pairs_sorted['trip_count']), 
                          max(top_location_pairs_sorted['trip_count'])],
                ticktext=[f"{min(top_location_pairs_sorted['trip_count'])/1000000:.1f}M" 
                         if min(top_location_pairs_sorted['trip_count']) >= 1000000 
                         else f"{min(top_location_pairs_sorted['trip_count'])/1000:.0f}K",
                         
                         f"{max(top_location_pairs_sorted['trip_count'])/1000000:.1f}M" 
                         if max(top_location_pairs_sorted['trip_count']) >= 1000000 
                         else f"{max(top_location_pairs_sorted['trip_count'])/1000:.0f}K"]
            )
        ),
        text=top_location_pairs_sorted['formatted_count'],
        textposition='outside',
        hovertemplate='Route: %{y}<br>Trips: %{x:,.0f}<extra></extra>'
    ))
    
    fig.update_layout(
        title={
            'text': 'Top 15 Pickup-Dropoff Location Pairs',
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'
        },
        xaxis_title='Number of Trips',
        yaxis_title='Pickup → Dropoff Location IDs',
        xaxis=dict(
            tickformat=",.0f"
        ),
        height=600,
        margin=dict(l=100, r=50, t=80, b=80),
        plot_bgcolor='white',
        paper_bgcolor='white',
        font=dict(
            family="Arial",
            size=14
        )
    )
    
    # Add grid for readability
    fig.update_yaxes(
        showgrid=True, 
        gridwidth=1, 
        gridcolor='lightgray'
    )
    
    fig.show()
    
    # Additionally, create a heatmap visualization
    print("\nCreating heatmap for top location pairs")
    
    # Get more pairs for the heatmap
    heatmap_pairs = execute_query("""
        SELECT 
            PULocationID,
            DOLocationID,
            COUNT(*) as trip_count
        FROM all_taxi_trips
        GROUP BY PULocationID, DOLocationID
        ORDER BY trip_count DESC
        LIMIT 100
    """, "Error getting location pairs for heatmap")
    
    if not heatmap_pairs.empty:
        # Create pivot table for heatmap
        # Filter to top 10 pickup and dropoff locations for readability
        top_pu = heatmap_pairs['PULocationID'].value_counts().nlargest(10).index.tolist()
        top_do = heatmap_pairs['DOLocationID'].value_counts().nlargest(10).index.tolist()
        
        heatmap_data = heatmap_pairs[
            heatmap_pairs['PULocationID'].isin(top_pu) & 
            heatmap_pairs['DOLocationID'].isin(top_do)
        ]
        
        # Create pivot table
        pivot_data = heatmap_data.pivot_table(
            index='PULocationID',
            columns='DOLocationID',
            values='trip_count',
            fill_value=0
        )
        
        # Create heatmap
        fig = px.imshow(
            pivot_data,
            labels=dict(x="Dropoff Location ID", y="Pickup Location ID", color="Trip Count"),
            title="Heatmap of Top Pickup-Dropoff Location Pairs",
            color_continuous_scale="Viridis"
        )
        
        fig.update_layout(
            width=800,
            height=700
        )
        
        fig.show()
else:
    print("No location pair data available")

# Close the database connection
conn.close()

print("\nVisualization creation complete!")

Connecting to DuckDB and Loading Data
-----------------------------------------
Connected to database

Creating top pickup locations visualization



Creating top dropoff locations visualization



Creating top location pairs visualization



Creating heatmap for top location pairs



Visualization creation complete!
