In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import plotly.graph_objects as go
from IPython.display import display, clear_output
import time


In [None]:
# Global variables
current_X_samples = np.array([])
plot_output = widgets.Output()

def sample_X_distribution(dist_type, n_samples, **params):
    """Sample from the specified X distribution"""
    if dist_type == "Uniform":
        return np.random.uniform(0, 1, n_samples)
    elif dist_type == "Beta":
        alpha = params.get('alpha', 2)
        beta = params.get('beta', 2)
        return np.random.beta(alpha, beta, n_samples)
    elif dist_type == "Gamma":
        shape = params.get('shape', 2)
        scale = params.get('scale', 0.5)
        # Scale to keep most mass in [0,1]
        samples = np.random.gamma(shape, scale, n_samples)
        # Normalize to [0,1] range
        samples = np.clip(samples / (shape * scale * 2), 0, 1)
        return samples
    elif dist_type == "Exponential":
        scale = params.get('scale', 0.5)
        # Scale to keep most mass in [0,1]
        samples = np.random.exponential(scale, n_samples)
        samples = np.clip(samples / (scale * 3), 0, 1)
        return samples
    elif dist_type == "Gaussian":
        mean = params.get('mean', 0.5)
        std = params.get('std', 0.2)
        samples = np.random.normal(mean, std, n_samples)
        samples = np.clip(samples, 0, 1)
        return samples
    else:
        return np.random.uniform(0, 1, n_samples)

def apply_g_function(x, func_type, **params):
    """Apply the function g(x) to get Y = g(X)"""
    x = np.clip(x, 0, 1)  # Ensure x is in [0,1]
    
    if func_type == "Linear":
        slope = params.get('slope', 1.0)
        intercept = params.get('intercept', 0.0)
        return slope * x + intercept
    elif func_type == "Piecewise Linear":
        kink = params.get('kink', 0.5)
        slope1 = params.get('slope1', 1.0)
        slope2 = params.get('slope2', 2.0)
        intercept = params.get('intercept', 0.0)
        # First piece: x < kink
        y = np.where(x < kink, 
                     slope1 * x + intercept,
                     slope1 * kink + intercept + slope2 * (x - kink))
        return y
    elif func_type == "Quadratic":
        a = params.get('a', 1.0)
        b = params.get('b', 0.0)
        c = params.get('c', 0.0)
        return a * x**2 + b * x + c
    elif func_type == "Exponential":
        base = params.get('base', np.e)
        scale = params.get('scale', 1.0)
        return scale * (base ** x - 1) / (base - 1)  # Normalized to start at 0
    elif func_type == "Log":
        base = params.get('base', np.e)
        scale = params.get('scale', 1.0)
        # Shift and scale so log(1) maps to scale
        return scale * np.log(1 + x * (base - 1)) / np.log(base)
    elif func_type == "Root":
        power = params.get('power', 0.5)  # 0.5 = sqrt
        scale = params.get('scale', 1.0)
        return scale * (x ** power)
    else:
        return x

def get_g_function_curve(func_type, x_range, **params):
    """Get the curve y = g(x) for plotting"""
    return apply_g_function(x_range, func_type, **params)


In [None]:
def determine_batch_size(sample_index):
    """
    Determine how many samples to add in this batch for animation.
    - Samples 1-10: one at a time
    - Samples 10-30: 2 at a time
    - Samples 30-70: 4 at a time
    - Samples 70+: 8 at a time
    """
    if sample_index < 10:
        return 1
    elif sample_index < 30:
        return 2
    elif sample_index < 70:
        return 4
    else:
        return 8

