In [None]:
"""
Copyright (c) 2023-2024 Laboratory for Intelligent Integrated Networks of Engineering Systems
@author: Hazem Abo-Donia, Amro M. Farid
@lab: Laboratory for Intelligent Integrated Networks of Engineering Systems, Stevens Institute of Technology
Acknowledgement:  The authors are grateful for the generous support of United States National Science Foundation for its generous funding of this research.  
@Modified: 12/13/2024
"""

: 

In [None]:
"""
To install dependencies, run:
pip install -r requirements.txt
"""

: 

In [None]:
# Importing all necessary libraries
import geopandas as gpd
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import plotly.graph_objects as go
import plotly.express as px
import networkx as nx
from ipywidgets import interactive, IntSlider, Layout, HBox, VBox
from matplotlib.widgets import Slider
import json
!pip install ipympl

: 

In [None]:
def visualizeTemporalSankeyDiagram(mySankeyDataFileName):
    """
    Visualizes a temporal Sankey diagram with yearly sliders and custom node colors.

    Parameters:
    - mySankeyDataFileName: str, path to the CSV file containing Sankey data

    Returns:
    - success: bool, True if the visualization was successful, False otherwise
    """
    try:
        # Section 1: Import Data
        data = pd.read_csv(mySankeyDataFileName)

        # Defining the nodes for the Sankey diagram
        nodes = list(set(data['source'].unique()).union(set(data['target'].unique())))
        node_indices = {node: i for i, node in enumerate(nodes)}

        # Custom colors for each node, ensure the list length matches the number of nodes
        node_colors = ['blue', 'green', 'red', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
        # Repeat or select enough colors based on the number of unique nodes
        node_colors = (node_colors * ((len(nodes) // len(node_colors)) + 1))[:len(nodes)]

        # Defining all of the available years
        years = sorted(data['year'].unique())

        # Section 3: Visualize Data
        fig = go.Figure()

        # Adding a trace for each year
        for year in years:
            filtered_data = data[data['year'] == year]
            fig.add_trace(go.Sankey(
                node=dict(
                    pad=15,
                    thickness=10,
                    line=dict(color="black", width=0.5),
                    label=nodes,
                    color=node_colors  # Applying the custom node colors here
                ),
                link=dict(
                    source=filtered_data['source'].map(node_indices),
                    target=filtered_data['target'].map(node_indices),
                    value=filtered_data['value']
                ),
                visible=False if year != years[0] else True
            ))

        # Defining the range slider component
        slider = dict(
            active=0,
            steps=[],
            xanchor='left',
            yanchor='top',
            pad=dict(t=50),
            currentvalue=dict(
                visible=True,
                prefix="Year:",
                xanchor="right",
                font=dict(size=20, color="#666")
            ),
            transition=dict(duration=300, easing="cubic-in-out")
        )

        # Creating a step for each available year
        for i, year in enumerate(years):
            # Creating the step for the selected year
            step = dict(
                label=str(year),
                method='update',
                args=[
                    {'visible': [i == j for j in range(len(years))]},
                    {'title_text': f"Energy Consumption Sankey Diagram for Year {year}"}
                ],
            )

            # Adding the step to the slider
            slider['steps'].append(step)

        fig.update_layout(
            title_text="Estimated United States Energy Consumption",
            sliders=[slider]
        )

        fig.show()
        return True  # Success
    except Exception as e:
        print(f"Error in visualizeTemporalSankeyDiagram: {e}")
        return False  # Failure

: 

In [None]:
visualizeTemporalSankeyDiagram('plaincsv.csv')  # Replace with your CSV file path

: 

In [None]:
def visualizeTemporalNetworkFlow(data_file):
    """
    Visualizes a temporal network flow map.
    """
    try:
        # Load US States dataset
        us_states = gpd.read_file('https://raw.githubusercontent.com/PublicaMundi/MappingAPI/master/data/geojson/us-states.json')
        
        # Reproject to a projected CRS for accurate centroids
        us_states_3395 = us_states.to_crs('EPSG:3395')
        us_states_3395['centroid_3395'] = us_states_3395.geometry.centroid
        
        # Convert centroids back to WGS84 (EPSG:4326)
        centroids_4326 = us_states_3395.set_geometry('centroid_3395').to_crs('EPSG:4326')
        
        # Extract lat/lon from centroid points
        centroids_4326['lat'] = centroids_4326.geometry.y
        centroids_4326['lon'] = centroids_4326.geometry.x

        # Create a dictionary of state centroids: {state_name: (lat, lon)}
        state_centroids = {row['name']: (row['lat'], row['lon']) for _, row in centroids_4326.iterrows()}

        # Load the data file
        connections = pd.read_csv(data_file)

        # Extract unique years for the slider
        years = sorted(connections['Year'].unique())

        # Prepare color scale
        viridis = px.colors.sequential.Viridis
        min_magnitude = connections['magnitude'].min()
        max_magnitude = connections['magnitude'].max()

        def get_color(magnitude):
            # Normalize magnitude to 0-1
            norm_val = (magnitude - min_magnitude) / (max_magnitude - min_magnitude)
            color_idx = int(norm_val * (len(viridis) - 1))
            return viridis[color_idx]

        # Create frames for each year
        frames = []
        for year in years:
            year_data = connections[connections['Year'] == year]

            traces = []
            for _, row in year_data.iterrows():
                state1, state2 = row['state1'], row['state2']
                if state1 in state_centroids and state2 in state_centroids:
                    lat1, lon1 = state_centroids[state1]
                    lat2, lon2 = state_centroids[state2]

                    magnitude = row['magnitude']
                    color = get_color(magnitude)

                    traces.append(go.Scattergeo(
                        lon=[lon1, lon2],
                        lat=[lat1, lat2],
                        mode='lines',
                        line=dict(width=2, color=color),
                        hoverinfo='text',
                        text=f"{state1} → {state2}: {magnitude}",
                        showlegend=False  # Each trace won't have a legend entry
                    ))

            frames.append(go.Frame(data=traces, name=str(year)))

        # Initialize figure with the first year's data
        initial_data = frames[0].data if frames else []
        fig = go.Figure(data=initial_data, frames=frames)

        # Update layout for geographic map and remove legend
        fig.update_layout(
            title='Temporal Network Flow Map',
            showlegend=False
        )

        fig.update_geos(
            scope='usa',
            projection_type='albers usa',
            showland=True,
            landcolor='rgb(240,240,240)',
            subunitcolor='rgb(100,100,100)',
            countrycolor='rgb(100,100,100)'
        )

        # Add slider without updatemenus
        sliders = [dict(
            steps=[
                dict(method="animate",
                     args=[[str(year)], {"frame": {"duration": 500, "redraw": True}, "mode":"immediate"}],
                     label=str(year)) for year in years
            ],
            transition=dict(duration=300),
            x=0.1,
            len=0.9
        )]

        fig.update_layout(sliders=sliders)

        fig.show()
        return True

    except Exception as e:
        print(f"Error in visualizeTemporalNetworkFlow: {e}")
        return False

: 

In [None]:
visualizeTemporalNetworkFlow('test2.csv')  # Replace with your CSV file path

: 

In [None]:
def visualizeTemporalChoroplethMap(data_file):
    """
    Visualizes a temporal choropleth map with an interactive Plotly slider.

    Parameters:
    - data_file: str, path to the CSV file containing energy usage data

    Returns:
    - success: bool, True if the visualization was successful, False otherwise
    """
    try:
        # Load US States geometries
        states = gpd.read_file('https://raw.githubusercontent.com/PublicaMundi/MappingAPI/master/data/geojson/us-states.json')
        states = states.to_crs('EPSG:4326')  # Use EPSG:4326 for compatibility with Plotly

        # Load data from the provided file path
        energy_df = pd.read_csv(data_file)

        # Convert time_step from YYYYMM to YYYY-MM format for better readability
        energy_df['time_step'] = energy_df['time_step'].astype(str).apply(lambda x: f"{x[:4]}-{x[4:]}")

        # Ensure the state names match by checking a few sample values
        missing_states = set(energy_df['state_name']) - set(states['name'])
        if missing_states:
            print(f"Warning: The following states in the data file do not match the GeoJSON file: {missing_states}")

        # Merge data with state geometries using state names
        df_merged = states[['name', 'geometry']].merge(energy_df, left_on='name', right_on='state_name')

        # Convert GeoDataFrame to GeoJSON format for Plotly
        geojson = json.loads(df_merged.to_json())

        # Generate the initial Plotly choropleth map with slider for time steps
        fig = px.choropleth(
            df_merged,
            geojson=geojson,
            locations="state_name",
            featureidkey="properties.name",
            color="energy_usage",
            color_continuous_scale="Viridis",
            animation_frame="time_step",
            title="Temporal Energy Usage by State"
        )

        fig.update_geos(
            visible=True,
            projection_type="albers usa"
        )
        fig.update_layout(
            margin={"r":0,"t":30,"l":0,"b":0},
            coloraxis_colorbar=dict(title="Energy Usage")
        )

        fig.show()
        return True  # Success
    except Exception as e:
        print(f"Error in visualizeTemporalChoroplethMap: {e}")
        return False  # Failure

: 

In [None]:
visualizeTemporalChoroplethMap('sample_energy_usage_with_names.csv')  # Replace with your CSV file path for actual testing

: 