In [59]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display, clear_output
import time
from scipy import stats


In [60]:
# Global variables
current_X_samples = np.array([])
current_Y_samples = np.array([])
plot_output = widgets.Output()
show_density = False  # Toggle for showing theoretical densities

def sample_X_distribution(dist_type, n_samples, **params):
    """Sample from the specified X distribution, keeping most mass in [0, 1]"""
    if dist_type == "Uniform":
        return np.random.uniform(0, 1, n_samples)
    elif dist_type == "Beta":
        alpha = params.get('alpha', 2)
        beta = params.get('beta', 2)
        return np.random.beta(alpha, beta, n_samples)
    elif dist_type == "Gamma":
        shape = params.get('shape', 2)
        scale = params.get('scale', 0.5)
        # Scale to keep most mass in [0,1]
        samples = np.random.gamma(shape, scale, n_samples)
        # Normalize to [0,1] range
        samples = np.clip(samples / (shape * scale * 2), 0, 1)
        return samples
    elif dist_type == "Exponential":
        scale = params.get('scale', 0.5)
        # Scale to keep most mass in [0,1]
        samples = np.random.exponential(scale, n_samples)
        samples = np.clip(samples / (scale * 3), 0, 1)
        return samples
    elif dist_type == "Gaussian":
        mean = params.get('mean', 0.5)
        std = params.get('std', 0.2)  # sqrt(0.04) = 0.2
        samples = np.random.normal(mean, std, n_samples)
        samples = np.clip(samples, 0, 1)
        return samples
    else:
        return np.random.uniform(0, 1, n_samples)

def apply_g_function(x, func_type, **params):
    """Apply the function g(x) to get Y = g(X)"""
    x = np.clip(x, 0, 1)  # Ensure x is in [0,1]
    
    if func_type == "Linear":
        slope = params.get('slope', 1.0)
        intercept = params.get('intercept', 0.0)
        return slope * x + intercept
    elif func_type == "Piecewise Linear":
        kink = params.get('kink', 0.5)
        slope1 = params.get('slope1', 1.0)
        slope2 = params.get('slope2', 2.0)
        intercept = params.get('intercept', 0.0)
        # First piece: x < kink, second piece: x >= kink
        y = np.where(x < kink, 
                     slope1 * x + intercept,
                     slope1 * kink + intercept + slope2 * (x - kink))
        return y
    elif func_type == "Quadratic":
        a = params.get('a', 1.0)
        b = params.get('b', 0.0)
        c = params.get('c', 0.0)
        return a * x**2 + b * x + c
    elif func_type == "Exponential":
        base = params.get('base', np.e)
        scale = params.get('scale', 1.0)
        return scale * (base ** x - 1) / (base - 1)  # Normalized to start at 0
    elif func_type == "Log":
        base = params.get('base', np.e)
        scale = params.get('scale', 1.0)
        # Shift and scale so log(1) maps to scale
        return scale * np.log(1 + x * (base - 1)) / np.log(base)
    elif func_type == "Root":
        power = params.get('power', 0.5)  # 0.5 = sqrt
        scale = params.get('scale', 1.0)
        return scale * (x ** power)
    else:
        return x

def get_g_function_curve(func_type, x_range, **params):
    """Get the curve y = g(x) for plotting"""
    return apply_g_function(x_range, func_type, **params)

def get_g_derivative(x, func_type, **params):
    """Compute the derivative g'(x) for change of variables formula"""
    x = np.clip(x, 0, 1)
    eps = 1e-6
    
    if func_type == "Linear":
        slope = params.get('slope', 1.0)
        return np.full_like(x, slope)
    elif func_type == "Piecewise Linear":
        kink = params.get('kink', 0.5)
        slope1 = params.get('slope1', 1.0)
        slope2 = params.get('slope2', 2.0)
        return np.where(x < kink, slope1, slope2)
    elif func_type == "Quadratic":
        a = params.get('a', 1.0)
        b = params.get('b', 0.0)
        return 2 * a * x + b
    elif func_type == "Exponential":
        base = params.get('base', np.e)
        scale = params.get('scale', 1.0)
        return scale * np.log(base) * (base ** x) / (base - 1)
    elif func_type == "Log":
        base = params.get('base', np.e)
        scale = params.get('scale', 1.0)
        return scale / (np.log(base) * (1 + x * (base - 1)))
    elif func_type == "Root":
        power = params.get('power', 0.5)
        scale = params.get('scale', 1.0)
        # Avoid division by zero
        x_safe = np.maximum(x, eps)
        return scale * power * (x_safe ** (power - 1))
    else:
        return np.ones_like(x)