def update_plot(X_samples, dist_type, func_type, dist_params, func_params):
    """Update the main plot with function curve and samples"""
    fig = go.Figure()
    
    # Plot the function y = g(x) over [0, 1]
    x_curve = np.linspace(0, 1, 200)
    y_curve = get_g_function_curve(func_type, x_curve, **func_params)
    
    fig.add_trace(go.Scatter(
        x=x_curve,
        y=y_curve,
        mode='lines',
        name=f'g(x) = {func_type}',
        line=dict(color='blue', width=2)
    ))
    
    # Plot samples on x-axis (y=0)
    if len(X_samples) > 0:
        # Show samples as points on the x-axis
        y_samples = np.zeros_like(X_samples)
        fig.add_trace(go.Scatter(
            x=X_samples,
            y=y_samples,
            mode='markers',
            name='X samples',
            marker=dict(
                size=8,
                color='red',
                line=dict(width=1, color='darkred')
            )
        ))
    
    # Determine y-axis range
    if len(y_curve) > 0:
        y_min = float(np.min(y_curve))
        y_max = float(np.max(y_curve))
        y_range = y_max - y_min
        if y_range > 0:
            y_min -= 0.1 * y_range
            y_max += 0.1 * y_range
        else:
            y_min -= 0.1
            y_max += 0.1
    else:
        y_min, y_max = -0.1, 1.1
    
    fig.update_layout(
        title=f"Change of Density Demo - {len(X_samples)} samples",
        xaxis_title="X",
        yaxis_title="Y = g(X)",
        xaxis=dict(range=[-0.05, 1.05]),
        yaxis=dict(range=[y_min, y_max]),
        width=800,
        height=600,
        showlegend=True,
        legend=dict(x=0.02, y=0.98, yanchor="top")
    )
    
    with plot_output:
        clear_output(wait=True)
        display(fig)


In [None]:
# Create distribution dropdown
dist_dropdown = widgets.Dropdown(
    options=['Uniform', 'Beta', 'Gamma', 'Exponential', 'Gaussian'],
    value='Uniform',
    description='X Distribution:',
    style={'description_width': 'initial'}
)

# Distribution parameter controls
beta_alpha_slider = widgets.FloatSlider(
    value=2.0, min=0.5, max=10.0, step=0.1,
    description='Beta α:',
    style={'description_width': 'initial'}
)
beta_beta_slider = widgets.FloatSlider(
    value=2.0, min=0.5, max=10.0, step=0.1,
    description='Beta β:',
    style={'description_width': 'initial'}
)

gamma_shape_slider = widgets.FloatSlider(
    value=2.0, min=0.5, max=10.0, step=0.1,
    description='Gamma shape:',
    style={'description_width': 'initial'}
)
gamma_scale_slider = widgets.FloatSlider(
    value=0.5, min=0.1, max=2.0, step=0.1,
    description='Gamma scale:',
    style={'description_width': 'initial'}
)

exp_scale_slider = widgets.FloatSlider(
    value=0.5, min=0.1, max=2.0, step=0.1,
    description='Exp scale:',
    style={'description_width': 'initial'}
)

gauss_mean_slider = widgets.FloatSlider(
    value=0.5, min=0.0, max=1.0, step=0.05,
    description='Gauss mean:',
    style={'description_width': 'initial'}
)
gauss_std_slider = widgets.FloatSlider(
    value=0.2, min=0.05, max=0.5, step=0.05,
    description='Gauss std:',
    style={'description_width': 'initial'}
)

# Container for distribution parameters
dist_params_box = widgets.VBox([
    beta_alpha_slider,
    beta_beta_slider,
    gamma_shape_slider,
    gamma_scale_slider,
    exp_scale_slider,
    gauss_mean_slider,
    gauss_std_slider
])

def update_dist_params_visibility(change):
    """Show/hide distribution parameter controls based on selected distribution"""
    dist_type = dist_dropdown.value
    children = []
    
    if dist_type == 'Beta':
        children = [beta_alpha_slider, beta_beta_slider]
    elif dist_type == 'Gamma':
        children = [gamma_shape_slider, gamma_scale_slider]
    elif dist_type == 'Exponential':
        children = [exp_scale_slider]
    elif dist_type == 'Gaussian':
        children = [gauss_mean_slider, gauss_std_slider]
    # Uniform has no parameters
    
    dist_params_box.children = children
    update_function_plot()

dist_dropdown.observe(update_dist_params_visibility, names='value')
update_dist_params_visibility(None)  # Initialize


In [None]:
# Create function dropdown
func_dropdown = widgets.Dropdown(
    options=['Linear', 'Piecewise Linear', 'Quadratic', 'Exponential', 'Log', 'Root'],
    value='Linear',
    description='Function g(x):',
    style={'description_width': 'initial'}
)

# Function parameter controls
linear_slope_slider = widgets.FloatSlider(
    value=1.0, min=-5.0, max=5.0, step=0.1,
    description='Slope:',
    style={'description_width': 'initial'}
)
linear_intercept_slider = widgets.FloatSlider(
    value=0.0, min=-2.0, max=2.0, step=0.1,
    description='Intercept:',
    style={'description_width': 'initial'}
)

