In [None]:
import os
import glob
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import Point
import pandas as pd
import time
from IPython.display import clear_output
import matplotlib.colors as mcolors
import contextily as cx

# Define the paths and patterns to match all part files
file_paths = [
    "./data/combined_train_schedules_rer_A.csv/",
    "./data/combined_train_schedules_rer_B.csv/",
    "./data/combined_train_schedules_rer_C.csv/",
    "./data/combined_train_schedules_rer_D.csv/",
    "./data/combined_train_schedules_rer_E.csv/",
]

# Define color mapping for RER lines
rer_colors = {
    'A': 'red',
    'B': 'blue',
    'C': 'goldenrod',
    'D': 'green',
    'E': 'purple'
}

file_patterns = [os.path.join(path, "part-*.csv") for path in file_paths]

# Use glob to get all part file paths
all_files = [file for pattern in file_patterns for file in glob.glob(pattern)]

# Initialize an empty list to hold DataFrames
dataframes = []

# Iterate over the file paths and read each CSV file
for file in all_files:
    df = pd.read_csv(file)
    dataframes.append(df)

# Concatenate all DataFrames into a single DataFrame
train_data = pd.concat(dataframes, ignore_index=True)

# Modify the journey_ids list to include the line information
journey_ids = [
    ("A", "vehicle_journey:SNCF:2024-07-08:BABY17:1187:RapidTransit"),
    ("B", "vehicle_journey:SNCF:2024-07-08:EACE07:1187:RapidTransit"),
    ("C", "vehicle_journey:SNCF:2024-07-08:141186-141187:1187:RapidTransit"),
    ("D", "vehicle_journey:SNCF:2024-07-08:122302-122303:1187:RapidTransit"),
    ("E", "vehicle_journey:SNCF:2024-07-08:116192:1187:RapidTransit"),
]

# Create a dictionary to hold GeoDataFrames for each journey
journeys = {}

for line, journey_id in journey_ids:
    journey = train_data[train_data["journey_id"] == journey_id]
    geometry = [Point(xy) for xy in zip(journey["longitude"], journey["latitude"])]
    journeys[line] = gpd.GeoDataFrame(journey, geometry=geometry)

# Function to plot the journey progress with the complete path, current train position, and background map
def plot_journey_with_train_and_map(journeys, current_indices):
    fig, ax = plt.subplots(figsize=(20, 15))  # Increased figure size
    ax.clear()

    legend_items = []
    station_info = {}

    for i, (line, gdf_journey) in enumerate(journeys.items()):
        current_index = current_indices[i]

        color = rer_colors[line]
        light_color = mcolors.to_rgba(color, alpha=0.3)

        # Plot the complete path of the journey
        gdf_journey.plot(ax=ax, color=light_color, linestyle='--', label=f"RER {line} Path")
        # Plot the current state of the journey
        gdf_journey.iloc[:current_index + 1].plot(ax=ax, color=color, marker="o", label=f"RER {line} Progress")
        ax.plot(gdf_journey.geometry.x, gdf_journey.geometry.y, color=light_color, linestyle='--')
        ax.plot(gdf_journey.iloc[:current_index + 1].geometry.x, gdf_journey.iloc[:current_index + 1].geometry.y, color=color)

        # Add labels for the stations
        for idx, (x, y, label) in enumerate(zip(gdf_journey.geometry.x, gdf_journey.geometry.y, gdf_journey["stop_point_name"])):
            station_label = f"{line}{idx+1}"
            ax.text(x, y, station_label, fontsize=9, ha="right", color=color)
            station_info[station_label] = label

        # Mark the current position of the trains
        ax.scatter(gdf_journey.iloc[current_index].geometry.x, gdf_journey.iloc[current_index].geometry.y, color=color, s=100, zorder=5, edgecolor='black')

        legend_items.append(plt.Line2D([0], [0], color=color, lw=2, label=f'RER {line}'))

    plt.title("Train Journeys with Current Position and Île-de-France Map")
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")

    # Add legend for RER lines
    ax.legend(handles=legend_items, loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.)

    # Add station information
    station_text = "\n".join([f"{k}: {v}" for k, v in station_info.items()])
    plt.text(1.05, 0.5, station_text, transform=ax.transAxes, fontsize=8, verticalalignment='center')

    plt.axis('off')
    plt.tight_layout()
    plt.show()

# Simulate the journey progress
max_length = max(len(journey) for journey in journeys.values())
for i in range(max_length):
    current_indices = [min(i, len(journey)) - 1 for journey in journeys.values()]
    clear_output(wait=True)
    plot_journey_with_train_and_map(journeys, current_indices)
    time.sleep(1)  # Simulate real-time delay
