In [2]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import folium
from folium.plugins import MarkerCluster
import json
from typing import Optional, Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

In [3]:

# =============================================================================
# COLOR PALETTES - NYC-inspired industrial aesthetic
# =============================================================================
COLORS = {
    'background': '#0a0a0f',
    'surface': '#14141f',
    'surface_light': '#1e1e2e',
    'primary': '#ff6b35',      # Warm orange
    'secondary': '#4ecdc4',     # Teal
    'accent': '#ffe66d',        # Yellow
    'text': '#f7f7f7',
    'text_muted': '#8b8b9e',
    'restaurant': '#ff6b35',
    'center': '#4ecdc4',
    'neighborhood': '#a855f7',
    'flow': '#ffe66d',
    'grid': '#2a2a3e',
}

# =============================================================================
# 1. PARETO CURVE VISUALIZATION
# =============================================================================

def create_pareto_curve(
    pareto_df: pd.DataFrame,
    x_col: str = 'transport_cost',
    y_col: str = 'equity_t',
    title: str = "Cost vs. Equity Pareto Frontier",
    interactive: bool = True
) -> go.Figure:
    """
    Create an interactive Pareto frontier visualization.
    
    Parameters
    ----------
    pareto_df : DataFrame with columns for objectives
    x_col : Column name for x-axis (cost objective)
    y_col : Column name for y-axis (equity objective - worst unmet demand)
    title : Plot title
    interactive : If True, returns interactive Plotly figure
    
    Returns
    -------
    Plotly Figure object
    """
    
    # Sort by x for proper line connection
    df = pareto_df.sort_values(x_col).copy()
    
    # Create figure
    fig = go.Figure()
    
    # Add the Pareto frontier line
    fig.add_trace(go.Scatter(
        x=df[x_col],
        y=df[y_col],
        mode='lines',
        line=dict(
            color=COLORS['primary'],
            width=3,
            shape='spline',
            smoothing=0.8
        ),
        name='Pareto Frontier',
        hoverinfo='skip'
    ))
    
    # Add points with full info on hover
    hover_text = []
    for _, row in df.iterrows():
        text = (
            f"<b>Solution</b><br>"
            f"Cost Weight: {row.get('w_cost', 'N/A')}<br>"
            f"Equity Weight: {row.get('w_eq', 'N/A')}<br>"
            f"Transport Cost: {row[x_col]:,.0f}<br>"
            f"Worst Unmet: {row[y_col]:,.0f}<br>"
            f"Total Delivered: {row.get('total_recv', 'N/A'):,.0f}"
        )
        hover_text.append(text)
    
    fig.add_trace(go.Scatter(
        x=df[x_col],
        y=df[y_col],
        mode='markers',
        marker=dict(
            size=14,
            color=df.index,
            colorscale=[
                [0, COLORS['secondary']],
                [0.5, COLORS['primary']],
                [1, COLORS['accent']]
            ],
            line=dict(color=COLORS['text'], width=2),
            symbol='circle'
        ),
        name='Solutions',
        text=hover_text,
        hovertemplate='%{text}<extra></extra>'
    ))
    
    # Mark extremes
    cost_optimal = df.loc[df[x_col].idxmin()]
    equity_optimal = df.loc[df[y_col].idxmin()]
    
    # Cost-optimal annotation
    fig.add_annotation(
        x=cost_optimal[x_col],
        y=cost_optimal[y_col],
        text="Cost<br>Optimal",
        showarrow=True,
        arrowhead=2,
        arrowsize=1,
        arrowwidth=2,
        arrowcolor=COLORS['secondary'],
        font=dict(size=12, color=COLORS['secondary']),
        ax=-60,
        ay=-40
    )
    
    # Equity-optimal annotation
    fig.add_annotation(
        x=equity_optimal[x_col],
        y=equity_optimal[y_col],
        text="Equity<br>Optimal",
        showarrow=True,
        arrowhead=2,
        arrowsize=1,
        arrowwidth=2,
        arrowcolor=COLORS['accent'],
        font=dict(size=12, color=COLORS['accent']),
        ax=60,
        ay=40
    )
    
    # Layout
    fig.update_layout(
        title=dict(
            text=f"<b>{title}</b><br><sub>NYC Food Rescue Distribution Network</sub>",
            font=dict(size=24, color=COLORS['text'], family="Helvetica Neue"),
            x=0.5
        ),
        xaxis=dict(
            title=dict(text="Transportation Cost (miles)", font=dict(size=14)),
            tickformat=',',
            gridcolor=COLORS['grid'],
            zerolinecolor=COLORS['grid'],
            color=COLORS['text_muted']
        ),
        yaxis=dict(
            title=dict(text="Worst Unmet Demand (lbs)", font=dict(size=14)),
            tickformat=',',
            gridcolor=COLORS['grid'],
            zerolinecolor=COLORS['grid'],
            color=COLORS['text_muted']
        ),
        plot_bgcolor=COLORS['surface'],
        paper_bgcolor=COLORS['background'],
        font=dict(color=COLORS['text'], family="Helvetica Neue"),
        showlegend=False,
        margin=dict(l=80, r=40, t=100, b=80),
        hoverlabel=dict(
            bgcolor=COLORS['surface_light'],
            font_size=12,
            font_family="Helvetica Neue"
        )
    )
    
    return fig


