# Explanatory Visualizations with focus on intra-simulation variation

## Plot of Z (economic) position

In [1]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from scipy.stats import beta
import dash
import dash_bootstrap_components as dbc
from dash import dcc, html
from dash.dependencies import Input, Output

from indirect_pathway.src.model.indirect_effect import generate_stratification_positions
from core.visualization.style import plotly_theme_decorator

plot_height = 600



# Initialize the Dash app with bootstrap theme
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# App layout using bootstrap components
app.layout = html.Div([
    html.H1("Stratification Position Visualization", className="my-4"),

    # Main row containing controls and chart
    dbc.Row([
        # Left column - Controls card
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Controls"),
                dbc.CardBody([
                    # Histogram controls
                    html.H5("Histogram Controls", className="mb-3"),
                    html.Label("Sample Size:"),
                    dcc.Slider(
                        id='sample-size-slider',
                        min=100, max=10000, step=100, value=1000,
                        marks={i: str(i) for i in range(0, 10001, 2000)},
                        className="mb-4"
                    ),
                    html.Label("Proportion Disadvantaged (p):"),
                    dcc.Slider(
                        id='p-slider',
                        min=0.01, max=0.99, step=0.01, value=0.3,
                        marks={i/10: str(i/10) for i in range(0, 11, 2)},
                        className="mb-4"
                    ),

                    # Distribution controls
                    html.H5("Distribution Parameters", className="mb-3 mt-4"),
                    html.Label("Mean Position (Disadvantaged):"),
                    dcc.Slider(
                        id='mu-disadv-slider',
                        min=0.01, max=0.9, step=0.01, value=0.2,
                        marks={i/10: str(i/10) for i in range(0, 10, 2)},
                        className="mb-4"
                    ),
                    html.Label("Position Gap:"),
                    dcc.Slider(
                        id='z-position-gap-slider',
                        min=0, max=0.8, step=0.01, value=0.3,
                        marks={i/10: str(i/10) for i in range(0, 9, 2)},
                        className="mb-4"
                    ),
                    html.Label("Concentration (Disadvantaged):"),
                    dcc.Slider(
                        id='c-disadv-slider',
                        min=1, max=50, step=1, value=20,
                        marks={i: str(i) for i in range(0, 51, 10)},
                        className="mb-4"
                    ),
                    html.Label("Concentration (Advantaged):"),
                    dcc.Slider(
                        id='c-adv-slider',
                        min=1, max=50, step=1, value=20,
                        marks={i: str(i) for i in range(0, 51, 10)},
                        className="mb-4"
                    ),
                ])
            ])
        ], width=3),

        # Right column - Chart
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Visualization"),
                dbc.CardBody([
                    dcc.Graph(
                        id='stratification-plot',
                        style={'height': f'{plot_height}px', 'width': '100%'}
                    )
                ])
            ])
        ], width=9)
    ])
])

@app.callback(
    Output('stratification-plot', 'figure'),
    [Input('sample-size-slider', 'value'),
     Input('p-slider', 'value'),
     Input('mu-disadv-slider', 'value'),
     Input('z-position-gap-slider', 'value'),
     Input('c-disadv-slider', 'value'),
     Input('c-adv-slider', 'value')]
)
def update_graph(sample_size, p, mu_disadv, z_position_gap, c_disadv, c_adv):
    # Generate positions
    positions = generate_stratification_positions(
        p=p,
        mu_disadv=mu_disadv,
        z_position_gap=z_position_gap,
        c_disadv=c_disadv,
        c_adv=c_adv,
        sample_size=sample_size,
    )
    
    return create_stratification_plot(positions, height=plot_height)

In [2]:
# Run the app and display in notebook
from IPython.display import IFrame
app.run(jupyter_mode='external', port=8050)
IFrame(src=f"http://127.0.0.1:8050", width="100%", height=plot_height+200)

Dash app running on http://127.0.0.1:8050/


## Plot of Normalized Expected Incarceration Curves

In [75]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import dash
import dash_bootstrap_components as dbc
from dash import dcc, html
from dash.dependencies import Input, Output

from indirect_pathway.src.model.indirect_effect import generate_stratification_positions, calculate_incarceration_rates_normalized
from core.visualization.style import plotly_theme_decorator

# Global legend group names
DISADV_GROUP = "Disadvantaged"
ADV_GROUP = "Advantaged"

