In [None]:
import os
import pickle

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from gnn_package import paths
from gnn_package import preprocessing

raw_file_name = "test_data_1wk.pkl"
raw_dir = paths.RAW_TIMESERIES_DIR
raw_file_path = os.path.join(paths.RAW_TIMESERIES_DIR, raw_file_name)
print(f"Loading raw data from {raw_file_path}")

In [None]:
if os.path.exists(raw_file_path):
    with open(raw_file_path, "rb") as f:
        results_containing_data = pickle.load(f)

# Clean the data
results_containing_data_cleaned = preprocessing.resample_sensor_data(
    results_containing_data,
    freq="15min",
    fill_value=-1,
)

In [None]:
results_containing_data["10000"].isnull().sum(), results_containing_data[
    "10003"
].isnull().sum()

In [None]:
# example series with missing values
example_series = results_containing_data_cleaned["10029"]

example_series.plot()

In [None]:
value_counts = {}

for name, series in results_containing_data_cleaned.items():
    temp = series.copy()
    temp = temp.reset_index()
    temp.columns = ["dt", "value"]
    value_counts[name] = temp["value"].value_counts(dropna=False)

value_counts_df = pd.DataFrame(value_counts)
value_counts_df

In [None]:
def generate_sensor_statistics(time_series_dict):
    """
    Generate comprehensive statistics for each sensor's time series data.

    Parameters:
    -----------
    time_series_dict : dict
        Dictionary mapping sensor IDs to their time series data

    Returns:
    --------
    pandas.DataFrame
        DataFrame containing summary statistics for each sensor
    """
    # Define the statistics we want to calculate
    stats_data = []

    # Calculate the full date range of the dataset
    all_dates = []
    for series in time_series_dict.values():
        if len(series) > 0:
            all_dates.extend(series.index.tolist())

    overall_start = min(all_dates)
    overall_end = max(all_dates)
    total_period_days = (overall_end - overall_start).total_seconds() / (60 * 60 * 24)

    # Expected readings for 15-minute intervals
    expected_readings = int(total_period_days * 24 * 4) + 1

    # Process each sensor
    for sensor_id, series in time_series_dict.items():
        # Skip empty series
        if series is None or len(series) == 0:
            continue

        # Basic count statistics
        total_readings = len(series)
        unique_readings = series.nunique()
        completeness = (total_readings / expected_readings) * 100

        # Create a resampled series to analyze missing data
        # First ensure the index is sorted and duplicates are removed
        series = series.sort_index()
        series = series[~series.index.duplicated(keep="first")]

        # Time-based statistics
        start_date = series.index.min()
        end_date = series.index.max()
        time_span_days = (end_date - start_date).total_seconds() / (60 * 60 * 24)

        # Value statistics
        min_value = series.min()
        max_value = series.max()
        mean_value = series.mean()
        median_value = series.median()
        std_dev = series.std()

        # Resampling to find gaps
        resampled = series.resample("15min").mean()
        missing_intervals = resampled.isna().sum()
        actual_intervals = (~resampled.isna()).sum()
        resampled_completeness = (actual_intervals / len(resampled)) * 100

        # Convert the boolean series to numpy array
        missing_mask = resampled.isna().values

        # Find transitions from not missing to missing (gap starts)
        gap_starts = []
        # Find transitions from missing to not missing (gap ends)
        gap_ends = []

        # Simpler approach to find gap transitions
        for i in range(1, len(missing_mask)):
            # Gap starts: previous value is not missing, current value is missing
            if missing_mask[i] and not missing_mask[i - 1]:
                gap_starts.append(i)
            # Gap ends: previous value is missing, current value is not missing
            elif not missing_mask[i] and missing_mask[i - 1]:
                gap_ends.append(i)

        # Handle case where gap is at beginning
        if missing_mask[0]:
            gap_starts.insert(0, 0)

        # Handle case where gap is at end
        if missing_mask[-1]:
            gap_ends.append(len(missing_mask))

        # Calculate the largest gap in hours
        largest_gap_hours = 0
        if (
            len(gap_starts) > 0
            and len(gap_ends) > 0
            and len(gap_starts) == len(gap_ends)
        ):
            gap_lengths = [end - start for start, end in zip(gap_starts, gap_ends)]
            if gap_lengths:
                largest_gap_hours = max(gap_lengths) * 0.25  # 15-min intervals to hours

        # Daily pattern analysis - hour of day with most readings
        if len(series) > 0:
            hour_counts = series.groupby(series.index.hour).size()
            busiest_hour = hour_counts.idxmax() if not hour_counts.empty else None
            busiest_hour_count = hour_counts.max() if not hour_counts.empty else 0

            # Day of week pattern
            dow_counts = series.groupby(series.index.dayofweek).size()
            busiest_day = dow_counts.idxmax() if not dow_counts.empty else None
            busiest_day_map = {
                0: "Monday",
                1: "Tuesday",
                2: "Wednesday",
                3: "Thursday",
                4: "Friday",
                5: "Saturday",
                6: "Sunday",
            }
            busiest_day_name = busiest_day_map.get(busiest_day, "Unknown")
        else:
            busiest_hour = None
            busiest_hour_count = 0
            busiest_day = None
            busiest_day_name = "Unknown"

        # Append to our statistics collection
        stats_data.append(
            {
                "sensor_id": sensor_id,
                "total_readings": total_readings,
                "unique_values": unique_readings,
                "raw_completeness_pct": completeness,
                "resampled_completeness_pct": resampled_completeness,
                "start_date": start_date,
                "end_date": end_date,
                "time_span_days": time_span_days,
                "min_value": min_value,
                "max_value": max_value,
                "mean_value": mean_value,
                "median_value": median_value,
                "std_dev": std_dev,
                "missing_15min_intervals": missing_intervals,
                "largest_gap_hours": largest_gap_hours,
                "readings_per_day": (
                    total_readings / time_span_days if time_span_days > 0 else 0
                ),
                "busiest_hour": busiest_hour,
                "busiest_hour_count": busiest_hour_count,
                "busiest_day": busiest_day_name,
            }
        )

    # Create DataFrame
    stats_df = pd.DataFrame(stats_data)

    # Sort by completeness
    stats_df = stats_df.sort_values("raw_completeness_pct", ascending=False)

    # Format percentages
    stats_df["raw_completeness_pct"] = stats_df["raw_completeness_pct"].round(2)
    stats_df["resampled_completeness_pct"] = stats_df[
        "resampled_completeness_pct"
    ].round(2)

    # Round numerical columns
    numeric_cols = [
        "mean_value",
        "median_value",
        "std_dev",
        "readings_per_day",
        "largest_gap_hours",
        "time_span_days",
    ]
    stats_df[numeric_cols] = stats_df[numeric_cols].round(2)

    return stats_df