def create_animated_pareto(
    pareto_df: pd.DataFrame,
    x_col: str = 'transport_cost',
    y_col: str = 'equity_t'
) -> go.Figure:
    """
    Create an animated Pareto curve that builds up point by point.
    Useful for presentations.
    """
    df = pareto_df.sort_values(x_col).reset_index(drop=True)
    
    # Create frames
    frames = []
    for i in range(1, len(df) + 1):
        frame_data = df.iloc[:i]
        frames.append(go.Frame(
            data=[
                go.Scatter(
                    x=frame_data[x_col],
                    y=frame_data[y_col],
                    mode='lines+markers',
                    line=dict(color=COLORS['primary'], width=3),
                    marker=dict(
                        size=12,
                        color=COLORS['secondary'],
                        line=dict(color=COLORS['text'], width=2)
                    )
                )
            ],
            name=str(i)
        ))
    
    # Initial figure
    fig = go.Figure(
        data=[go.Scatter(
            x=[df[x_col].iloc[0]],
            y=[df[y_col].iloc[0]],
            mode='markers',
            marker=dict(size=12, color=COLORS['secondary'])
        )],
        frames=frames
    )
    
    # Add play button
    fig.update_layout(
        title="<b>Building the Pareto Frontier</b>",
        xaxis=dict(
            title="Transportation Cost",
            range=[df[x_col].min() * 0.9, df[x_col].max() * 1.1],
            gridcolor=COLORS['grid']
        ),
        yaxis=dict(
            title="Worst Unmet Demand",
            range=[df[y_col].min() * 0.9, df[y_col].max() * 1.1],
            gridcolor=COLORS['grid']
        ),
        plot_bgcolor=COLORS['surface'],
        paper_bgcolor=COLORS['background'],
        font=dict(color=COLORS['text']),
        updatemenus=[dict(
            type='buttons',
            showactive=False,
            y=1.15,
            x=0.5,
            xanchor='center',
            buttons=[
                dict(
                    label='‚ñ∂ Play',
                    method='animate',
                    args=[None, dict(
                        frame=dict(duration=500, redraw=True),
                        fromcurrent=True,
                        transition=dict(duration=300)
                    )]
                ),
                dict(
                    label='‚è∏ Pause',
                    method='animate',
                    args=[[None], dict(
                        frame=dict(duration=0, redraw=False),
                        mode='immediate',
                        transition=dict(duration=0)
                    )]
                )
            ]
        )],
        sliders=[dict(
            active=0,
            yanchor='top',
            xanchor='left',
            currentvalue=dict(
                font=dict(size=16),
                prefix='Solution: ',
                visible=True,
                xanchor='center'
            ),
            pad=dict(b=10, t=50),
            len=0.9,
            x=0.05,
            y=0,
            steps=[
                dict(
                    args=[[str(i)], dict(
                        frame=dict(duration=300, redraw=True),
                        mode='immediate',
                        transition=dict(duration=300)
                    )],
                    label=str(i),
                    method='animate'
                )
                for i in range(1, len(df) + 1)
            ]
        )]
    )
    
    return fig


