In [None]:
import os
import pickle
import torch
from gnn_package import training
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# get processed data
if os.path.exists("data_loaders_test_1wk.pkl"):
    with open("data_loaders_test_1wk.pkl", "rb") as f:
        data_loaders = pickle.load(f)
else:
    print("Data loaders not found. Please run the data processing script first.")


if os.path.exists("test_data_1wk.pkl"):
    with open("test_data_1wk.pkl", "rb") as f:
        results_containing_data = pickle.load(f)
else:
    print("Test data not found. Please run the data processing script first.")

In [None]:
def inspect_dataloader(dataloader):
    """Inspect the structure of a PyTorch DataLoader and its batches."""
    print(f"DataLoader type: {type(dataloader)}")
    print(f"DataLoader object: {dataloader}")

    # Get information about batch size and other dataloader properties
    print(f"\nDataLoader Properties:")
    print(f"Batch size: {dataloader.batch_size}")
    print(f"Number of workers: {dataloader.num_workers}")
    print(
        f"Collate function: {dataloader.collate_fn.__name__ if hasattr(dataloader.collate_fn, '__name__') else dataloader.collate_fn}"
    )

    # Examine the dataset
    print(f"\nDataset:")
    print(f"Dataset type: {type(dataloader.dataset)}")
    print(f"Dataset length: {len(dataloader.dataset)}")

    # Try to inspect one sample from the dataset
    try:
        sample = dataloader.dataset[10000]
        print(f"\nSample from dataset:")
        print(f"Sample type: {type(sample)}")
        print(
            f"Sample keys: {sample.keys() if isinstance(sample, dict) else 'Not a dictionary'}"
        )
        for key, value in sample.items() if isinstance(sample, dict) else []:
            print(
                f"  {key}: {type(value)}, Shape: {value.shape if hasattr(value, 'shape') else 'No shape attribute'}"
            )
    except Exception as e:
        print(f"Could not inspect dataset sample: {e}")

    # Try to examine one batch
    print(f"\nBatch inspection:")
    try:
        batch = next(iter(dataloader))
        print(f"Batch type: {type(batch)}")
        print(
            f"Batch keys: {batch.keys() if isinstance(batch, dict) else 'Not a dictionary'}"
        )

        for key, value in batch.items() if isinstance(batch, dict) else []:
            if hasattr(value, "shape"):
                print(f"  {key}: Shape {value.shape}, Type {value.dtype}")
            else:
                print(f"  {key}: {type(value)}")
    except Exception as e:
        print(f"Could not inspect batch: {e}")

    # Calculate total number of batches
    try:
        num_batches = len(dataloader)
        print(f"\nTotal number of batches: {num_batches}")
    except Exception as e:
        print(f"Could not determine number of batches: {e}")

    return batch if "batch" in locals() else None

In [None]:
# Inspect the train dataloader
batch = inspect_dataloader(data_loaders["train_loader"])

# If a batch was successfully retrieved, perform a deeper inspection
if batch is not None:
    print("\nDetailed batch analysis:")

    # For the x tensor (input features)
    if "x" in batch:
        x = batch["x"]
        print(f"\nx tensor analysis:")
        print(f"  Shape: {x.shape}")
        print(
            f"  Interpretation: [batch_size={x.shape[0]}, num_nodes={x.shape[1]}, seq_len={x.shape[2]}, features={x.shape[3]}]"
        )
        print(f"  Min value: {x.min().item():.4f}, Max value: {x.max().item():.4f}")
        print(f"  Contains -1 (missing data): {(x == -1).any().item()}")
        print(
            f"  Percentage of missing data: {((x == -1).sum() / x.numel()) * 100:.2f}%"
        )

    # For the adjacency matrix
    if "adj" in batch:
        adj = batch["adj"]
        print(f"\nadj tensor analysis:")
        print(f"  Shape: {adj.shape}")
        print(f"  Interpretation: [num_nodes={adj.shape[0]}, num_nodes={adj.shape[1]}]")
        print(f"  Min value: {adj.min().item():.4f}, Max value: {adj.max().item():.4f}")
        print(f"  Is symmetric: {torch.allclose(adj, adj.transpose(0, 1))}")

    # For the masks
    if "x_mask" in batch:
        x_mask = batch["x_mask"]
        print(f"\nx_mask tensor analysis:")
        print(f"  Shape: {x_mask.shape}")
        print(f"  Values are binary: {torch.all((x_mask == 0) | (x_mask == 1)).item()}")
        print(
            f"  Percentage of valid (non-masked) values: {x_mask.mean().item()*100:.2f}%"
        )

    # For the target values
    if "y" in batch:
        y = batch["y"]
        print(f"\ny tensor analysis (targets):")
        print(f"  Shape: {y.shape}")
        print(
            f"  Interpretation: [batch_size={y.shape[0]}, num_nodes={y.shape[1]}, horizon={y.shape[2]}, features={y.shape[3]}]"
        )