def compute_X_density(x_values, dist_type, **params):
    """Compute theoretical PDF for X distribution"""
    x_values = np.clip(x_values, 0, 1)
    density = np.zeros_like(x_values)
    
    if dist_type == "Uniform":
        density = np.ones_like(x_values)
    elif dist_type == "Beta":
        alpha = params.get('alpha', 2)
        beta = params.get('beta', 2)
        # Use scipy beta distribution
        density = stats.beta.pdf(x_values, alpha, beta)
    elif dist_type == "Gamma":
        shape = params.get('shape', 2)
        scale = params.get('scale', 0.5)
        # Scale factor for normalization to [0,1]
        scale_factor = shape * scale * 2
        # Transform: if X_scaled = X_original / scale_factor, then
        # f_X_scaled(x) = scale_factor * f_X_original(scale_factor * x)
        x_original = x_values * scale_factor
        density = scale_factor * stats.gamma.pdf(x_original, shape, scale=scale)
        # Clip to [0,1] range
        density = np.clip(density, 0, np.inf)
    elif dist_type == "Exponential":
        scale = params.get('scale', 0.5)
        # Scale factor for normalization to [0,1]
        scale_factor = scale * 3
        x_original = x_values * scale_factor
        density = scale_factor * stats.expon.pdf(x_original, scale=scale)
        density = np.clip(density, 0, np.inf)
    elif dist_type == "Gaussian":
        mean = params.get('mean', 0.5)
        std = params.get('std', 0.2)
        density = stats.norm.pdf(x_values, mean, std)
        # Renormalize for clipped distribution (approximate)
        # This is an approximation - full treatment would require truncation
        density = np.clip(density, 0, np.inf)
    
    return density

def compute_Y_density(y_values, x_range, dist_type, func_type, dist_params, func_params):
    """Compute theoretical PDF for Y = g(X) using change of variables"""
    # For Y = g(X), we need to find x such that g(x) = y, then use:
    # f_Y(y) = f_X(x) / |g'(x)|
    # This requires g to be invertible, which may not always be true
    # We'll use a numerical approach with vectorized operations
    
    # Compute g(x) for all x in x_range
    g_values = apply_g_function(x_range, func_type, **func_params)
    
    # For each y, find the closest x such that g(x) ≈ y
    density = np.zeros_like(y_values)
    
    # Vectorized approach: for each y, find closest g value
    for i, y in enumerate(y_values):
        # Find index of closest g value to y
        idx = np.argmin(np.abs(g_values - y))
        x_match = x_range[idx]
        
        # Compute density using change of variables
        f_X = compute_X_density(np.array([x_match]), dist_type, **dist_params)[0]
        g_prime = get_g_derivative(np.array([x_match]), func_type, **func_params)[0]
        
        if abs(g_prime) > 1e-10:  # Avoid division by zero
            density[i] = f_X / abs(g_prime)
        else:
            density[i] = 0
    
    return density


In [61]:
def determine_batch_size(sample_index):
    """
    Determine how many samples to add in this batch for animation.
    - Samples 1-10: one at a time
    - Samples 10-30: 2 at a time
    - Samples 30-70: 4 at a time
    - Samples 70+: 8 at a time
    """
    if sample_index < 10:
        return 1
    elif sample_index < 30:
        return 2
    elif sample_index < 70:
        return 4
    else:
        return 8