# =============================================================================
# 2. NETWORK FLOW MAP - FOLIUM
# =============================================================================

def create_network_map(
    restaurants_df: pd.DataFrame,
    centers_df: pd.DataFrame,
    neighborhoods_df: pd.DataFrame,
    flows_df: Optional[pd.DataFrame] = None,
    show_flows: bool = True,
    flow_threshold: float = 1000,
    cluster_markers: bool = True
) -> folium.Map:
    """
    Create an interactive Folium map showing the food rescue network.
    
    Parameters
    ----------
    restaurants_df : DataFrame with id, latitude, longitude, supply
    centers_df : DataFrame with id, latitude, longitude
    neighborhoods_df : DataFrame with id, latitude, longitude, demand
    flows_df : Optional DataFrame with from_type, from_id, to_type, to_id, flow
    show_flows : Whether to draw flow lines
    flow_threshold : Minimum flow to display
    cluster_markers : Whether to cluster markers for performance
    
    Returns
    -------
    Folium Map object
    """
    
    # Center on NYC
    nyc_center = [40.7128, -73.9560]
    
    # Create map with dark theme
    m = folium.Map(
        location=nyc_center,
        zoom_start=11,
        tiles='CartoDB dark_matter'
    )
    
    # Create feature groups
    restaurant_group = folium.FeatureGroup(name='üçΩÔ∏è Restaurants (Supply)')
    center_group = folium.FeatureGroup(name='üì¶ Distribution Centers')
    neighborhood_group = folium.FeatureGroup(name='üèòÔ∏è Neighborhoods (Demand)')
    flow_group = folium.FeatureGroup(name='üöö Food Flows')
    
    # Helper to create popup
    def make_popup(title, **kwargs):
        html = f"<b style='font-size:14px'>{title}</b><br>"
        for k, v in kwargs.items():
            if isinstance(v, float):
                html += f"{k}: {v:,.0f}<br>"
            else:
                html += f"{k}: {v}<br>"
        return folium.Popup(html, max_width=250)
    
    # Add restaurants
    if cluster_markers:
        restaurant_cluster = MarkerCluster(name='Restaurants')
        for _, row in restaurants_df.iterrows():
            folium.CircleMarker(
                location=[row['latitude'], row['longitude']],
                radius=5 + np.log1p(row['supply']) / 2,
                color='#ff6b35',
                fill=True,
                fill_color='#ff6b35',
                fill_opacity=0.7,
                popup=make_popup(
                    f"Restaurant #{int(row['id'])}",
                    Supply=row['supply']
                )
            ).add_to(restaurant_cluster)
        restaurant_cluster.add_to(restaurant_group)
    else:
        for _, row in restaurants_df.iterrows():
            folium.CircleMarker(
                location=[row['latitude'], row['longitude']],
                radius=4,
                color='#ff6b35',
                fill=True,
                fill_color='#ff6b35',
                fill_opacity=0.7,
                popup=make_popup(
                    f"Restaurant #{int(row['id'])}",
                    Supply=row['supply']
                )
            ).add_to(restaurant_group)
    
    # Add distribution centers
    for _, row in centers_df.iterrows():
        folium.Marker(
            location=[row['latitude'], row['longitude']],
            icon=folium.Icon(color='green', icon='archive', prefix='fa'),
            popup=make_popup(f"Distribution Center #{int(row['id'])}")
        ).add_to(center_group)
    
    # Add neighborhoods (only those with demand > 0)
    demand_neighborhoods = neighborhoods_df[neighborhoods_df['demand'] > 0]
    max_demand = demand_neighborhoods['demand'].max()
    
    for _, row in demand_neighborhoods.iterrows():
        # Scale radius by demand
        radius = 3 + 15 * (row['demand'] / max_demand)
        
        folium.CircleMarker(
            location=[row['latitude'], row['longitude']],
            radius=radius,
            color='#a855f7',
            fill=True,
            fill_color='#a855f7',
            fill_opacity=0.5,
            popup=make_popup(
                f"Neighborhood #{int(row['id'])}",
                Demand=row['demand']
            )
        ).add_to(neighborhood_group)
    
    # Add flows if provided
    if show_flows and flows_df is not None:
        # Filter by threshold
        significant_flows = flows_df[flows_df['flow'] > flow_threshold]
        max_flow = significant_flows['flow'].max() if len(significant_flows) > 0 else 1
        
        # Build lookup dicts
        restaurant_coords = dict(zip(
            restaurants_df['id'],
            zip(restaurants_df['latitude'], restaurants_df['longitude'])
        ))
        center_coords = dict(zip(
            centers_df['id'],
            zip(centers_df['latitude'], centers_df['longitude'])
        ))
        neighborhood_coords = dict(zip(
            neighborhoods_df['id'],
            zip(neighborhoods_df['latitude'], neighborhoods_df['longitude'])
        ))
        
        for _, row in significant_flows.iterrows():
            # Get coordinates
            if row['from_type'] == 'restaurant':
                from_coord = restaurant_coords.get(row['from_id'])
            else:
                from_coord = center_coords.get(row['from_id'])
            
            if row['to_type'] == 'center':
                to_coord = center_coords.get(row['to_id'])
            else:
                to_coord = neighborhood_coords.get(row['to_id'])
            
            if from_coord is None or to_coord is None:
                continue
            
            # Scale line weight by flow
            weight = 1 + 4 * (row['flow'] / max_flow)
            
            # Color by flow type
            if row['from_type'] == 'restaurant':
                color = '#ffe66d'  # Yellow for restaurant->center
            else:
                color = '#4ecdc4'  # Teal for center->neighborhood
            
            folium.PolyLine(
                locations=[from_coord, to_coord],
                weight=weight,
                color=color,
                opacity=0.6,
                popup=f"Flow: {row['flow']:,.0f} lbs"
            ).add_to(flow_group)
    
    # Add layers to map
    restaurant_group.add_to(m)
    center_group.add_to(m)
    neighborhood_group.add_to(m)
    if show_flows and flows_df is not None:
        flow_group.add_to(m)
    
    # Add layer control
    folium.LayerControl().add_to(m)
    
    # Add title
    title_html = '''
    <div style="position: fixed; 
                top: 10px; left: 50px; width: 300px;
                background-color: rgba(20, 20, 31, 0.9);
                border: 2px solid #ff6b35;
                border-radius: 8px;
                padding: 10px;
                z-index: 9999;
                font-family: 'Helvetica Neue', sans-serif;">
        <h4 style="color: #f7f7f7; margin: 0 0 5px 0;">NYC Food Rescue Network</h4>
        <p style="color: #8b8b9e; margin: 0; font-size: 11px;">
            üçΩÔ∏è Orange: Restaurants (Supply)<br>
            üì¶ Green: Distribution Centers<br>
            üèòÔ∏è Purple: Neighborhoods (Demand)<br>
            üöö Lines: Food Flows
        </p>
    </div>
    '''
    m.get_root().html.add_child(folium.Element(title_html))
    
    return m