In [None]:
stats_df = generate_sensor_statistics(results_containing_data_cleaned)
display(stats_df)

In [None]:
def analyze_time_series_data(time_series_dict):
    """
    Create a comprehensive analysis of the time series data with multiple dataframes

    Parameters:
    -----------
    time_series_dict : dict
        Dictionary mapping sensor IDs to their time series data

    Returns:
    --------
    dict
        Dictionary containing multiple DataFrames with different analyses
    """
    # Get basic statistics
    basic_stats = generate_sensor_statistics(time_series_dict)

    # Create a summary of data completeness and quality
    summary_df = pd.DataFrame(
        {
            "Total Sensors": len(time_series_dict),
            "Sensors with Data": len(basic_stats),
            "Average Completeness (%)": basic_stats["raw_completeness_pct"].mean(),
            "Median Completeness (%)": basic_stats["raw_completeness_pct"].median(),
            "Date Range Start": basic_stats["start_date"].min(),
            "Date Range End": basic_stats["end_date"].max(),
            "Total Days": (
                basic_stats["end_date"].max() - basic_stats["start_date"].min()
            ).days,
            "Average Readings per Sensor": basic_stats["total_readings"].mean(),
            "Max Readings (Sensor)": f"{basic_stats['total_readings'].max()} ({basic_stats.iloc[basic_stats['total_readings'].idxmax()]['sensor_id']})",
            "Min Readings (Sensor)": f"{basic_stats['total_readings'].min()} ({basic_stats.iloc[basic_stats['total_readings'].idxmin()]['sensor_id']})",
        },
        index=["Summary"],
    )

    # Calculate data distribution by time of day (hourly)
    hourly_data = {}
    for sensor_id, series in time_series_dict.items():
        if series is not None and len(series) > 0:
            hourly_data[sensor_id] = series.groupby(series.index.hour).size()

    hourly_df = pd.DataFrame(hourly_data).fillna(0)
    hourly_df["Average"] = hourly_df.mean(axis=1)

    # Calculate data distribution by day of week
    dow_data = {}
    for sensor_id, series in time_series_dict.items():
        if series is not None and len(series) > 0:
            dow_data[sensor_id] = series.groupby(series.index.dayofweek).size()

    dow_df = pd.DataFrame(dow_data).fillna(0)
    dow_df.index = [
        "Monday",
        "Tuesday",
        "Wednesday",
        "Thursday",
        "Friday",
        "Saturday",
        "Sunday",
    ]
    dow_df["Average"] = dow_df.mean(axis=1)

    # Calculate data distribution by day of month
    dom_data = {}
    for sensor_id, series in time_series_dict.items():
        if series is not None and len(series) > 0:
            dom_data[sensor_id] = series.groupby(series.index.day).size()

    dom_df = pd.DataFrame(dom_data).fillna(0)
    dom_df["Average"] = dom_df.mean(axis=1)

    # Calculate missing data patterns (consecutive missing values)
    missing_patterns = {}
    for sensor_id, series in time_series_dict.items():
        if series is not None and len(series) > 0:
            # Resample to consistent 15-minute intervals
            resampled = series.resample("15min").mean()

            # Find runs of missing data
            is_missing = resampled.isna()

            # Calculate runs of True values (missing data)
            runs = []
            run_length = 0

            for val in is_missing:
                if val:
                    run_length += 1
                elif run_length > 0:
                    runs.append(run_length)
                    run_length = 0

            # Add the last run if needed
            if run_length > 0:
                runs.append(run_length)

            # Calculate statistics on missing data runs
            if runs:
                missing_patterns[sensor_id] = {
                    "count": len(runs),
                    "mean_length": np.mean(runs) * 15,  # convert to minutes
                    "max_length": np.max(runs) * 15,  # convert to minutes
                    "median_length": np.median(runs) * 15,  # convert to minutes
                }
            else:
                missing_patterns[sensor_id] = {
                    "count": 0,
                    "mean_length": 0,
                    "max_length": 0,
                    "median_length": 0,
                }

    missing_df = pd.DataFrame(missing_patterns).T
    missing_df.columns = [
        "Number of Gaps",
        "Mean Gap Length (min)",
        "Max Gap Length (min)",
        "Median Gap Length (min)",
    ]
    missing_df = missing_df.sort_values("Number of Gaps", ascending=False)

    return {
        "basic_stats": basic_stats,
        "summary": summary_df,
        "hourly_distribution": hourly_df,
        "day_of_week_distribution": dow_df,
        "day_of_month_distribution": dom_df,
        "missing_data_patterns": missing_df,
    }