def update_plot(X_samples, dist_type, func_type, dist_params, func_params, Y_samples=None, show_Y=False):
    global show_density
    """Update the main plot with function curve, samples, and histograms
    
    Parameters:
    - X_samples: X values to plot
    - Y_samples: Y values to plot (if show_Y is True)
    - show_Y: If True, show Y samples on the curve with both X and Y histograms; if False, show X samples on x-axis with X histogram only
    """
    # Create subplots layout: 2 rows x 2 columns
    # Row 1, Col 1: Y histogram (horizontal, left side) - only if show_Y
    # Row 1, Col 2: Main plot
    # Row 2, Col 1: Empty (or can be used for spacing)
    # Row 2, Col 2: X histogram (bottom)
    
    if show_Y and Y_samples is not None and len(Y_samples) > 0:
        # Layout with Y histogram on left
        fig = make_subplots(
            rows=2, cols=2,
            column_widths=[0.2, 0.8],  # Y histogram 20%, main plot 80%
            row_heights=[0.7, 0.3],    # Main area 70%, X histogram 30%
            horizontal_spacing=0.05,
            vertical_spacing=0.05,
            shared_yaxes='rows',  # Share y-axis within rows (row 1: Y hist and main plot)
            shared_xaxes='columns',  # Share x-axis within columns
            subplot_titles=('Y Distribution', '', '', 'X Distribution'),
            specs=[[{"type": "bar"}, {"type": "scatter"}],
                   [None, {"type": "bar"}]]
        )
    else:
        # Simple layout: main plot on top, X histogram below
        fig = make_subplots(
            rows=2, cols=1,
            row_heights=[0.7, 0.3],
            vertical_spacing=0.05,
            shared_xaxes=True,
            subplot_titles=('', 'X Distribution Histogram')
        )
    
    # Plot the function y = g(x) over [0, 1]
    x_curve = np.linspace(0, 1, 200)
    y_curve = get_g_function_curve(func_type, x_curve, **func_params)
    
    # Determine which subplot to use for main plot
    main_row, main_col = (1, 2) if (show_Y and Y_samples is not None and len(Y_samples) > 0) else (1, 1)
    
    fig.add_trace(go.Scatter(
        x=x_curve,
        y=y_curve,
        mode='lines',
        name=f'g(x) = {func_type}',
        line=dict(color='blue', width=2)
    ), row=main_row, col=main_col)
    
    # Plot samples based on mode
    if show_Y and Y_samples is not None and len(Y_samples) > 0:
        # Show Y samples on the curve (at their (x, y) positions)
        X_for_Y = X_samples[:len(Y_samples)]  # Match the length
        fig.add_trace(go.Scatter(
            x=X_for_Y,
            y=Y_samples,
            mode='markers',
            name='Y = g(X) samples',
            marker=dict(
                size=8,
                color='green',
                line=dict(width=1, color='darkgreen')
            )
        ), row=main_row, col=main_col)
    elif len(X_samples) > 0:
        # Show X samples as points on the x-axis
        y_samples = np.zeros_like(X_samples)
        fig.add_trace(go.Scatter(
            x=X_samples,
            y=y_samples,
            mode='markers',
            name='X samples',
            marker=dict(
                size=8,
                color='red',
                line=dict(width=1, color='darkred')
            )
        ), row=main_row, col=main_col)
    
    # Determine y-axis range for main plot (used for alignment)
    if len(y_curve) > 0:
        y_min = float(np.min(y_curve))
        y_max = float(np.max(y_curve))
        # If showing Y samples, include them in range
        if show_Y and Y_samples is not None and len(Y_samples) > 0:
            y_min = min(y_min, float(np.min(Y_samples)))
            y_max = max(y_max, float(np.max(Y_samples)))
        y_range = y_max - y_min
        if y_range > 0:
            y_min -= 0.1 * y_range
            y_max += 0.1 * y_range
        else:
            y_min -= 0.1
            y_max += 0.1
    else:
        y_min, y_max = -0.1, 1.1
    
    # Add X histogram on the bottom
    X_density = np.array([])
    if len(X_samples) > 0:
        # Compute X histogram with density = count / (bin_width * n_samples)
        n_bins = 30
        counts, bin_edges = np.histogram(X_samples, bins=n_bins, range=(0, 1))
        bin_width = bin_edges[1] - bin_edges[0]
        n_samples = len(X_samples)
        
        # Calculate density: count in bin / (width of bin * number of samples)
        X_density = counts / (bin_width * n_samples) if n_samples > 0 else counts
        
        # Use bin centers for x-axis
        X_bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        
        # Add X histogram bars
        x_hist_row, x_hist_col = (2, 2) if (show_Y and Y_samples is not None and len(Y_samples) > 0) else (2, 1)
        fig.add_trace(go.Bar(
            x=X_bin_centers,
            y=X_density,
            width=bin_width * 0.9,
            name='X Density',
            marker=dict(color='steelblue', line=dict(color='navy', width=1)),
            showlegend=False
        ), row=x_hist_row, col=x_hist_col)
    
    # Add Y histogram on the left (rotated/horizontal) - only if show_Y
    Y_density = np.array([])
    if show_Y and Y_samples is not None and len(Y_samples) > 0:
        # Determine range for Y histogram
        Y_hist_min = float(np.min(Y_samples))
        Y_hist_max = float(np.max(Y_samples))
        Y_hist_range = (Y_hist_min, Y_hist_max)
        # Adjust to avoid edge issues
        if Y_hist_range[1] - Y_hist_range[0] < 0.01:
            Y_hist_range = (Y_hist_min - 0.1, Y_hist_max + 0.1)
        
        # Compute Y histogram with density = count / (bin_width * n_samples)
        n_bins = 30
        counts, bin_edges = np.histogram(Y_samples, bins=n_bins, range=Y_hist_range)
        bin_width = bin_edges[1] - bin_edges[0]
        n_samples = len(Y_samples)
        
        # Calculate density: count in bin / (width of bin * number of samples)
        Y_density = counts / (bin_width * n_samples) if n_samples > 0 else counts
        
        # Use bin centers for y-axis (since it's rotated)
        Y_bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        
        # Add Y histogram bars (horizontal/rotated)
        fig.add_trace(go.Bar(
            x=Y_density,  # Density on x-axis
            y=Y_bin_centers,  # Y values on y-axis (aligned with main plot)
            orientation='h',  # Horizontal bars
            name='Y Density',
            marker=dict(color='darkgreen', line=dict(color='green', width=1)),
            showlegend=False
        ), row=1, col=1)
        
        # Add theoretical Y density curve if enabled
        if show_density:
            y_density_range = np.linspace(Y_hist_min, Y_hist_max, 200)
            x_density_range = np.linspace(0, 1, 200)
            Y_theoretical_density = compute_Y_density(y_density_range, x_density_range, 
                                                      dist_type, func_type, dist_params, func_params)
            # Normalize to match histogram scale
            fig.add_trace(go.Scatter(
                x=Y_theoretical_density,
                y=y_density_range,
                mode='lines',
                name='Y Theoretical Density',
                line=dict(color='orange', width=2, dash='dash'),
                showlegend=False
            ), row=1, col=1)
    
    # Add theoretical X density curve if enabled
    if show_density and len(X_samples) > 0:
        x_density_range = np.linspace(0, 1, 200)
        X_theoretical_density = compute_X_density(x_density_range, dist_type, **dist_params)
        # Normalize to match histogram scale (may need adjustment)
        fig.add_trace(go.Scatter(
            x=x_density_range,
            y=X_theoretical_density,
            mode='lines',
            name='X Theoretical Density',
            line=dict(color='orange', width=2, dash='dash'),
            showlegend=False
        ), row=x_hist_row, col=x_hist_col)
    
    # Determine ranges for histograms
    X_hist_y_max = float(np.max(X_density)) * 1.1 if len(X_density) > 0 and np.max(X_density) > 0 else 1.0
    Y_hist_x_max = float(np.max(Y_density)) * 1.1 if len(Y_density) > 0 and np.max(Y_density) > 0 else 1.0
    
    # Update layout
    sample_count = len(Y_samples) if (show_Y and Y_samples is not None) else len(X_samples)
    title_suffix = " (Transformed)" if show_Y else ""
    fig.update_layout(
        title=f"Change of Density Demo - {sample_count} samples{title_suffix}",
        height=700,
        width=1000 if (show_Y and Y_samples is not None and len(Y_samples) > 0) else 800,
        showlegend=True,
        legend=dict(x=0.02, y=0.98, yanchor="top")
    )
    
    # Update axes
    if show_Y and Y_samples is not None and len(Y_samples) > 0:
        # Y histogram (left, row 1, col 1)
        fig.update_xaxes(title_text="Density", range=[0, Y_hist_x_max], row=1, col=1)
        fig.update_yaxes(title_text="Y", range=[y_min, y_max], row=1, col=1)  # Aligned with main plot
        
        # Main plot (right, row 1, col 2)
        fig.update_xaxes(title_text="X", range=[-0.05, 1.05], row=1, col=2)
        fig.update_yaxes(title_text="Y = g(X)", range=[y_min, y_max], row=1, col=2)
        
        # X histogram (bottom, row 2, col 2)
        fig.update_xaxes(title_text="X", range=[-0.05, 1.05], row=2, col=2)
        fig.update_yaxes(title_text="Density", range=[0, X_hist_y_max], row=2, col=2)
    else:
        # Main plot
        fig.update_xaxes(title_text="X", range=[-0.05, 1.05], row=1, col=1)
        fig.update_yaxes(title_text="Y = g(X)", range=[y_min, y_max], row=1, col=1)
        
        # X histogram
        fig.update_xaxes(title_text="X", range=[-0.05, 1.05], row=2, col=1)
        fig.update_yaxes(title_text="Density", range=[0, X_hist_y_max], row=2, col=1)
    
    with plot_output:
        clear_output(wait=True)
        display(fig)