# =============================================================================
# 3. PLOTLY NETWORK MAP (Alternative)
# =============================================================================

def create_plotly_network_map(
    restaurants_df: pd.DataFrame,
    centers_df: pd.DataFrame,
    neighborhoods_df: pd.DataFrame,
    flows_df: Optional[pd.DataFrame] = None,
    flow_threshold: float = 1000
) -> go.Figure:
    """
    Create a Plotly scattermapbox visualization of the network.
    Faster rendering than Folium for large networks.
    """
    
    fig = go.Figure()
    
    # Add flows first (so they're behind markers)
    if flows_df is not None:
        significant_flows = flows_df[flows_df['flow'] > flow_threshold]
        
        # Build lookup dicts
        restaurant_coords = dict(zip(
            restaurants_df['id'],
            zip(restaurants_df['latitude'], restaurants_df['longitude'])
        ))
        center_coords = dict(zip(
            centers_df['id'],
            zip(centers_df['latitude'], centers_df['longitude'])
        ))
        neighborhood_coords = dict(zip(
            neighborhoods_df['id'],
            zip(neighborhoods_df['latitude'], neighborhoods_df['longitude'])
        ))
        
        for _, row in significant_flows.iterrows():
            if row['from_type'] == 'restaurant':
                from_coord = restaurant_coords.get(row['from_id'])
            else:
                from_coord = center_coords.get(row['from_id'])
            
            if row['to_type'] == 'center':
                to_coord = center_coords.get(row['to_id'])
            else:
                to_coord = neighborhood_coords.get(row['to_id'])
            
            if from_coord is None or to_coord is None:
                continue
            
            fig.add_trace(go.Scattermapbox(
                lat=[from_coord[0], to_coord[0]],
                lon=[from_coord[1], to_coord[1]],
                mode='lines',
                line=dict(
                    width=1 + 3 * (row['flow'] / significant_flows['flow'].max()),
                    color=COLORS['flow']
                ),
                opacity=0.5,
                hoverinfo='skip',
                showlegend=False
            ))
    
    # Add restaurants
    fig.add_trace(go.Scattermapbox(
        lat=restaurants_df['latitude'],
        lon=restaurants_df['longitude'],
        mode='markers',
        marker=dict(
            size=6 + np.log1p(restaurants_df['supply']) / 2,
            color=COLORS['restaurant'],
            opacity=0.8
        ),
        text=[f"Restaurant {i}<br>Supply: {s:,.0f}" 
              for i, s in zip(restaurants_df['id'], restaurants_df['supply'])],
        hoverinfo='text',
        name='Restaurants'
    ))
    
    # Add centers
    fig.add_trace(go.Scattermapbox(
        lat=centers_df['latitude'],
        lon=centers_df['longitude'],
        mode='markers',
        marker=dict(
            size=12,
            color=COLORS['center'],
            symbol='square'
        ),
        text=[f"Center {i}" for i in centers_df['id']],
        hoverinfo='text',
        name='Distribution Centers'
    ))
    
    # Add neighborhoods with demand
    demand_neighborhoods = neighborhoods_df[neighborhoods_df['demand'] > 0]
    
    fig.add_trace(go.Scattermapbox(
        lat=demand_neighborhoods['latitude'],
        lon=demand_neighborhoods['longitude'],
        mode='markers',
        marker=dict(
            size=5 + 15 * (demand_neighborhoods['demand'] / demand_neighborhoods['demand'].max()),
            color=COLORS['neighborhood'],
            opacity=0.6
        ),
        text=[f"Neighborhood {i}<br>Demand: {d:,.0f}" 
              for i, d in zip(demand_neighborhoods['id'], demand_neighborhoods['demand'])],
        hoverinfo='text',
        name='Neighborhoods'
    ))
    
    # Layout
    fig.update_layout(
        mapbox=dict(
            style='carto-darkmatter',
            center=dict(lat=40.7128, lon=-73.9560),
            zoom=10
        ),
        title=dict(
            text="<b>NYC Food Rescue Network</b>",
            font=dict(size=20, color=COLORS['text']),
            x=0.5
        ),
        paper_bgcolor=COLORS['background'],
        font=dict(color=COLORS['text']),
        margin=dict(l=0, r=0, t=50, b=0),
        legend=dict(
            bgcolor=COLORS['surface'],
            bordercolor=COLORS['grid'],
            borderwidth=1
        ),
        showlegend=True
    )
    
    return fig