In [None]:
def analyze_dataloader_missing_values(dataloader):
    """
    Analyze the presence of -1 values (missing data) in a dataloader and check window completeness.

    Parameters:
    -----------
    dataloader : torch.utils.data.DataLoader
        The dataloader to analyze

    Returns:
    --------
    dict
        Dictionary containing analysis results
    """
    print("Analyzing missing values in dataloader...")

    # Initialize counters and storage
    total_batches = 0
    total_windows = 0
    windows_with_missing = 0
    missing_values_count = 0
    total_values_count = 0

    # Store window counts per sensor
    node_window_counts = {}
    node_missing_counts = {}

    # Process some batches
    max_batches = 10  # Limit to 10 batches for efficiency

    for batch_idx, batch in enumerate(dataloader):
        if batch_idx >= max_batches:
            break

        total_batches += 1

        # Extract data tensors
        x = batch["x"]  # [batch_size, num_nodes, seq_len, features]
        x_mask = batch["x_mask"]  # [batch_size, num_nodes, seq_len, features]
        node_indices = batch["node_indices"]  # Sensor indices in this batch

        batch_size, num_nodes, seq_len, _ = x.shape

        # Convert node indices to list if it's a tensor
        if torch.is_tensor(node_indices):
            node_indices = node_indices.cpu().numpy().tolist()

        # Count windows and missing values
        total_windows += batch_size * num_nodes

        # Check for -1 values (missing data indicators)
        missing_mask = x == -1
        batch_missing = missing_mask.sum().item()
        missing_values_count += batch_missing
        total_values_count += x.numel()

        # Count windows with any missing values
        for b in range(batch_size):
            for n in range(num_nodes):
                # Get node ID (sensor ID)
                node_id = node_indices[n] if n < len(node_indices) else f"unknown_{n}"

                # Increment window count for this node
                if node_id not in node_window_counts:
                    node_window_counts[node_id] = 0
                    node_missing_counts[node_id] = 0

                node_window_counts[node_id] += 1

                # Check if this window has any missing values
                window_missing = missing_mask[b, n].any().item()
                if window_missing:
                    windows_with_missing += 1
                    node_missing_counts[node_id] += 1

        # Check if mask matches -1 values
        mask_matches_missing = ((x_mask == 0) == (x == -1)).all().item()
        if not mask_matches_missing:
            print(
                f"WARNING: Batch {batch_idx} has mismatches between mask and -1 values!"
            )

    # Create a DataFrame for window counts by sensor
    window_counts_df = pd.DataFrame(
        {
            "sensor_id": list(node_window_counts.keys()),
            "total_windows": [node_window_counts[nid] for nid in node_window_counts],
            "windows_with_missing": [
                node_missing_counts[nid] for nid in node_window_counts
            ],
        }
    )

    # Calculate percentage of windows with missing values
    window_counts_df["pct_windows_with_missing"] = (
        window_counts_df["windows_with_missing"]
        / window_counts_df["total_windows"]
        * 100
    )

    # Sort by number of windows
    window_counts_df = window_counts_df.sort_values("total_windows", ascending=False)

    # Check if all sensors have the same number of windows
    equal_window_counts = window_counts_df["total_windows"].nunique() == 1

    # Prepare results
    results = {
        "total_batches": total_batches,
        "total_windows": total_windows,
        "windows_with_missing": windows_with_missing,
        "pct_windows_with_missing": (
            windows_with_missing / total_windows * 100 if total_windows > 0 else 0
        ),
        "missing_values_count": missing_values_count,
        "total_values_count": total_values_count,
        "pct_missing_values": (
            missing_values_count / total_values_count * 100
            if total_values_count > 0
            else 0
        ),
        "window_counts_by_sensor": window_counts_df,
        "equal_window_counts": equal_window_counts,
    }

    # Print summary
    print("\n=== DataLoader Missing Value Analysis ===")
    print(f"Total batches analyzed: {total_batches}")
    print(f"Total windows: {total_windows}")
    print(
        f"Windows with missing values: {windows_with_missing} ({results['pct_windows_with_missing']:.2f}%)"
    )
    print(
        f"Total missing values: {missing_values_count} out of {total_values_count} ({results['pct_missing_values']:.2f}%)"
    )
    print(f"All sensors have equal window counts: {equal_window_counts}")
    print(f"Number of sensors in analysis: {len(node_window_counts)}")

    return results