In [62]:
# Create distribution dropdown
dist_dropdown = widgets.Dropdown(
    options=['Uniform', 'Beta', 'Gamma', 'Exponential', 'Gaussian'],
    value='Uniform',
    description='X Distribution:',
    style={'description_width': 'initial'}
)

# Distribution parameter controls
beta_alpha_slider = widgets.FloatSlider(
    value=2.0, min=0.5, max=10.0, step=0.1,
    description='Beta α:',
    style={'description_width': 'initial'}
)
beta_beta_slider = widgets.FloatSlider(
    value=2.0, min=0.5, max=10.0, step=0.1,
    description='Beta β:',
    style={'description_width': 'initial'}
)

gamma_shape_slider = widgets.FloatSlider(
    value=2.0, min=0.5, max=10.0, step=0.1,
    description='Gamma shape:',
    style={'description_width': 'initial'}
)
gamma_scale_slider = widgets.FloatSlider(
    value=0.5, min=0.1, max=2.0, step=0.1,
    description='Gamma scale:',
    style={'description_width': 'initial'}
)

exp_scale_slider = widgets.FloatSlider(
    value=0.5, min=0.1, max=2.0, step=0.1,
    description='Exp scale:',
    style={'description_width': 'initial'}
)

gauss_mean_slider = widgets.FloatSlider(
    value=0.5, min=0.0, max=1.0, step=0.05,
    description='Gauss mean:',
    style={'description_width': 'initial'}
)
gauss_std_slider = widgets.FloatSlider(
    value=0.2, min=0.05, max=0.5, step=0.05,
    description='Gauss std:',
    style={'description_width': 'initial'}
)