@plotly_theme_decorator
def create_stratification_plot(positions, num_bins=100, **kwargs):
    """Creates a combined plot with histograms and PDF curves for both groups"""
    # Create figure with secondary y-axis
    fig = make_subplots(rows=1, cols=1, specs=[[{"secondary_y": True}]])

    # Extract parameters
    alpha_disadv = positions['alpha_disadv']
    beta_disadv = positions['beta_disadv']
    alpha_adv = positions['alpha_adv']
    beta_adv = positions['beta_adv']

    # Create histograms and PDF curves for disadvantaged group
    fig.add_trace(
        go.Histogram(
            x=positions['positions_disadv'],
            name=DISADV_GROUP,
            opacity=0.5,
            marker_color='red',
            nbinsx=num_bins,
            showlegend=True,
            legendgroup=DISADV_GROUP
        ),
        secondary_y=False
    )

    x = np.linspace(0, 1, 1000)
    fig.add_trace(
        go.Scatter(
            x=x,
            y=beta.pdf(x, alpha_disadv, beta_disadv),
            mode='lines',
            name=DISADV_GROUP,
            line=dict(color='darkred', width=2),
            showlegend=False,
            legendgroup=DISADV_GROUP
        ),
        secondary_y=True
    )

    # Create histograms and PDF curves for advantaged group
    fig.add_trace(
        go.Histogram(
            x=positions['positions_adv'],
            name=ADV_GROUP,
            opacity=0.5,
            marker_color='blue',
            nbinsx=num_bins,
            showlegend=True,
            legendgroup=ADV_GROUP
        ),
        secondary_y=False
    )

    fig.add_trace(
        go.Scatter(
            x=x,
            y=beta.pdf(x, alpha_adv, beta_adv),
            mode='lines',
            name=ADV_GROUP,
            line=dict(color='darkblue', width=2),
            showlegend=False,
            legendgroup=ADV_GROUP
        ),
        secondary_y=True
    )

    # Update layout
    fig.update_layout(
        title='Stratification Position Distributions',
        xaxis_title='Z Position (Economic Status)',
        legend_title='Groups',
        barmode='overlay'
    )

    # Update y-axis labels
    fig.update_yaxes(title_text="Frequency", secondary_y=False)
    fig.update_yaxes(title_text="Density", secondary_y=True)

    return fig