def visualize_dataloader_completeness(dataloader, max_sensors=10):
    """
    Visualize the completeness of data in the dataloader by sensor.

    Parameters:
    -----------
    dataloader : torch.utils.data.DataLoader
        The dataloader to analyze
    max_sensors : int
        Maximum number of sensors to display
    """
    # Get the analysis results
    results = analyze_dataloader_missing_values(dataloader)

    # Extract the window counts DataFrame
    window_df = results["window_counts_by_sensor"]

    # Limit to top sensors
    if len(window_df) > max_sensors:
        window_df = window_df.head(max_sensors)

    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot 1: Window counts by sensor
    bar_colors = plt.cm.viridis(np.linspace(0, 1, len(window_df)))
    ax1.bar(window_df["sensor_id"], window_df["total_windows"], color=bar_colors)
    ax1.set_title("Number of Windows by Sensor")
    ax1.set_xlabel("Sensor ID")
    ax1.set_ylabel("Window Count")
    ax1.tick_params(axis="x", rotation=45)

    # Add text on bars if not too many sensors
    if len(window_df) <= max_sensors:
        for i, v in enumerate(window_df["total_windows"]):
            ax1.text(i, v + 0.1, str(v), ha="center")

    # Plot 2: Percentage of windows with missing values
    ax2.bar(
        window_df["sensor_id"], window_df["pct_windows_with_missing"], color=bar_colors
    )
    ax2.set_title("Percentage of Windows with Missing Values")
    ax2.set_xlabel("Sensor ID")
    ax2.set_ylabel("Percentage")
    ax2.tick_params(axis="x", rotation=45)
    ax2.set_ylim(0, 100)

    # Add text on bars
    if len(window_df) <= max_sensors:
        for i, v in enumerate(window_df["pct_windows_with_missing"]):
            ax2.text(i, v + 1, f"{v:.1f}%", ha="center")

    plt.tight_layout()
    plt.show()

    # Create a second figure to show a heatmap of missing values
    # Extract a sample batch to examine
    batch = next(iter(dataloader))
    x = batch["x"]  # [batch_size, num_nodes, seq_len, features]
    node_indices = batch["node_indices"]

    # Convert node indices to list if it's a tensor
    if torch.is_tensor(node_indices):
        node_indices = node_indices.cpu().numpy().tolist()

    # Create a mask of missing values (1 = missing, 0 = present)
    missing_mask = (x == -1).float().cpu().numpy()

    # Plot up to 2 batches
    batch_size, num_nodes, seq_len, _ = x.shape
    plot_batches = min(2, batch_size)

    fig, axs = plt.subplots(plot_batches, 1, figsize=(12, 4 * plot_batches))
    if plot_batches == 1:
        axs = [axs]

    for b in range(plot_batches):
        # Create mask matrix for this batch
        mask_matrix = missing_mask[b, :, :, 0]

        # Plot heatmap
        im = axs[b].imshow(mask_matrix, aspect="auto", cmap="Blues_r")
        axs[b].set_title(f"Missing Values Pattern in Batch {b}")
        axs[b].set_xlabel("Time Step")
        axs[b].set_ylabel("Node (Sensor) Index")

        # Add sensor labels
        node_labels = [f"{i}:{node_indices[i]}" for i in range(num_nodes)]
        axs[b].set_yticks(range(num_nodes))
        axs[b].set_yticklabels(node_labels)

        # Add colorbar
        plt.colorbar(im, ax=axs[b], label="Missing (1) vs Present (0)")

    plt.tight_layout()
    plt.show()

    return results