# Container for distribution parameters
dist_params_box = widgets.VBox([
    beta_alpha_slider,
    beta_beta_slider,
    gamma_shape_slider,
    gamma_scale_slider,
    exp_scale_slider,
    gauss_mean_slider,
    gauss_std_slider
])

def update_dist_params_visibility(change):
    """Show/hide distribution parameter controls based on selected distribution"""
    dist_type = dist_dropdown.value
    children = []
    
    if dist_type == 'Beta':
        children = [beta_alpha_slider, beta_beta_slider]
    elif dist_type == 'Gamma':
        children = [gamma_shape_slider, gamma_scale_slider]
    elif dist_type == 'Exponential':
        children = [exp_scale_slider]
    elif dist_type == 'Gaussian':
        children = [gauss_mean_slider, gauss_std_slider]
    # Uniform has no parameters
    
    dist_params_box.children = children
    # Only update plot if the function is already defined
    try:
        update_function_plot()
    except NameError:
        pass  # Function not defined yet, will be called later

dist_dropdown.observe(update_dist_params_visibility, names='value')
# Initialize visibility but don't update plot yet
dist_type = dist_dropdown.value
children = []
if dist_type == 'Beta':
    children = [beta_alpha_slider, beta_beta_slider]
elif dist_type == 'Gamma':
    children = [gamma_shape_slider, gamma_scale_slider]
elif dist_type == 'Exponential':
    children = [exp_scale_slider]
elif dist_type == 'Gaussian':
    children = [gauss_mean_slider, gauss_std_slider]
dist_params_box.children = children


In [63]:
# Create function dropdown
func_dropdown = widgets.Dropdown(
    options=['Linear', 'Piecewise Linear', 'Quadratic', 'Exponential', 'Log', 'Root'],
    value='Linear',
    description='Function g(x):',
    style={'description_width': 'initial'}
)

# Function parameter controls
linear_slope_slider = widgets.FloatSlider(
    value=1.0, min=-5.0, max=5.0, step=0.1,
    description='Slope:',
    style={'description_width': 'initial'}
)
linear_intercept_slider = widgets.FloatSlider(
    value=0.0, min=-2.0, max=2.0, step=0.1,
    description='Intercept:',
    style={'description_width': 'initial'}
)