@plotly_theme_decorator
def create_incarceration_rate_plot(rate_data, gamma, target_avg_rate, positions):
    """Creates a plot of incarceration rates with position markers and group boxplots"""
    # Create figure with marginal box plots
    fig = make_subplots(
        rows=2, cols=2,
        column_widths=[0.2, 0.8],
        row_heights=[0.2, 0.8],
        specs=[
            [None,
             {"type": "xy", 'secondary_y':True}],
            [{"type": "box"}, 
             {"type": "scatter"}
             ]
        ],
        shared_xaxes='columns',
        shared_yaxes='rows',
        horizontal_spacing=0.05,
        vertical_spacing=0.05
    )
    
    # Extract data
    positions_disadv = rate_data['positions_disadv']
    positions_adv = rate_data['positions_adv']
    rates_disadv = rate_data['rates_disadv']
    rates_adv = rate_data['rates_adv']
    pop_avg_rate = rate_data['pop_avg_rate']
    
    # Generate curve data points
    x_curve = np.linspace(0, 1, 1000)
    expected_effect = rate_data['expected_effect']
    norm_factor = rate_data['norm_factor']
    y_curve = target_avg_rate * np.power(1 - x_curve, gamma) * norm_factor
    
    # Add the theoretical curve
    fig.add_trace(
        go.Scatter(
            x=x_curve,
            y=y_curve,
            mode='lines',
            name=f'Incarceration Rate Curve (γ={gamma:.1f})',
            line=dict(color='black', width=2)
        ),
        row=2, col=2
    )
    
    # Add position/rate markers for disadvantaged group
    fig.add_trace(
        go.Scatter(
            x=positions_disadv,
            y=rates_disadv,
            mode='markers',
            name=DISADV_GROUP,
            marker=dict(color='red', size=8, opacity=0.5),
            legendgroup=DISADV_GROUP
        ),
        row=2, col=2
    )
    
    # Add position/rate markers for advantaged group
    fig.add_trace(
        go.Scatter(
            x=positions_adv,
            y=rates_adv,
            mode='markers',
            name=ADV_GROUP,
            marker=dict(color='blue', size=8, opacity=0.5),
            legendgroup=ADV_GROUP
        ),
        row=2, col=2
    )
    
    # Add target/population average line to scatter plot
    fig.add_trace(
        go.Scatter(
            x=[0, 1],
            y=[pop_avg_rate, pop_avg_rate],
            mode='lines',
            name=f'Population Average ({pop_avg_rate:.1f})',
            line=dict(color='green', width=2, dash='dash')
        ),
        row=2, col=2
    )
    
    # Add target/population average line to box plot
    fig.add_trace(
        go.Scatter(
            x=[0, 1],
            y=[pop_avg_rate, pop_avg_rate],
            mode='lines',
            name=f'Population Average ({pop_avg_rate:.1f})',
            line=dict(color='green', width=2, dash='dash'),
            showlegend=False
        ),
        row=2, col=1
    )
    
    # Add boxplot for disadvantaged group
    fig.add_trace(
        go.Box(
            y=rates_disadv,
            name=DISADV_GROUP,
            marker_color='red',
            boxmean=True,
            legendgroup=DISADV_GROUP,
            showlegend=False
        ),
        row=2, col=1
    )
    
    # Add boxplot for advantaged group
    fig.add_trace(
        go.Box(
            y=rates_adv,
            name=ADV_GROUP,
            marker_color='blue',
            boxmean=True,
            legendgroup=ADV_GROUP,
            showlegend=False
        ),
        row=2, col=1
    )
    
    # Add distribution of positions along x-axis
    # Get the traces from create_stratification_plot
    strat_plot = create_stratification_plot(positions)
    
    # Add each trace to our main figure in row 1, col 2
    for i, trace in enumerate(strat_plot.data):
        # First trace goes on primary y-axis
        if trace.type == 'scatter':
            fig.add_trace(
                trace,
                row=1, col=2,
                secondary_y=True
            )
        # Second trace goes on secondary y-axis 
        else:
            fig.add_trace(
                trace,
                row=1, col=2,
                secondary_y=False
            )
    
    
    # Update layout
    fig.update_layout(
        title=f'Incarceration Rate by Stratification Position (γ={gamma:.1f}, Target Avg={target_avg_rate})',
        height=700,
        boxmode='group',
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    # Update axes
    # fig.update_xaxes(title_text='Z Position (Economic Status)', row=1, col=2)
    # fig.update_xaxes(title_text='Z Position Distribution', row=2, col=2)
    # fig.update_yaxes(title_text='Incarceration Rate (per 100,000)', row=1, col=2)
    # fig.update_yaxes(title_text='Rate Distribution', row=1, col=1)
    
    # Synchronize x-axes
    # fig.update_xaxes(range=[0, 1], row=1, col=2)
    # fig.update_xaxes(range=[0, 1], row=2, col=2)
    
    return fig

# Initialize the Dash app with bootstrap theme
app_incarceration_rate = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Set plot height
plot_height = 700

# App layout using bootstrap components
app_incarceration_rate.layout = html.Div([
    html.H1("Incarceration Rate Visualization", className="my-4"),

    # Main row containing controls and chart
    dbc.Row([
        # Left column - Controls card
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Controls"),
                dbc.CardBody([
                    # Position distribution controls
                    html.H5("Population Parameters", className="mb-3"),
                    html.Label("Sample Size:"),
                    dcc.Slider(
                        id='sample-size-slider',
                        min=100, max=10000, step=100, value=1000,
                        marks={i: str(i) for i in range(0, 10001, 2000)},
                        className="mb-4"
                    ),
                    html.Label("Proportion Disadvantaged (p):"),
                    dcc.Slider(
                        id='p-slider',
                        min=0.01, max=0.99, step=0.01, value=0.3,
                        marks={i/10: str(i/10) for i in range(0, 11, 2)},
                        className="mb-4"
                    ),
                    
                    # Distribution controls
                    html.H5("Distribution Parameters", className="mb-3 mt-4"),
                    html.Label("Mean Position (Disadvantaged):"),
                    dcc.Slider(
                        id='mu-disadv-slider',
                        min=0.01, max=0.9, step=0.01, value=0.2,
                        marks={i/10: str(i/10) for i in range(0, 10, 2)},
                        className="mb-4"
                    ),
                    html.Label("Position Gap:"),
                    dcc.Slider(
                        id='z-position-gap-slider',
                        min=0, max=0.8, step=0.01, value=0.3,
                        marks={i/10: str(i/10) for i in range(0, 9, 2)},
                        className="mb-4"
                    ),
                    html.Label("Concentration (Disadvantaged):"),
                    dcc.Slider(
                        id='c-disadv-slider',
                        min=1, max=50, step=1, value=20,
                        marks={i: str(i) for i in range(0, 51, 10)},
                        className="mb-4"
                    ),
                    html.Label("Concentration (Advantaged):"),
                    dcc.Slider(
                        id='c-adv-slider',
                        min=1, max=50, step=1, value=20,
                        marks={i: str(i) for i in range(0, 51, 10)},
                        className="mb-4"
                    ),
                    
                    # Incarceration rate parameters
                    html.H5("Incarceration Rate Parameters", className="mb-3 mt-4"),
                    html.Label("Shape Parameter (γ):"),
                    dcc.Slider(
                        id='gamma-slider',
                        min=0, max=5, step=0.1, value=1,
                        marks={i: str(i) for i in range(0, 6)},
                        className="mb-4"
                    ),
                    html.Label("Target Average Rate (per 100,000):"),
                    dcc.Slider(
                        id='target-rate-slider',
                        min=100, max=1000, step=50, value=500,
                        marks={i: str(i) for i in range(100, 1001, 200)},
                        className="mb-4"
                    ),
                ])
            ])
        ], width=3),

        # Right column - Statistics and Chart
        dbc.Col([
            # Statistics card on top
            dbc.Card([
                dbc.CardHeader("Statistics"),
                dbc.CardBody(id='stats-container')
            ], className="mb-4"),
            
            # Visualization card below
            dbc.Card([
                dbc.CardHeader("Visualization"),
                dbc.CardBody([
                    dcc.Graph(
                        id='incarceration-plot',
                        style={'height': f'{plot_height}px', 'width': '100%'}
                    )
                ])
            ])
        ], width=9)
    ])
])