# =============================================================================
# 4. RESULTS COMPARISON DASHBOARD
# =============================================================================

def create_solution_comparison(
    pareto_df: pd.DataFrame,
    neighborhoods_df: pd.DataFrame,
    solution_allocations: Dict[str, pd.DataFrame]
) -> go.Figure:
    """
    Create a multi-panel comparison of different solutions.
    
    Parameters
    ----------
    pareto_df : Pareto frontier data
    neighborhoods_df : Neighborhood data with demand
    solution_allocations : Dict mapping solution names to DataFrames
                          with columns: neighborhood_id, received
    """
    
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=(
            'Pareto Frontier',
            'Distribution Spread',
            'Demand Coverage',
            'Equity Distribution'
        ),
        specs=[
            [{'type': 'scatter'}, {'type': 'bar'}],
            [{'type': 'bar'}, {'type': 'box'}]
        ]
    )
    
    # 1. Pareto Frontier (top-left)
    df = pareto_df.sort_values('transport_cost')
    fig.add_trace(
        go.Scatter(
            x=df['transport_cost'],
            y=df['equity_t'],
            mode='lines+markers',
            marker=dict(size=10, color=COLORS['primary']),
            line=dict(color=COLORS['primary']),
            name='Pareto Frontier'
        ),
        row=1, col=1
    )
    
    # 2. Distribution spread - neighborhoods served
    colors = [COLORS['primary'], COLORS['secondary'], COLORS['accent']]
    for i, (name, alloc) in enumerate(solution_allocations.items()):
        served = len(alloc[alloc['received'] > 0])
        fig.add_trace(
            go.Bar(
                x=[name],
                y=[served],
                marker_color=colors[i % len(colors)],
                name=name,
                showlegend=False
            ),
            row=1, col=2
        )
    
    # 3. Total demand coverage
    total_demand = neighborhoods_df['demand'].sum()
    for i, (name, alloc) in enumerate(solution_allocations.items()):
        pct = 100 * alloc['received'].sum() / total_demand
        fig.add_trace(
            go.Bar(
                x=[name],
                y=[pct],
                marker_color=colors[i % len(colors)],
                name=name,
                showlegend=False
            ),
            row=2, col=1
        )
    
    # 4. Equity distribution (box plot of unmet demand)
    for i, (name, alloc) in enumerate(solution_allocations.items()):
        merged = neighborhoods_df.merge(alloc, left_on='id', right_on='neighborhood_id', how='left')
        merged['received'] = merged['received'].fillna(0)
        merged['unmet'] = merged['demand'] - merged['received']
        
        fig.add_trace(
            go.Box(
                y=merged['unmet'],
                name=name,
                marker_color=colors[i % len(colors)],
                boxpoints='outliers'
            ),
            row=2, col=2
        )
    
    # Update layout
    fig.update_layout(
        title=dict(
            text="<b>Solution Comparison Dashboard</b>",
            font=dict(size=22, color=COLORS['text']),
            x=0.5
        ),
        plot_bgcolor=COLORS['surface'],
        paper_bgcolor=COLORS['background'],
        font=dict(color=COLORS['text']),
        showlegend=False,
        height=700
    )
    
    # Update axes
    fig.update_xaxes(gridcolor=COLORS['grid'], title_font=dict(size=12))
    fig.update_yaxes(gridcolor=COLORS['grid'], title_font=dict(size=12))
    
    fig.update_xaxes(title_text="Transport Cost", row=1, col=1)
    fig.update_yaxes(title_text="Worst Unmet", row=1, col=1)
    fig.update_yaxes(title_text="Neighborhoods Served", row=1, col=2)
    fig.update_yaxes(title_text="% Demand Met", row=2, col=1)
    fig.update_yaxes(title_text="Unmet Demand", row=2, col=2)
    
    return fig