In [None]:
analysis_results = analyze_time_series_data(results_containing_data_cleaned)
display(analysis_results["summary"])
display(analysis_results["basic_stats"])
display(analysis_results["missing_data_patterns"])

In [None]:
def visualize_time_series_summary(time_series_dict, basic_stats):
    """
    Create visualizations for the time series data summary

    Parameters:
    -----------
    time_series_dict : dict
        Dictionary mapping sensor IDs to their time series data
    basic_stats : DataFrame
        DataFrame containing the basic statistics for each sensor
    """
    # Set up the plotting style
    plt.style.use("seaborn-v0_8-whitegrid")

    # Create a figure with 2x2 subplots
    fig, axs = plt.subplots(2, 2, figsize=(16, 12))

    # 1. Data completeness by sensor
    completeness = basic_stats[["sensor_id", "raw_completeness_pct"]].set_index(
        "sensor_id"
    )
    completeness = completeness.sort_values("raw_completeness_pct", ascending=True)
    sns.barplot(
        x=completeness["raw_completeness_pct"],
        y=completeness.index,
        ax=axs[0, 0],
        palette="viridis",
    )
    axs[0, 0].set_title("Data Completeness by Sensor (%)", fontsize=12)
    axs[0, 0].set_xlabel("Completeness (%)")
    axs[0, 0].set_ylabel("Sensor ID")

    # 2. Average readings by hour of day
    hourly_data = {}
    for sensor_id, series in time_series_dict.items():
        if series is not None and len(series) > 0:
            hourly_data[sensor_id] = series.groupby(series.index.hour).size()

    hourly_df = pd.DataFrame(hourly_data).fillna(0)
    hourly_avg = hourly_df.mean(axis=1)

    sns.lineplot(x=hourly_avg.index, y=hourly_avg.values, marker="o", ax=axs[0, 1])
    axs[0, 1].set_title("Average Readings by Hour of Day", fontsize=12)
    axs[0, 1].set_xlabel("Hour of Day")
    axs[0, 1].set_ylabel("Average Number of Readings")
    axs[0, 1].set_xticks(range(0, 24, 2))

    # 3. Average readings by day of week
    dow_data = {}
    for sensor_id, series in time_series_dict.items():
        if series is not None and len(series) > 0:
            dow_data[sensor_id] = series.groupby(series.index.dayofweek).size()

    dow_df = pd.DataFrame(dow_data).fillna(0)
    dow_avg = dow_df.mean(axis=1)
    days = [
        "Monday",
        "Tuesday",
        "Wednesday",
        "Thursday",
        "Friday",
        "Saturday",
        "Sunday",
    ]

    sns.barplot(x=days, y=dow_avg.values, ax=axs[1, 0], palette="viridis")
    axs[1, 0].set_title("Average Readings by Day of Week", fontsize=12)
    axs[1, 0].set_xlabel("Day of Week")
    axs[1, 0].set_ylabel("Average Number of Readings")
    plt.setp(axs[1, 0].xaxis.get_majorticklabels(), rotation=45)

    # 4. Distribution of readings per day
    if "readings_per_day" in basic_stats.columns:
        sns.histplot(basic_stats["readings_per_day"], kde=True, ax=axs[1, 1])
        axs[1, 1].set_title("Distribution of Readings per Day", fontsize=12)
        axs[1, 1].set_xlabel("Average Readings per Day")
        axs[1, 1].set_ylabel("Number of Sensors")

    plt.tight_layout()
    plt.show()

    plt.tight_layout()
    plt.show()

    # Return nothing
    return None

In [None]:
def analyze_sensor_data(time_series_dict, visualize=True):
    """
    Analyze sensor time series data and return a comprehensive report

    Parameters:
    -----------
    time_series_dict : dict
        Dictionary mapping sensor IDs to their time series data
    visualize : bool
        Whether to generate visualizations

    Returns:
    --------
    dict
        Dictionary containing DataFrames with different analyses
    """
    # Run the analysis
    analysis_results = analyze_time_series_data(time_series_dict)

    # Print a summary
    print("==== Traffic Sensor Data Analysis ====")
    print(f"Total Sensors: {len(time_series_dict)}")
    print(
        f"Period: {analysis_results['summary'].iloc[0]['Date Range Start']} to {analysis_results['summary'].iloc[0]['Date Range End']}"
    )
    print(
        f"Average Completeness: {analysis_results['summary'].iloc[0]['Average Completeness (%)']:.2f}%"
    )
    print(
        f"Average Readings per Sensor: {analysis_results['summary'].iloc[0]['Average Readings per Sensor']:.0f}"
    )
    print("====================================")

    # Create visualizations if requested
    if visualize:
        visualize_time_series_summary(time_series_dict, analysis_results["basic_stats"])

    return analysis_results

In [None]:
# Run the full analysis
analysis = analyze_sensor_data(results_containing_data_cleaned)

# Display the main statistics DataFrame
# display(analysis['basic_stats'])

# Display the summary information
display(analysis["summary"])

# Display missing data patterns
# display(analysis['missing_data_patterns'])

# Display hourly distribution
# display(analysis['hourly_distribution'])