@app_incarceration_rate.callback(
    [Output('incarceration-plot', 'figure'),
     Output('stats-container', 'children')],
    [Input('sample-size-slider', 'value'),
     Input('p-slider', 'value'),
     Input('mu-disadv-slider', 'value'),
     Input('z-position-gap-slider', 'value'),
     Input('c-disadv-slider', 'value'),
     Input('c-adv-slider', 'value'),
     Input('gamma-slider', 'value'),
     Input('target-rate-slider', 'value')]
)
def update_graph(sample_size, p, mu_disadv, z_position_gap, c_disadv, c_adv, gamma, target_avg_rate):
    # Generate positions
    positions = generate_stratification_positions(
        p=p,
        mu_disadv=mu_disadv,
        z_position_gap=z_position_gap,
        c_disadv=c_disadv,
        c_adv=c_adv,
        sample_size=sample_size,
    )
    
    # Calculate incarceration rates
    rate_data = calculate_incarceration_rates_normalized(
        positions=positions,
        gamma=gamma,
        target_avg_rate=target_avg_rate
    )
    
    # Create plot
    fig = create_incarceration_rate_plot(
        rate_data=rate_data,
        gamma=gamma,
        target_avg_rate=target_avg_rate,
        positions=positions,
        height=plot_height
    )
    
    # Create statistics display
    stats = html.Div([
        dbc.Row([
            dbc.Col([
                html.P([
                    html.Strong("Disadvantaged Group Rate: "), 
                    f"{rate_data['rate_disadv']:.1f} per 100,000"
                ]),
                html.P([
                    html.Strong("Advantaged Group Rate: "), 
                    f"{rate_data['rate_adv']:.1f} per 100,000"
                ])
            ], width=4),
            dbc.Col([
                html.P([
                    html.Strong("Population Average Rate: "), 
                    f"{rate_data['pop_avg_rate']:.1f} per 100,000"
                ]),
                html.P([
                    html.Strong("Disparity Ratio: "), 
                    f"{rate_data['rate_disadv'] / rate_data['rate_adv']:.2f}"
                ])
            ], width=4),
            dbc.Col([
                html.P([
                    html.Strong("Disparity Difference: "), 
                    f"{rate_data['rate_disadv'] - rate_data['rate_adv']:.1f} per 100,000"
                ]),
                html.P([
                    html.Strong("Normalized Disparity Index (η): "), 
                    f"{(rate_data['rate_disadv'] / rate_data['rate_adv'] - 1) / (rate_data['rate_disadv'] / rate_data['rate_adv'] + (1-p)/p):.3f}"
                ])
            ], width=4)
        ])
    ])
    
    return fig, stats

In [76]:
# Run the app and display in notebook
from IPython.display import IFrame
app_incarceration_rate.run(jupyter_mode='external', port=8051)
IFrame(src=f"http://127.0.0.1:8051", width="100%", height=plot_height+500)

Dash app running on http://127.0.0.1:8051/


## Combined 