piecewise_kink_slider = widgets.FloatSlider(
    value=0.5, min=0.1, max=0.9, step=0.05,
    description='Kink position:',
    style={'description_width': 'initial'}
)
piecewise_slope1_slider = widgets.FloatSlider(
    value=1.0, min=-5.0, max=5.0, step=0.1,
    description='Slope 1:',
    style={'description_width': 'initial'}
)
piecewise_slope2_slider = widgets.FloatSlider(
    value=2.0, min=-5.0, max=5.0, step=0.1,
    description='Slope 2:',
    style={'description_width': 'initial'}
)
piecewise_intercept_slider = widgets.FloatSlider(
    value=0.0, min=-2.0, max=2.0, step=0.1,
    description='Intercept:',
    style={'description_width': 'initial'}
)

quadratic_a_slider = widgets.FloatSlider(
    value=1.0, min=-5.0, max=5.0, step=0.1,
    description='a (x²):',
    style={'description_width': 'initial'}
)
quadratic_b_slider = widgets.FloatSlider(
    value=0.0, min=-5.0, max=5.0, step=0.1,
    description='b (x):',
    style={'description_width': 'initial'}
)
quadratic_c_slider = widgets.FloatSlider(
    value=0.0, min=-2.0, max=2.0, step=0.1,
    description='c:',
    style={'description_width': 'initial'}
)

exp_base_slider = widgets.FloatSlider(
    value=np.e, min=1.1, max=10.0, step=0.1,
    description='Base:',
    style={'description_width': 'initial'}
)
exp_scale_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1,
    description='Scale:',
    style={'description_width': 'initial'}
)

log_base_slider = widgets.FloatSlider(
    value=np.e, min=1.1, max=10.0, step=0.1,
    description='Base:',
    style={'description_width': 'initial'}
)
log_scale_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1,
    description='Scale:',
    style={'description_width': 'initial'}
)

root_power_slider = widgets.FloatSlider(
    value=0.5, min=0.1, max=2.0, step=0.1,
    description='Power:',
    style={'description_width': 'initial'}
)
root_scale_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1,
    description='Scale:',
    style={'description_width': 'initial'}
)

# Container for function parameters
func_params_box = widgets.VBox([
    linear_slope_slider,
    linear_intercept_slider,
    piecewise_kink_slider,
    piecewise_slope1_slider,
    piecewise_slope2_slider,
    piecewise_intercept_slider,
    quadratic_a_slider,
    quadratic_b_slider,
    quadratic_c_slider,
    exp_base_slider,
    exp_scale_slider,
    log_base_slider,
    log_scale_slider,
    root_power_slider,
    root_scale_slider
])

def update_func_params_visibility(change):
    """Show/hide function parameter controls based on selected function"""
    func_type = func_dropdown.value
    children = []
    
    if func_type == 'Linear':
        children = [linear_slope_slider, linear_intercept_slider]
    elif func_type == 'Piecewise Linear':
        children = [piecewise_kink_slider, piecewise_slope1_slider, 
                   piecewise_slope2_slider, piecewise_intercept_slider]
    elif func_type == 'Quadratic':
        children = [quadratic_a_slider, quadratic_b_slider, quadratic_c_slider]
    elif func_type == 'Exponential':
        children = [exp_base_slider, exp_scale_slider]
    elif func_type == 'Log':
        children = [log_base_slider, log_scale_slider]
    elif func_type == 'Root':
        children = [root_power_slider, root_scale_slider]
    
    func_params_box.children = children
    update_function_plot()

func_dropdown.observe(update_func_params_visibility, names='value')
update_func_params_visibility(None)  # Initialize

# Make all parameter sliders update the plot when changed
for slider in [linear_slope_slider, linear_intercept_slider,
               piecewise_kink_slider, piecewise_slope1_slider, piecewise_slope2_slider, piecewise_intercept_slider,
               quadratic_a_slider, quadratic_b_slider, quadratic_c_slider,
               exp_base_slider, exp_scale_slider,
               log_base_slider, log_scale_slider,
               root_power_slider, root_scale_slider]:
    slider.observe(lambda x: update_function_plot(), names='value')