# Display day of week distribution
# display(analysis['day_of_week_distribution'])

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.colors import ListedColormap


def visualize_missing_data_masks(
    time_series_dict, window_size=12, stride=1, num_sensors=3, num_windows=4
):
    """
    Visualize how missing data is represented by masks and interpreted by the model.

    Parameters:
    -----------
    time_series_dict : dict
        Dictionary mapping sensor IDs to time series data
    window_size : int
        Size of sliding windows
    stride : int
        Step size for sliding windows
    num_sensors : int
        Number of sensors to visualize
    num_windows : int
        Number of windows to show per sensor
    """
    # Choose a few sensors with some missing data
    sensor_ids = list(time_series_dict.keys())
    if len(sensor_ids) > num_sensors:
        sensor_ids = sensor_ids[:num_sensors]

    fig = plt.figure(figsize=(18, 4 * len(sensor_ids)))
    fig.suptitle("Visualization of Missing Data Masks", fontsize=16)

    # Set up a 2D grid: sensors x (raw data, windows)
    gs = fig.add_gridspec(len(sensor_ids), 2, width_ratios=[2, 3])

    for i, sensor_id in enumerate(sensor_ids):
        series = time_series_dict[sensor_id]

        # Remove duplicates and sort index
        series = series.sort_index()
        series = series[~series.index.duplicated(keep="first")]

        # Create a resampled series to ensure regular intervals
        resampled = series.resample("15min").mean()

        # 1. Plot raw time series with missing values
        ax1 = fig.add_subplot(gs[i, 0])

        # Plot the raw data
        valid_mask = ~pd.isna(resampled)

        # X values for plotting
        x_dates = resampled.index

        # Plot with gaps for missing values
        ax1.plot(x_dates[valid_mask], resampled[valid_mask], "b-", label="Valid Data")

        # Mark missing values with red X
        missing_mask = ~valid_mask
        if any(missing_mask):
            ax1.scatter(
                x_dates[missing_mask],
                [0] * sum(missing_mask),  # Put at y=0 for visibility
                marker="x",
                color="red",
                s=50,
                label="Missing Data",
            )

        # Highlight a few windows for detailed inspection
        window_indices = []
        consecutive_valid = []

        # Find consecutive windows of valid data
        current_run = 0
        for j in range(len(valid_mask)):
            if valid_mask.iloc[j]:
                current_run += 1
                if current_run >= window_size:
                    window_indices.append(j - window_size + 1)
            else:
                current_run = 0

            # Track consecutive valid points for visualization
            if valid_mask.iloc[j]:
                if len(consecutive_valid) == 0 or j - consecutive_valid[-1][-1] > 1:
                    consecutive_valid.append([j])
                else:
                    consecutive_valid[-1].append(j)

        # If we didn't find enough windows, add some with missing data
        if len(window_indices) < num_windows:
            for j in range(0, len(resampled) - window_size + 1, window_size):
                if j not in window_indices and len(window_indices) < num_windows:
                    window_indices.append(j)

        # Limit to requested number
        window_indices = window_indices[:num_windows]

        # Highlight window regions on raw data plot
        colors = ["green", "orange", "purple", "brown"]
        for j, start_idx in enumerate(window_indices):
            end_idx = start_idx + window_size
            color = colors[j % len(colors)]

            # Get date range for this window
            if start_idx < len(x_dates) and end_idx <= len(x_dates):
                start_date = x_dates[start_idx]
                end_date = x_dates[min(end_idx, len(x_dates) - 1)]

                # Draw rectangle for window
                ax1.axvspan(
                    start_date,
                    end_date,
                    alpha=0.2,
                    color=color,
                    label=f"Window {j+1}" if j == 0 else "",
                )

                # Label window
                ax1.text(
                    start_date + (end_date - start_date) / 2,
                    ax1.get_ylim()[1] * 0.9,
                    f"W{j+1}",
                    horizontalalignment="center",
                    color=color,
                )

        ax1.set_title(f"Sensor {sensor_id} Raw Time Series")
        ax1.set_xlabel("Date")
        ax1.set_ylabel("Value")
        ax1.legend(loc="upper right")
        ax1.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d"))
        plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45)

        # 2. Visualize the windowed data and masks
        ax2 = fig.add_subplot(gs[i, 1])

        # Create a matrix to show windows and their masks
        window_matrix = np.zeros((len(window_indices), window_size))
        mask_matrix = np.zeros((len(window_indices), window_size))

        # Fill the matrices
        for j, start_idx in enumerate(window_indices):
            end_idx = start_idx + window_size

            if end_idx <= len(resampled):
                # Get window data
                window_data = resampled.iloc[start_idx:end_idx].values

                # Create mask (1 for valid, 0 for missing)
                window_mask = (~pd.isna(window_data)).astype(int)

                # Fill matrices
                window_matrix[j, :] = np.where(
                    window_mask == 1, window_data, -1
                )  # -1 for missing in the model
                mask_matrix[j, :] = window_mask

        # Create a custom colormap for the mask visualization
        mask_cmap = ListedColormap(["red", "green"])

        # Plot the mask matrix
        im = ax2.imshow(
            mask_matrix,
            aspect="auto",
            cmap=mask_cmap,
            interpolation="none",
            vmin=0,
            vmax=1,
        )

        # Overlay values on the mask visualization
        for j in range(mask_matrix.shape[0]):
            for k in range(mask_matrix.shape[1]):
                if mask_matrix[j, k] == 1:  # Valid data
                    value = window_matrix[j, k]
                    ax2.text(
                        k,
                        j,
                        f"{value:.1f}",
                        ha="center",
                        va="center",
                        color="black",
                        fontsize=8,
                    )
                else:  # Missing data
                    ax2.text(
                        k,
                        j,
                        "X",
                        ha="center",
                        va="center",
                        color="white",
                        fontsize=10,
                        fontweight="bold",
                    )

        # Add colorbar and labels
        cbar = plt.colorbar(im, ax=ax2, ticks=[0, 1])
        cbar.set_ticklabels(["Missing (0)", "Valid (1)"])

        ax2.set_title(f"Sensor {sensor_id} Windowed Data with Masks")
        ax2.set_xlabel("Time Step in Window")
        ax2.set_ylabel("Window Index")
        ax2.set_yticks(range(len(window_indices)))
        ax2.set_yticklabels([f"Window {j+1}" for j in range(len(window_indices))])
        ax2.set_xticks(range(window_size))

    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    plt.show()

    # Create a diagram explaining how the model handles missing values
    fig, ax = plt.subplots(figsize=(12, 7))

    # Turn off axis
    ax.axis("off")

    # Create flow diagram explaining the process
    # Main title
    ax.text(
        0.5,
        0.95,
        "How STGNN Model Handles Missing Data",
        fontsize=16,
        ha="center",
        va="center",
        fontweight="bold",
    )

    # Step 1: Raw Time Series
    ax.add_patch(
        plt.Rectangle((0.05, 0.75), 0.2, 0.1, fill=True, color="lightblue", alpha=0.5)
    )
    ax.text(0.15, 0.8, "Raw Time Series", ha="center", va="center", fontweight="bold")
    ax.text(
        0.15,
        0.77,
        "Contains gaps and missing values",
        ha="center",
        va="center",
        fontsize=9,
    )

    # Arrow 1
    ax.arrow(
        0.25, 0.8, 0.1, 0, head_width=0.02, head_length=0.02, fc="black", ec="black"
    )

    # Step 2: Preprocessing
    ax.add_patch(
        plt.Rectangle((0.35, 0.75), 0.2, 0.1, fill=True, color="lightgreen", alpha=0.5)
    )
    ax.text(
        0.45, 0.8, "Data Preprocessing", ha="center", va="center", fontweight="bold"
    )
    ax.text(
        0.45, 0.77, "Find continuous segments", ha="center", va="center", fontsize=9
    )

    # Arrow 2
    ax.arrow(
        0.55, 0.8, 0.1, 0, head_width=0.02, head_length=0.02, fc="black", ec="black"
    )

    # Step 3: Window Creation
    ax.add_patch(
        plt.Rectangle((0.65, 0.75), 0.2, 0.1, fill=True, color="lightyellow", alpha=0.5)
    )
    ax.text(0.75, 0.8, "Window Creation", ha="center", va="center", fontweight="bold")
    ax.text(
        0.75, 0.77, "Create fixed-size windows", ha="center", va="center", fontsize=9
    )

    # Arrow 3
    ax.arrow(
        0.75, 0.75, 0, -0.1, head_width=0.02, head_length=0.02, fc="black", ec="black"
    )

    # Branch out to two parallel boxes
    # Box 1: Window Values
    ax.add_patch(
        plt.Rectangle((0.55, 0.55), 0.2, 0.1, fill=True, color="lightcoral", alpha=0.5)
    )
    ax.text(0.65, 0.6, "Window Values", ha="center", va="center", fontweight="bold")
    ax.text(
        0.65, 0.57, "Missing values set to -1", ha="center", va="center", fontsize=9
    )

    # Box 2: Mask Creation
    ax.add_patch(
        plt.Rectangle((0.75, 0.55), 0.2, 0.1, fill=True, color="lightblue", alpha=0.5)
    )
    ax.text(0.85, 0.6, "Mask Creation", ha="center", va="center", fontweight="bold")
    ax.text(0.85, 0.57, "1 = valid, 0 = missing", ha="center", va="center", fontsize=9)

    # Arrows to next step
    ax.arrow(
        0.65, 0.55, 0, -0.1, head_width=0.02, head_length=0.02, fc="black", ec="black"
    )
    ax.arrow(
        0.85, 0.55, 0, -0.1, head_width=0.02, head_length=0.02, fc="black", ec="black"
    )

    # Step 4: Model Input
    ax.add_patch(
        plt.Rectangle((0.35, 0.35), 0.6, 0.1, fill=True, color="lightsalmon", alpha=0.5)
    )
    ax.text(
        0.65, 0.4, "Model Input Tensors", ha="center", va="center", fontweight="bold"
    )
    ax.text(
        0.65,
        0.37,
        "x: [batch_size, num_nodes, seq_len, 1]    x_mask: [batch_size, num_nodes, seq_len, 1]",
        ha="center",
        va="center",
        fontsize=9,
    )

    # Arrow 4
    ax.arrow(
        0.65, 0.35, 0, -0.1, head_width=0.02, head_length=0.02, fc="black", ec="black"
    )

    # Step 5: STGNN Model
    ax.add_patch(
        plt.Rectangle(
            (0.35, 0.15), 0.6, 0.1, fill=True, color="lightsteelblue", alpha=0.5
        )
    )
    ax.text(
        0.65, 0.2, "STGNN Model Processing", ha="center", va="center", fontweight="bold"
    )
    ax.text(
        0.65,
        0.17,
        "Only applies operations on valid data (where mask = 1)",
        ha="center",
        va="center",
        fontsize=9,
    )

    # Mask handling explanation (key points)
    ax.add_patch(
        plt.Rectangle((0.1, 0.05), 0.8, 0.08, fill=True, color="#e6f7ff", alpha=0.5)
    )
    ax.text(
        0.5,
        0.09,
        "Key Concepts for Mask Handling in STGNN",
        ha="center",
        va="center",
        fontweight="bold",
        fontsize=10,
    )
    ax.text(
        0.5,
        0.06,
        "1. Graph convolutions compute features only for valid nodes (mask=1)\n"
        + "2. Missing values remain masked through each layer\n"
        + "3. Loss function weighted by mask: only valid predictions contribute to loss",
        ha="center",
        va="center",
        fontsize=9,
    )

    plt.show()