piecewise_kink_slider = widgets.FloatSlider(
    value=0.5, min=0.1, max=0.9, step=0.05,
    description='Kink position:',
    style={'description_width': 'initial'}
)
piecewise_slope1_slider = widgets.FloatSlider(
    value=1.0, min=-5.0, max=5.0, step=0.1,
    description='Slope 1:',
    style={'description_width': 'initial'}
)
piecewise_slope2_slider = widgets.FloatSlider(
    value=2.0, min=-5.0, max=5.0, step=0.1,
    description='Slope 2:',
    style={'description_width': 'initial'}
)
piecewise_intercept_slider = widgets.FloatSlider(
    value=0.0, min=-2.0, max=2.0, step=0.1,
    description='Intercept:',
    style={'description_width': 'initial'}
)

quadratic_a_slider = widgets.FloatSlider(
    value=1.0, min=-5.0, max=5.0, step=0.1,
    description='a (x²):',
    style={'description_width': 'initial'}
)
quadratic_b_slider = widgets.FloatSlider(
    value=0.0, min=-5.0, max=5.0, step=0.1,
    description='b (x):',
    style={'description_width': 'initial'}
)
quadratic_c_slider = widgets.FloatSlider(
    value=0.0, min=-2.0, max=2.0, step=0.1,
    description='c:',
    style={'description_width': 'initial'}
)

exp_base_slider = widgets.FloatSlider(
    value=np.e, min=1.1, max=10.0, step=0.1,
    description='Base:',
    style={'description_width': 'initial'}
)
exp_scale_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1,
    description='Scale:',
    style={'description_width': 'initial'}
)

log_base_slider = widgets.FloatSlider(
    value=np.e, min=1.1, max=10.0, step=0.1,
    description='Base:',
    style={'description_width': 'initial'}
)
log_scale_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1,
    description='Scale:',
    style={'description_width': 'initial'}
)

root_power_slider = widgets.FloatSlider(
    value=0.5, min=0.1, max=2.0, step=0.1,
    description='Power:',
    style={'description_width': 'initial'}
)
root_scale_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1,
    description='Scale:',
    style={'description_width': 'initial'}
)

# Container for function parameters
func_params_box = widgets.VBox([
    linear_slope_slider,
    linear_intercept_slider,
    piecewise_kink_slider,
    piecewise_slope1_slider,
    piecewise_slope2_slider,
    piecewise_intercept_slider,
    quadratic_a_slider,
    quadratic_b_slider,
    quadratic_c_slider,
    exp_base_slider,
    exp_scale_slider,
    log_base_slider,
    log_scale_slider,
    root_power_slider,
    root_scale_slider
])

def update_func_params_visibility(change):
    """Show/hide function parameter controls based on selected function"""
    func_type = func_dropdown.value
    children = []
    
    if func_type == 'Linear':
        children = [linear_slope_slider, linear_intercept_slider]
    elif func_type == 'Piecewise Linear':
        children = [piecewise_kink_slider, piecewise_slope1_slider, 
                   piecewise_slope2_slider, piecewise_intercept_slider]
    elif func_type == 'Quadratic':
        children = [quadratic_a_slider, quadratic_b_slider, quadratic_c_slider]
    elif func_type == 'Exponential':
        children = [exp_base_slider, exp_scale_slider]
    elif func_type == 'Log':
        children = [log_base_slider, log_scale_slider]
    elif func_type == 'Root':
        children = [root_power_slider, root_scale_slider]
    
    func_params_box.children = children
    # Only update plot if the function is already defined
    try:
        update_function_plot()
    except NameError:
        pass  # Function not defined yet, will be called later

func_dropdown.observe(update_func_params_visibility, names='value')
# Initialize visibility but don't update plot yet
func_type = func_dropdown.value
children = []
if func_type == 'Linear':
    children = [linear_slope_slider, linear_intercept_slider]
elif func_type == 'Piecewise Linear':
    children = [piecewise_kink_slider, piecewise_slope1_slider, 
               piecewise_slope2_slider, piecewise_intercept_slider]
elif func_type == 'Quadratic':
    children = [quadratic_a_slider, quadratic_b_slider, quadratic_c_slider]
elif func_type == 'Exponential':
    children = [exp_base_slider, exp_scale_slider]
elif func_type == 'Log':
    children = [log_base_slider, log_scale_slider]