In [None]:
analyze_dataloader_missing_values(data_loaders["train_loader"])

In [None]:
def inspect_collate_function(dataloader, time_series_dict, window_size=12, stride=1):
    """
    Inspect the collate function used by the dataloader, which might be introducing extra -1 values.

    Parameters:
    -----------
    dataloader : torch.utils.data.DataLoader
        The dataloader to inspect
    time_series_dict : dict
        Original time series data for comparison
    window_size : int
        Size of sliding windows
    stride : int
        Step size for sliding windows
    """
    from gnn_package.src.preprocessing import TimeSeriesPreprocessor
    from gnn_package.src.dataloaders import collate_fn

    print("Inspecting dataloader collate function...")

    # First create windows using the processor
    processor = TimeSeriesPreprocessor(
        window_size=window_size,
        stride=stride,
        gap_threshold=pd.Timedelta(minutes=15),
        missing_value=-1.0,
    )

    X_by_sensor, masks_by_sensor, metadata_by_sensor = processor.create_windows(
        time_series_dict, standardize=True
    )

    # Get the collate function from the dataloader
    actual_collate_fn = dataloader.collate_fn

    # Get a few samples from the dataset to test the collate function directly
    try:
        # Get the dataset
        dataset = dataloader.dataset

        # Take a small batch of samples
        batch_size = min(8, len(dataset))
        samples = [dataset[i] for i in range(batch_size)]

        # Run the collate function directly
        print("\nRunning collate function directly on samples...")
        batch = actual_collate_fn(samples)

        # Analyze the batch
        print("\nBatch structure after collate_fn:")
        for key, value in batch.items():
            if torch.is_tensor(value):
                print(f"  {key}: Shape {value.shape}, Type {value.dtype}")

                # Check for -1 values
                if key in ["x", "y"]:
                    missing_count = (value == -1).sum().item()
                    total_count = value.numel()
                    print(f"    Contains -1 values: {missing_count > 0}")
                    print(
                        f"    Missing value percentage: {missing_count/total_count*100:.2f}%"
                    )
            else:
                print(f"  {key}: {type(value)}")

        # Compare with original sample values
        print("\nComparing -1 values between original samples and batch:")

        # Count missing values in original samples
        original_missing = 0
        original_total = 0

        for sample in samples:
            if "x" in sample and torch.is_tensor(sample["x"]):
                original_missing += (sample["x"] == -1).sum().item()
                original_total += sample["x"].numel()

        # Count missing values in batch
        batch_missing = (batch["x"] == -1).sum().item()
        batch_total = batch["x"].numel()

        # Print comparison
        print(
            f"  Original samples: {original_missing} missing out of {original_total} ({original_missing/original_total*100:.2f}%)"
        )
        print(
            f"  After collate_fn: {batch_missing} missing out of {batch_total} ({batch_missing/batch_total*100:.2f}%)"
        )

        if batch_missing > original_missing:
            print(
                f"  WARNING: Collate function is adding {batch_missing - original_missing} extra missing values!"
            )
            print(
                "  This suggests the collate_fn is introducing -1 values when forming batches."
            )

            # Let's print the collate_fn code for inspection
            import inspect

            print("\nCollate function source code:")
            print(inspect.getsource(actual_collate_fn))

    except Exception as e:
        print(f"Error in collate function inspection: {e}")

    return {
        "X_by_sensor": X_by_sensor,
        "masks_by_sensor": masks_by_sensor,
        "batch": batch if "batch" in locals() else None,
    }