# Example usage:
visualize_missing_data_masks(results_containing_data_cleaned)

In [None]:
def visualize_mask_propagation(time_series_dict, window_size=12):
    """
    Visualize how masks propagate through the STGNN model layers.

    Parameters:
    -----------
    time_series_dict : dict
        Dictionary mapping sensor IDs to time series data
    window_size : int
        Size of sliding windows
    """
    # Choose a sensor with some missing data
    sensor_id = None
    for sid, series in time_series_dict.items():
        if series is not None and len(series) > 0:
            # Check for missing values
            resampled = series.resample("15min").mean()
            if resampled.isna().sum() > 0:
                sensor_id = sid
                break

    if sensor_id is None:
        print("No sensor with missing values found. Using the first sensor.")
        # If no sensor has missing values, just use the first one
        sensor_id = list(time_series_dict.keys())[0]

    series = time_series_dict[sensor_id]

    # Remove duplicates and sort index
    series = series.sort_index()
    series = series[~series.index.duplicated(keep="first")]

    # Create a resampled series to ensure regular intervals
    resampled = series.resample("15min").mean()

    # Find a window with some missing values
    window_start = None
    for i in range(len(resampled) - window_size + 1):
        window = resampled.iloc[i : i + window_size]
        # Look for windows with some but not all missing values
        missing_count = window.isna().sum()
        if 0 < missing_count < window_size // 2:
            window_start = i
            break

    if window_start is None:
        # If no suitable window found, use the first window
        window_start = 0

    # Extract the window
    window = resampled.iloc[window_start : window_start + window_size]

    # Create mask (1 for valid, 0 for missing)
    window_mask = (~window.isna()).astype(int).values

    # Replace NaN with -1 (model's representation for missing values)
    window_values = window.values.copy()
    window_values[np.isnan(window_values)] = -1

    # Create figure
    fig, axs = plt.subplots(4, 1, figsize=(15, 12))
    fig.suptitle(
        f"Mask Propagation Through STGNN Layers - Sensor {sensor_id}", fontsize=16
    )

    # 1. Original window with missing values
    axs[0].plot(range(window_size), window_values, "o-", label="Original Values")
    axs[0].set_title("1. Input Window with Missing Values")
    axs[0].set_xlabel("Time Step")
    axs[0].set_ylabel("Value")

    # Mark missing values
    for i in range(window_size):
        if window_mask[i] == 0:
            axs[0].scatter(
                i, window_values[i], color="red", marker="x", s=100, label="_nolegend_"
            )
            axs[0].annotate(
                "Missing",
                (i, window_values[i]),
                xytext=(0, 10),
                textcoords="offset points",
                ha="center",
                fontsize=8,
                color="red",
            )

    axs[0].set_xticks(range(window_size))
    axs[0].set_ylim([min(window_values) - 1, max(window_values) + 1])

    # 2. Mask visualization
    mask_cmap = ListedColormap(["red", "green"])
    im = axs[1].imshow(
        window_mask.reshape(1, -1),
        aspect="auto",
        cmap=mask_cmap,
        interpolation="none",
        vmin=0,
        vmax=1,
    )

    axs[1].set_title("2. Input Mask (0=Missing, 1=Valid)")
    axs[1].set_xlabel("Time Step")
    axs[1].set_yticks([])
    axs[1].set_xticks(range(window_size))

    # Overlay mask values
    for i in range(window_size):
        axs[1].text(
            i,
            0,
            str(window_mask[i]),
            ha="center",
            va="center",
            color="white" if window_mask[i] == 0 else "black",
            fontweight="bold",
        )

    # Add colorbar
    cbar = plt.colorbar(im, ax=axs[1], orientation="horizontal")
    cbar.set_ticks([0.25, 0.75])
    cbar.set_ticklabels(["Missing (0)", "Valid (1)"])

    # 3. Simulate GCN layer processing
    axs[2].set_title("3. Graph Convolution: Features computed only for valid nodes")

    # Simulate GCN operation by applying a simple operation to valid points
    gcn_values = window_values.copy()
    # Simulate transformation - in reality this would involve neighbors
    for i in range(window_size):
        if window_mask[i] == 1:
            # Valid value - transform it (here just apply a simple function)
            gcn_values[i] = 0.8 * window_values[i] + 0.5

    # Plot transformed values
    axs[2].plot(
        range(window_size), gcn_values, "o-", color="purple", label="GCN Output"
    )

    # Mark valid vs. missing values
    for i in range(window_size):
        if window_mask[i] == 0:
            axs[2].scatter(i, gcn_values[i], color="red", marker="x", s=100)
            axs[2].annotate(
                "Still Missing",
                (i, gcn_values[i]),
                xytext=(0, 10),
                textcoords="offset points",
                ha="center",
                fontsize=8,
                color="red",
            )
        else:
            axs[2].scatter(i, gcn_values[i], color="green", marker="o", s=80)
            axs[2].annotate(
                "Transformed",
                (i, gcn_values[i]),
                xytext=(0, -15),
                textcoords="offset points",
                ha="center",
                fontsize=8,
                color="green",
            )

    axs[2].set_xlabel("Time Step")
    axs[2].set_ylabel("Value")
    axs[2].set_xticks(range(window_size))
    axs[2].set_ylim([min(gcn_values) - 1, max(gcn_values) + 1])

    # 4. Loss calculation
    axs[3].set_title("4. Loss Calculation: Only valid predictions contribute to loss")

    # Simulate target values and predictions
    target_values = window_values.copy() + 1  # Arbitrary offset
    pred_values = gcn_values.copy() + np.random.normal(0, 0.5, window_size)  # Add noise

    # Calculate error for each point
    errors = np.abs(target_values - pred_values)

    # Create a bar chart for errors
    bar_colors = ["green" if m == 1 else "lightgray" for m in window_mask]
    bars = axs[3].bar(range(window_size), errors, color=bar_colors)

    # Add labels explaining masked vs. unmasked loss
    for i in range(window_size):
        if window_mask[i] == 1:
            axs[3].text(
                i,
                errors[i] + 0.1,
                "Contributes\nto Loss",
                ha="center",
                va="bottom",
                fontsize=8,
                color="green",
            )
        else:
            axs[3].text(
                i,
                errors[i] + 0.1,
                "Ignored\nin Loss",
                ha="center",
                va="bottom",
                fontsize=8,
                color="red",
            )

    axs[3].set_xlabel("Time Step")
    axs[3].set_ylabel("Error")
    axs[3].set_xticks(range(window_size))

    # Add a legend to explain color scheme
    from matplotlib.patches import Patch

    legend_elements = [
        Patch(facecolor="green", label="Valid Data Points"),
        Patch(facecolor="lightgray", label="Missing Data Points"),
    ]
    axs[3].legend(handles=legend_elements, loc="upper right")

    plt.tight_layout()
    plt.subplots_adjust(top=0.92)
    plt.show()