In [None]:
def get_dist_params():
    """Get current distribution parameters"""
    dist_type = dist_dropdown.value
    params = {}
    
    if dist_type == 'Beta':
        params = {'alpha': beta_alpha_slider.value, 'beta': beta_beta_slider.value}
    elif dist_type == 'Gamma':
        params = {'shape': gamma_shape_slider.value, 'scale': gamma_scale_slider.value}
    elif dist_type == 'Exponential':
        params = {'scale': exp_scale_slider.value}
    elif dist_type == 'Gaussian':
        params = {'mean': gauss_mean_slider.value, 'std': gauss_std_slider.value}
    
    return params

def get_func_params():
    """Get current function parameters"""
    func_type = func_dropdown.value
    params = {}
    
    if func_type == 'Linear':
        params = {'slope': linear_slope_slider.value, 'intercept': linear_intercept_slider.value}
    elif func_type == 'Piecewise Linear':
        params = {'kink': piecewise_kink_slider.value,
                 'slope1': piecewise_slope1_slider.value,
                 'slope2': piecewise_slope2_slider.value,
                 'intercept': piecewise_intercept_slider.value}
    elif func_type == 'Quadratic':
        params = {'a': quadratic_a_slider.value,
                 'b': quadratic_b_slider.value,
                 'c': quadratic_c_slider.value}
    elif func_type == 'Exponential':
        params = {'base': exp_base_slider.value, 'scale': exp_scale_slider.value}
    elif func_type == 'Log':
        params = {'base': log_base_slider.value, 'scale': log_scale_slider.value}
    elif func_type == 'Root':
        params = {'power': root_power_slider.value, 'scale': root_scale_slider.value}
    
    return params

def update_function_plot():
    """Update the plot when parameters change (without samples)"""
    dist_params = get_dist_params()
    func_params = get_func_params()
    update_plot(current_X_samples, dist_dropdown.value, func_dropdown.value, 
               dist_params, func_params)


In [None]:
# Number of samples slider
n_samples_slider = widgets.IntSlider(
    value=100, min=10, max=1000, step=10,
    description='Number of samples:',
    style={'description_width': 'initial'}
)

# Draw samples button
draw_samples_button = widgets.Button(
    description='Draw Samples',
    button_style='primary',
    layout=widgets.Layout(width='200px', height='40px')
)

status_html = widgets.HTML(value="Ready to draw samples.")

def on_draw_samples_clicked(button):
    """Handle the Draw Samples button click with progressive visualization"""
    global current_X_samples
    
    n_total = n_samples_slider.value
    dist_type = dist_dropdown.value
    func_type = func_dropdown.value
    dist_params = get_dist_params()
    func_params = get_func_params()
    
    # Generate all X samples at once
    X_all = sample_X_distribution(dist_type, n_total, **dist_params)
    
    # Store samples
    current_X_samples = X_all
    
    # Progressive visualization
    status_html.value = "Generating samples..."
    
    sample_index = 0
    while sample_index < n_total:
        batch_size = determine_batch_size(sample_index)
        end_index = min(sample_index + batch_size, n_total)
        
        # Get samples up to current index
        X_visible = X_all[:end_index]
        
        # Update plot
        update_plot(X_visible, dist_type, func_type, dist_params, func_params)
        
        # Update status
        status_html.value = f"Generated {end_index} / {n_total} samples"
        
        # Small delay for animation effect
        if sample_index < 10:
            time.sleep(0.1)
        elif sample_index < 30:
            time.sleep(0.05)
        else:
            time.sleep(0.01)
        
        sample_index = end_index
    
    # Final update
    update_plot(current_X_samples, dist_type, func_type, dist_params, func_params)
    status_html.value = f"Complete! Generated {n_total} samples."

draw_samples_button.on_click(on_draw_samples_clicked)


In [None]:
# Layout the interface
left_panel = widgets.VBox([
    widgets.HTML("<h3>X Distribution</h3>"),
    dist_dropdown,
    dist_params_box
])

right_panel = widgets.VBox([
    widgets.HTML("<h3>Function g(x)</h3>"),
    func_dropdown,
    func_params_box
])

control_panel = widgets.HBox([
    left_panel,
    right_panel
], layout=widgets.Layout(width='100%'))

bottom_panel = widgets.VBox([
    n_samples_slider,
    draw_samples_button,
    status_html
])

# Initial plot
update_function_plot()

# Display everything
display(control_panel)
display(bottom_panel)
display(plot_output)