# Usage example:
collate_results = inspect_collate_function(
    data_loaders["train_loader"], results_containing_data
)

In [None]:
def inspect_dataset_implementation(data_loaders):
    """Inspect the dataset implementation to understand how it's creating batches"""
    from gnn_package.src.dataloaders import SpatioTemporalDataset

    # First check the dataset type
    dataset = data_loaders["train_loader"].dataset
    print(f"Dataset type: {type(dataset)}")

    # Check if it's our SpatioTemporalDataset
    if isinstance(dataset, SpatioTemporalDataset):
        print("\nFound SpatioTemporalDataset instance:")
        print(f"Number of node IDs: {len(dataset.node_ids)}")
        print(f"Number of sample indices: {len(dataset.sample_indices)}")

        # Check a few sample indices to understand structure
        print("\nSample indices structure (first 5):")
        for i, (node_id, window_idx) in enumerate(dataset.sample_indices[:5]):
            print(f"  {i}: node_id={node_id}, window_idx={window_idx}")

        # Check __getitem__ implementation
        import inspect

        print("\nExamining __getitem__ method:")
        print(inspect.getsource(dataset.__getitem__))

        # Test __getitem__ directly
        print("\nTesting __getitem__ directly:")
        sample = dataset[0]
        print("Sample keys:", sample.keys())

        # Check dimensions and missing value ratio
        for key, value in sample.items():
            if torch.is_tensor(value):
                if value.numel() > 0:
                    missing_ratio = (value == -1).float().mean().item() * 100
                    print(f"  {key}: shape={value.shape}, missing={missing_ratio:.2f}%")
                else:
                    print(f"  {key}: shape={value.shape}")
            else:
                print(f"  {key}: {value}")

        # Now let's get a batch from the dataloader and check each node's representation
        print("\nAnalyzing a batch from the dataloader:")
        batch = next(iter(data_loaders["train_loader"]))

        # Check node representation in the batch
        x = batch["x"]  # [batch_size, num_nodes, seq_len, features]
        batch_size, num_nodes, seq_len, _ = x.shape

        print(f"Batch shape: {x.shape}")

        # Check missing value patterns by node
        print("\nMissing value patterns by node position in batch:")
        for n in range(min(num_nodes, 10)):  # Limit to first 10 nodes
            node_missing = (x[:, n, :, :] == -1).float().mean().item() * 100
            print(f"  Node position {n}: {node_missing:.2f}% missing values")

        # Check if all values are missing for some nodes - this would confirm our theory
        all_missing_nodes = 0
        for n in range(num_nodes):
            if (x[:, n, :, :] == -1).all().item():
                all_missing_nodes += 1

        print(
            f"\nNodes with all values missing: {all_missing_nodes} out of {num_nodes} ({all_missing_nodes/num_nodes*100:.2f}%)"
        )

        if all_missing_nodes > 0:
            print(
                "This confirms our theory - the dataloader is creating tensors with all nodes,"
            )
            print(
                "but only filling in values for the nodes present in each batch, leaving others as -1"
            )

    else:
        print(f"Dataset is not SpatioTemporalDataset but {type(dataset)}")
        # Try to extract some info about the dataset
        print("\nAttempting to inspect dataset properties:")
        for attr in dir(dataset):
            if not attr.startswith("_") and not callable(getattr(dataset, attr)):
                try:
                    value = getattr(dataset, attr)
                    print(f"  {attr}: {type(value)}")
                except:
                    print(f"  {attr}: <error getting value>")

    # Try to check the collate_fn source code regardless
    print("\nCollate function source code:")
    import inspect

    print(inspect.getsource(data_loaders["train_loader"].collate_fn))

    return dataset


# Usage example:
dataset_results = inspect_dataset_implementation(data_loaders)