In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyathena import connect
from mpl_toolkits.mplot3d import Axes3D

sns.set(style='whitegrid')

def fetch_data_from_athena(s3_staging_dir, region_name, database, query):
    conn = connect(s3_staging_dir=s3_staging_dir, region_name=region_name)
    return pd.read_sql(query, conn)

# Athena settings
s3_staging_dir = 's3://athena-results-bucket/'
region_name = 'us-east-2'
database = 'vehicle_data_db'
table_name = 'vehicle_data_table'
query = f'SELECT * FROM {database}.{table_name}'

data = fetch_data_from_athena(s3_staging_dir, region_name, database, query)

data['timestep_time'] = pd.to_numeric(data['timestep_time'], errors='coerce')

# Scatter plot: Vehicle Speed vs CO2 Emissions
plt.figure(figsize=(12, 6))
sns.scatterplot(data=data, x='vehicle_speed', y='vehicle_CO2', hue='car_id', palette='viridis', s=50)
plt.title('Vehicle Speed vs CO2 Emissions')
plt.xlabel('Speed (m/s)')
plt.ylabel('CO2 Emissions (g/s)')
plt.legend(title='Car ID')
plt.savefig('speed_vs_CO2_emissions.png')
plt.show()


# Line plot: Fuel Consumption Over Time
plt.figure(figsize=(12, 6))
sns.lineplot(data=data, x='timestep_time', y='vehicle_fuel', hue='car_id', marker='o')
plt.title('Vehicle Fuel Consumption Over Time')
plt.xlabel('Time Step')
plt.ylabel('Fuel Consumption (L/s)')
plt.legend(title='Car ID')
plt.savefig('fuel_consumption_over_time.png')
plt.show()


# Line plot: Vehicle Speed Over Time
plt.figure(figsize=(12, 6))
sns.lineplot(data=data, x='timestep_time', y='vehicle_speed', hue='car_id', marker='o')
plt.title('Vehicle Speed Over Time')
plt.xlabel('Time Step')
plt.ylabel('Speed (m/s)')
plt.legend(title='Car ID')
plt.savefig('vehicle_speed_over_time.png')
plt.show()