# =============================================================================
# 5. DEMAND HEATMAP
# =============================================================================

def create_demand_heatmap(
    neighborhoods_df: pd.DataFrame,
    received_df: Optional[pd.DataFrame] = None
) -> go.Figure:
    """
    Create a heatmap showing demand (and optionally unmet demand) by location.
    """
    
    if received_df is not None:
        merged = neighborhoods_df.merge(
            received_df, 
            left_on='id', 
            right_on='neighborhood_id', 
            how='left'
        )
        merged['received'] = merged['received'].fillna(0)
        merged['unmet'] = merged['demand'] - merged['received']
        z_values = merged['unmet']
        colorbar_title = 'Unmet Demand'
        title = 'Unmet Food Demand by Neighborhood'
    else:
        z_values = neighborhoods_df['demand']
        colorbar_title = 'Demand (lbs)'
        title = 'Food Demand by Neighborhood'
    
    fig = go.Figure(go.Densitymapbox(
        lat=neighborhoods_df['latitude'],
        lon=neighborhoods_df['longitude'],
        z=z_values,
        radius=20,
        colorscale=[
            [0, 'rgba(78, 205, 196, 0.1)'],
            [0.25, 'rgba(255, 230, 109, 0.4)'],
            [0.5, 'rgba(255, 107, 53, 0.6)'],
            [0.75, 'rgba(168, 85, 247, 0.8)'],
            [1, 'rgba(168, 85, 247, 1)']
        ],
        colorbar=dict(
            title=dict(text=colorbar_title, side='right')
        ),
        hovertemplate='Demand: %{z:,.0f}<extra></extra>'
    ))
    
    fig.update_layout(
        mapbox=dict(
            style='carto-darkmatter',
            center=dict(lat=40.7128, lon=-73.9560),
            zoom=10
        ),
        title=dict(
            text=f"<b>{title}</b>",
            font=dict(size=20, color=COLORS['text']),
            x=0.5
        ),
        paper_bgcolor=COLORS['background'],
        margin=dict(l=0, r=0, t=50, b=0)
    )
    
    return fig