elif func_type == 'Root':
    children = [root_power_slider, root_scale_slider]
func_params_box.children = children

# Make all parameter sliders update the plot when changed
def make_slider_observer(slider):
    """Create an observer function for a slider"""
    def observer(change):
        try:
            update_function_plot()
        except NameError:
            pass
    return observer

for slider in [linear_slope_slider, linear_intercept_slider,
               piecewise_kink_slider, piecewise_slope1_slider, piecewise_slope2_slider, piecewise_intercept_slider,
               quadratic_a_slider, quadratic_b_slider, quadratic_c_slider,
               exp_base_slider, exp_scale_slider,
               log_base_slider, log_scale_slider,
               root_power_slider, root_scale_slider]:
    slider.observe(make_slider_observer(slider), names='value')


In [64]:
# Now define the actual implementations after widgets are created
def get_dist_params():
    """Get current distribution parameters"""
    dist_type = dist_dropdown.value
    params = {}
    
    if dist_type == 'Beta':
        params = {'alpha': beta_alpha_slider.value, 'beta': beta_beta_slider.value}
    elif dist_type == 'Gamma':
        params = {'shape': gamma_shape_slider.value, 'scale': gamma_scale_slider.value}
    elif dist_type == 'Exponential':
        params = {'scale': exp_scale_slider.value}
    elif dist_type == 'Gaussian':
        params = {'mean': gauss_mean_slider.value, 'std': gauss_std_slider.value}
    
    return params

def get_func_params():
    """Get current function parameters"""
    func_type = func_dropdown.value
    params = {}
    
    if func_type == 'Linear':
        params = {'slope': linear_slope_slider.value, 'intercept': linear_intercept_slider.value}
    elif func_type == 'Piecewise Linear':
        params = {'kink': piecewise_kink_slider.value,
                 'slope1': piecewise_slope1_slider.value,
                 'slope2': piecewise_slope2_slider.value,
                 'intercept': piecewise_intercept_slider.value}
    elif func_type == 'Quadratic':
        params = {'a': quadratic_a_slider.value,
                 'b': quadratic_b_slider.value,
                 'c': quadratic_c_slider.value}
    elif func_type == 'Exponential':
        params = {'base': exp_base_slider.value, 'scale': exp_scale_slider.value}
    elif func_type == 'Log':
        params = {'base': log_base_slider.value, 'scale': log_scale_slider.value}
    elif func_type == 'Root':
        params = {'power': root_power_slider.value, 'scale': root_scale_slider.value}
    
    return params

def update_function_plot():
    """Update the plot when parameters change (without samples)"""
    dist_params = get_dist_params()
    func_params = get_func_params()
    # Show Y samples if they exist, otherwise show X samples
    if len(current_Y_samples) > 0:
        update_plot(current_X_samples, dist_dropdown.value, func_dropdown.value, 
                   dist_params, func_params, Y_samples=current_Y_samples, show_Y=True)
    else:
        update_plot(current_X_samples, dist_dropdown.value, func_dropdown.value, 
                   dist_params, func_params)


In [None]:
# Number of samples slider
n_samples_slider = widgets.IntSlider(
    value=100, min=10, max=1000, step=10,
    description='Number of samples:',
    style={'description_width': 'initial'}
)

# Draw samples button
draw_samples_button = widgets.Button(
    description='Draw Samples',
    button_style='primary',
    layout=widgets.Layout(width='200px', height='40px')
)

# Transform button
transform_button = widgets.Button(
    description='Transform',
    button_style='success',
    layout=widgets.Layout(width='200px', height='40px')
)

# Show Density button (initially disabled)
show_density_button = widgets.Button(
    description='Show Density',
    button_style='info',
    layout=widgets.Layout(width='200px', height='40px'),
    disabled=True  # Disabled until samples are drawn
)

status_html = widgets.HTML(value="Ready to draw samples.")

