# üß™ Objectif :
# - Lancer une pr√©diction batch Vertex AI Forecast sur des donn√©es futures
# - Visualiser les r√©sultats

In [None]:
import os
from google.cloud import aiplatform
import pandas as pd

# === 1. Initialisation ===

In [None]:
PROJECT_ID = "avisia-certification-ml-yde"
REGION = "us-central1"
BUCKET_URI = f"gs://{PROJECT_ID}-vertex-bucket"
BQ_INPUT = "bq://avisia-certification-ml-yde.chicago_taxis.forecast_input"   # table contenant les timestamps futurs
BQ_OUTPUT = "bq://avisia-certification-ml-yde.chicago_taxis.forecast_output"

aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

# === 2. Chargement du mod√®le entra√Æn√© ===


In [None]:
model = aiplatform.Model.list(
    filter='display_name="taxi_demand_model"',
    order_by="update_time desc"
)[0]

print("‚úÖ Dernier mod√®le entra√Æn√© :", model.display_name)

# === 3. Lancement du job de pr√©diction batch ===


In [None]:
batch_job = model.batch_predict(
    job_display_name="batch_pred_taxi_demand",
    instances_format="bigquery",
    predictions_format="bigquery",
    bigquery_source=BQ_INPUT,
    bigquery_destination_prefix=BQ_OUTPUT,
    sync=True,
)

print("‚úÖ Pr√©dictions enregistr√©es dans :", BQ_OUTPUT)

# === 4. Analyse des r√©sultats (optionnel si output vers BigQuery) ===


In [None]:
from google.cloud import bigquery

client = bigquery.Client(project=PROJECT_ID)

query = """
SELECT
  instance.pickup_community_area,
  instance.timestamp_hour,
  prediction.value[OFFSET(0)] AS predicted_trip_count
FROM `avisia-certification-ml-yde.chicago_taxis.forecast_output.predictions`
ORDER BY timestamp_hour, pickup_community_area
"""

df_pred = client.query(query).to_dataframe()
df_pred.head()

# === 5. Visualisation ===


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Visualiser pour les 5 zones les plus actives
top_zones = df_pred["pickup_community_area"].value_counts().index[:5]
df_filtered = df_pred[df_pred["pickup_community_area"].isin(top_zones)]

plt.figure(figsize=(16, 8))
sns.lineplot(data=df_filtered, x="timestamp_hour", y="predicted_trip_count", hue="pickup_community_area")
plt.title("Pr√©vision du nombre de courses par heure (5 zones principales)")
plt.xlabel("Heure")
plt.ylabel("Courses pr√©vues")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