# =============================================================================
# 6. UTILITY FUNCTIONS
# =============================================================================

def save_all_visualizations(
    output_dir: str,
    pareto_df: pd.DataFrame,
    restaurants_df: pd.DataFrame,
    centers_df: pd.DataFrame,
    neighborhoods_df: pd.DataFrame,
    flows_df: Optional[pd.DataFrame] = None
):
    """
    Generate and save all visualizations to the specified directory.
    """
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    print("Generating Pareto curve...")
    pareto_fig = create_pareto_curve(pareto_df)
    pareto_fig.write_html(f"{output_dir}/pareto_curve.html")
    
    print("Generating animated Pareto curve...")
    animated_fig = create_animated_pareto(pareto_df)
    animated_fig.write_html(f"{output_dir}/pareto_animated.html")
    
    print("Generating network map (Folium)...")
    network_map = create_network_map(
        restaurants_df, centers_df, neighborhoods_df, flows_df
    )
    network_map.save(f"{output_dir}/network_map.html")
    
    print("Generating network map (Plotly)...")
    plotly_map = create_plotly_network_map(
        restaurants_df, centers_df, neighborhoods_df, flows_df
    )
    plotly_map.write_html(f"{output_dir}/network_map_plotly.html")
    
    print("Generating demand heatmap...")
    heatmap = create_demand_heatmap(neighborhoods_df)
    heatmap.write_html(f"{output_dir}/demand_heatmap.html")
    
    print(f"All visualizations saved to {output_dir}/")