# Example usage:
visualize_mask_propagation(results_containing_data_cleaned)

In [None]:
def debug_dataloader_creation(time_series_dict, window_size=12, stride=1):
    """
    Debug the process of creating windows to see if there's an issue with -1 values.

    Parameters:
    -----------
    time_series_dict : dict
        Dictionary mapping sensor IDs to time series data
    window_size : int
        Size of sliding windows
    stride : int
        Step size for sliding windows
    """
    from gnn_package.src.preprocessing import TimeSeriesPreprocessor

    print(
        f"Debugging window creation process with window_size={window_size}, stride={stride}"
    )

    # Create preprocessor with same parameters as in your training code
    processor = TimeSeriesPreprocessor(
        window_size=window_size,
        stride=stride,
        gap_threshold=pd.Timedelta(minutes=15),
        missing_value=-1.0,
    )

    # Track statistics from original data
    original_stats = {}

    # First, analyze original data completeness
    for sensor_id, series in time_series_dict.items():
        if series is None or len(series) == 0:
            continue

        # Calculate completeness in original data
        # First, clean up any duplicates
        series = series[~series.index.duplicated(keep="first")]

        # Resample to ensure regular intervals
        series = series.resample("15min").mean()

        # Calculate original stats
        total_values = len(series)
        missing_values = series.isna().sum()
        pct_missing = (missing_values / total_values) * 100 if total_values > 0 else 0
        pct_complete = 100 - pct_missing

        original_stats[sensor_id] = {
            "total_values": total_values,
            "missing_values": missing_values,
            "pct_complete": pct_complete,
        }

    print(f"\nOriginal data completeness stats for {len(original_stats)} sensors:")
    for sensor_id, stats in list(original_stats.items())[:5]:  # Show just first 5
        print(
            f"  Sensor {sensor_id}: {stats['pct_complete']:.2f}% complete ({stats['missing_values']} missing out of {stats['total_values']} values)"
        )

    # Now create windows using the same processor logic
    print("\nCreating windows using TimeSeriesPreprocessor...")
    X_by_sensor, masks_by_sensor, metadata_by_sensor = processor.create_windows(
        time_series_dict, standardize=True
    )

    # Analyze window completeness
    window_stats = {}

    for sensor_id, windows in X_by_sensor.items():
        if sensor_id not in window_stats:
            window_stats[sensor_id] = {
                "total_windows": 0,
                "windows_with_missing": 0,
                "total_values": 0,
                "missing_values": 0,
            }

        # Count windows and values
        window_stats[sensor_id]["total_windows"] = len(windows)
        window_stats[sensor_id]["total_values"] = windows.size

        # Count windows with missing values
        missing_mask = windows == -1
        windows_with_missing = np.any(missing_mask, axis=1).sum()
        window_stats[sensor_id]["windows_with_missing"] = windows_with_missing

        # Count missing values
        missing_values = missing_mask.sum()
        window_stats[sensor_id]["missing_values"] = missing_values

        # Check if masks correctly identify missing values
        if sensor_id in masks_by_sensor:
            sensor_masks = masks_by_sensor[sensor_id]
            mask_matches = ((sensor_masks == 0) == (windows == -1)).all()
            window_stats[sensor_id]["mask_correct"] = mask_matches

    # Print summary stats on windows
    print(f"\nWindow creation results for {len(window_stats)} sensors:")
    for sensor_id, stats in list(window_stats.items())[:5]:  # Show just first 5
        pct_windows_missing = (
            (stats["windows_with_missing"] / stats["total_windows"] * 100)
            if stats["total_windows"] > 0
            else 0
        )
        pct_values_missing = (
            (stats["missing_values"] / stats["total_values"] * 100)
            if stats["total_values"] > 0
            else 0
        )

        print(f"  Sensor {sensor_id}:")
        print(
            f"    Windows: {stats['total_windows']} total, {stats['windows_with_missing']} with missing values ({pct_windows_missing:.2f}%)"
        )
        print(
            f"    Values: {stats['missing_values']} missing out of {stats['total_values']} ({pct_values_missing:.2f}%)"
        )

        # Compare with original completeness
        if sensor_id in original_stats:
            orig_pct_missing = 100 - original_stats[sensor_id]["pct_complete"]
            print(
                f"    Original missing: {orig_pct_missing:.2f}%, Window missing: {pct_values_missing:.2f}%"
            )

            if pct_values_missing > orig_pct_missing + 5:  # 5% tolerance
                print(
                    f"    WARNING: Missing values increased significantly after windowing!"
                )

    # Examine a specific window for one sensor to see the pattern
    print("\nExamining specific windows for a sample sensor:")
    example_sensor_id = list(X_by_sensor.keys())[0]
    example_windows = X_by_sensor[example_sensor_id]
    example_masks = masks_by_sensor[example_sensor_id]

    print(f"Sensor {example_sensor_id} window shape: {example_windows.shape}")

    # Look at first 3 windows
    for i in range(min(3, len(example_windows))):
        window = example_windows[i]
        mask = example_masks[i]

        missing_count = (window == -1).sum()
        print(
            f"\n  Window {i}: {missing_count} missing values out of {len(window)} ({missing_count/len(window)*100:.2f}%)"
        )

        # Print window values and masks side by side
        print("  Index  Value   Mask")
        print("  ------------------")
        for j in range(len(window)):
            value_str = f"{window[j]:.2f}" if window[j] != -1 else "-1.00"
            print(f"  {j:5d}  {value_str}  {int(mask[j])}")

    # Now let's visualize a few windows and their masks
    visualize_windows_and_masks(X_by_sensor, masks_by_sensor)

    return {
        "original_stats": original_stats,
        "window_stats": window_stats,
        "X_by_sensor": X_by_sensor,
        "masks_by_sensor": masks_by_sensor,
    }