def on_draw_samples_clicked(button):
    """Handle the Draw Samples button click with progressive visualization"""
    global current_X_samples, current_Y_samples, show_density
    
    n_total = n_samples_slider.value
    dist_type = dist_dropdown.value
    func_type = func_dropdown.value
    dist_params = get_dist_params()
    func_params = get_func_params()
    
    # Generate all X samples at once
    X_all = sample_X_distribution(dist_type, n_total, **dist_params)
    
    # Store samples and clear Y samples
    current_X_samples = X_all
    current_Y_samples = np.array([])  # Reset Y samples
    show_density = False  # Reset density display
    show_density_button.description = 'Show Density'  # Reset button text
    
    # Progressive visualization
    status_html.value = "Generating samples..."
    
    sample_index = 0
    while sample_index < n_total:
        batch_size = determine_batch_size(sample_index)
        end_index = min(sample_index + batch_size, n_total)
        
        # Get samples up to current index
        X_visible = X_all[:end_index]
        
        # Update plot
        update_plot(X_visible, dist_type, func_type, dist_params, func_params)
        
        # Update status
        status_html.value = f"Generated {end_index} / {n_total} samples"
        
        # Small delay for animation effect
        if sample_index < 10:
            time.sleep(0.1)
        elif sample_index < 30:
            time.sleep(0.05)
        else:
            time.sleep(0.01)
        
        sample_index = end_index
    
    # Final update
    update_plot(current_X_samples, dist_type, func_type, dist_params, func_params)
    status_html.value = f"Complete! Generated {n_total} samples."
    
    # Enable the Show Density button now that samples are available
    show_density_button.disabled = False

def on_transform_clicked(button):
    """Handle the Transform button click - apply g(X) to stored X samples"""
    global current_X_samples, current_Y_samples
    
    if len(current_X_samples) == 0:
        status_html.value = "Please draw samples first!"
        return
    
    func_type = func_dropdown.value
    func_params = get_func_params()
    dist_type = dist_dropdown.value
    dist_params = get_dist_params()
    
    # Apply g(X) to all X samples to get Y
    Y_all = apply_g_function(current_X_samples, func_type, **func_params)
    
    # Store Y samples
    current_Y_samples = Y_all
    
    # Progressive visualization
    status_html.value = "Transforming samples..."
    
    n_total = len(current_X_samples)
    sample_index = 0
    while sample_index < n_total:
        batch_size = determine_batch_size(sample_index)
        end_index = min(sample_index + batch_size, n_total)
        
        # Get Y samples up to current index
        Y_visible = Y_all[:end_index]
        
        # Update plot showing Y samples
        update_plot(current_X_samples, dist_type, func_type, dist_params, func_params,
                   Y_samples=Y_visible, show_Y=True)
        
        # Update status
        status_html.value = f"Transformed {end_index} / {n_total} samples"
        
        # Small delay for animation effect
        if sample_index < 10:
            time.sleep(0.1)
        elif sample_index < 30:
            time.sleep(0.05)
        else:
            time.sleep(0.01)
        
        sample_index = end_index
    
    # Final update
    update_plot(current_X_samples, dist_type, func_type, dist_params, func_params,
               Y_samples=current_Y_samples, show_Y=True)
    status_html.value = f"Complete! Transformed {n_total} samples."

def on_show_density_clicked(button):
    """Handle the Show Density button click - toggle theoretical density curves"""
    global show_density
    
    # Toggle the show_density flag
    show_density = not show_density
    
    # Update button text
    if show_density:
        show_density_button.description = 'Hide Density'
        status_html.value = "Showing theoretical density functions"
    else:
        show_density_button.description = 'Show Density'
        status_html.value = "Hiding theoretical density functions"
    
    # Update the plot
    update_function_plot()

draw_samples_button.on_click(on_draw_samples_clicked)
transform_button.on_click(on_transform_clicked)
show_density_button.on_click(on_show_density_clicked)


In [66]:
# Layout the interface
left_panel = widgets.VBox([
    widgets.HTML("<h3>X Distribution</h3>"),
    dist_dropdown,
    dist_params_box
])

right_panel = widgets.VBox([
    widgets.HTML("<h3>Function g(x)</h3>"),
    func_dropdown,
    func_params_box
])

control_panel = widgets.HBox([
    left_panel,
    right_panel
], layout=widgets.Layout(width='100%'))

bottom_panel = widgets.VBox([
    n_samples_slider,
    widgets.HBox([draw_samples_button, transform_button, show_density_button]),
    status_html
])

# Initial plot
update_function_plot()

# Display everything
display(control_panel)
display(bottom_panel)
display(plot_output)


HBox(children=(VBox(children=(HTML(value='<h3>X Distribution</h3>'), Dropdown(description='X Distribution:', o…

VBox(children=(IntSlider(value=100, description='Number of samples:', max=1000, min=10, step=10, style=SliderS…

Output()