# =============================================================================
# DEMO WITH SAMPLE DATA
# =============================================================================

def create_sample_data():
    """
    Create sample data matching the structure from the Julia notebook
    for testing visualizations.
    """
    
    np.random.seed(42)
    
    # Sample Pareto results (based on notebook output)
    pareto_df = pd.DataFrame({
        'w_cost': [1.0]*10 + [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0],
        'w_eq': [float(i) for i in range(1, 11)] + [1.0]*9,
        'transport_cost': [
            1.5e6, 1.67e6, 2.07e6, 2.28e6, 2.59e6, 2.63e6, 2.63e6, 2.63e6, 2.63e6, 2.63e6,
            0.85e6, 0.85e6, 0.85e6, 0.85e6, 0.86e6, 0.86e6, 0.87e6, 0.88e6, 0.97e6
        ],
        'equity_t': [
            2.88e6, 2.40e6, 2.23e6, 2.17e6, 2.10e6, 2.095e6, 2.095e6, 2.095e6, 2.095e6, 2.095e6,
            3.44e6, 3.44e6, 3.44e6, 3.44e6, 3.37e6, 3.31e6, 3.18e6, 3.14e6, 2.93e6
        ],
        'total_recv': [1.6e7]*19,
        'total_unmet': [1.36e8]*19
    })
    
    # Sample restaurants (319 in original)
    n_restaurants = 50
    restaurants_df = pd.DataFrame({
        'id': range(1, n_restaurants + 1),
        'latitude': 40.7 + np.random.randn(n_restaurants) * 0.08,
        'longitude': -73.95 + np.random.randn(n_restaurants) * 0.08,
        'supply': np.random.exponential(50000, n_restaurants)
    })
    
    # Sample distribution centers (201 in original)
    n_centers = 30
    centers_df = pd.DataFrame({
        'id': range(1, n_centers + 1),
        'latitude': 40.72 + np.random.randn(n_centers) * 0.1,
        'longitude': -73.98 + np.random.randn(n_centers) * 0.08
    })
    
    # Sample neighborhoods (591 in original)
    n_neighborhoods = 100
    neighborhoods_df = pd.DataFrame({
        'id': range(1, n_neighborhoods + 1),
        'latitude': 40.75 + np.random.randn(n_neighborhoods) * 0.12,
        'longitude': -73.9 + np.random.randn(n_neighborhoods) * 0.1,
        'demand': np.maximum(0, np.random.exponential(200000, n_neighborhoods))
    })
    
    # Sample flows
    flows = []
    for i in range(1, 20):
        j = np.random.randint(1, n_centers + 1)
        flows.append(('restaurant', i, 'center', j, np.random.exponential(20000)))
    for j in range(1, 15):
        k = np.random.randint(1, n_neighborhoods + 1)
        flows.append(('center', j, 'neighborhood', k, np.random.exponential(30000)))
    
    flows_df = pd.DataFrame(
        flows,
        columns=['from_type', 'from_id', 'to_type', 'to_id', 'flow']
    )
    
    return pareto_df, restaurants_df, centers_df, neighborhoods_df, flows_df

In [4]:
if __name__ == "__main__":
    print("Creating sample visualizations...")
    
    # Generate sample data
    pareto_df, restaurants_df, centers_df, neighborhoods_df, flows_df = create_sample_data()
    
    # Save all visualizations
    save_all_visualizations(
        output_dir="./viz_output",
        pareto_df=pareto_df,
        restaurants_df=restaurants_df,
        centers_df=centers_df,
        neighborhoods_df=neighborhoods_df,
        flows_df=flows_df
    )
    
    print("\nDone! Open the HTML files in viz_output/ to view.")

Creating sample visualizations...
Generating Pareto curve...
Generating animated Pareto curve...
Generating network map (Folium)...
Generating network map (Plotly)...
Generating demand heatmap...
All visualizations saved to ./viz_output/

Done! Open the HTML files in viz_output/ to view.