def visualize_windows_and_masks(
    X_by_sensor, masks_by_sensor, num_sensors=2, num_windows=3
):
    """
    Visualize windows and their masks to help diagnose missing data issues.

    Parameters:
    -----------
    X_by_sensor : dict
        Dictionary mapping sensor IDs to window arrays
    masks_by_sensor : dict
        Dictionary mapping sensor IDs to mask arrays
    num_sensors : int
        Number of sensors to visualize
    num_windows : int
        Number of windows per sensor to visualize
    """
    import matplotlib.pyplot as plt

    # Get a subset of sensors
    sensor_ids = list(X_by_sensor.keys())[:num_sensors]

    # Create a figure for each sensor
    for sensor_id in sensor_ids:
        windows = X_by_sensor[sensor_id]
        masks = masks_by_sensor[sensor_id]

        # Limit number of windows
        n_windows = min(num_windows, len(windows))

        fig, axs = plt.subplots(n_windows, 1, figsize=(12, 3 * n_windows))
        if n_windows == 1:
            axs = [axs]

        fig.suptitle(f"Sensor {sensor_id} Windows and Masks", fontsize=14)

        for i in range(n_windows):
            window = windows[i]
            mask = masks[i]

            # Plot the window values
            axs[i].plot(range(len(window)), window, "b-", label="Window Values")

            # Mark missing values (-1)
            missing_mask = window == -1
            if missing_mask.any():
                axs[i].scatter(
                    np.where(missing_mask)[0],
                    window[missing_mask],
                    color="red",
                    marker="x",
                    s=100,
                    label="Missing (-1)",
                )

            # Plot mask values on same axis (scaled to match window values)
            min_val, max_val = window[~missing_mask].min(), window[~missing_mask].max()
            if not np.isfinite(min_val):
                min_val, max_val = -1, 1  # Default if all values are missing

            range_val = max(1, max_val - min_val)

            # Scale mask to plot range
            scaled_mask = min_val + mask * range_val
            axs[i].plot(
                range(len(mask)), scaled_mask, "g--", alpha=0.7, label="Mask (scaled)"
            )

            # Add mask values as text
            for j, m in enumerate(mask):
                axs[i].text(
                    j,
                    scaled_mask[j] + 0.05 * range_val,
                    f"{int(m)}",
                    ha="center",
                    va="bottom",
                    color="green",
                    fontsize=8,
                )

            axs[i].set_title(f"Window {i}")
            axs[i].set_xlabel("Time Step")
            axs[i].set_ylabel("Value")
            axs[i].legend(loc="upper right")

            # Annotate missing value percentage
            missing_pct = missing_mask.sum() / len(window) * 100
            axs[i].text(
                0.02,
                0.95,
                f"{missing_pct:.1f}% missing values",
                transform=axs[i].transAxes,
                fontsize=10,
                bbox=dict(facecolor="white", alpha=0.8),
            )

        plt.tight_layout()
        plt.subplots_adjust(top=0.9)
        plt.show()


# Example usage:
debug_results = debug_dataloader_creation(
    results_containing_data_cleaned, window_size=12, stride=1
)