In [13]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import plotly.graph_objects as go
from IPython.display import display, clear_output
import time


In [14]:
def generate_samples_linear(n_samples, sigma, seed=None):
    """
    Generate samples for linear model: Y = a*X + b + Z
    where a = 1/2, b = 2, X ~ N(1, 1), Z ~ N(0, sigma)
    """
    if seed is not None:
        np.random.seed(seed)
    X = np.random.normal(1.0, 1.0, n_samples)
    Z = np.random.normal(0.0, sigma, n_samples)
    Y = 0.5 * X + 2.0 + Z
    return X, Y

def generate_samples_quadratic(n_samples, sigma, seed=None):
    """
    Generate samples for quadratic model: Y = a*X^2 + b + Z
    where a = 1/4, b = -1, X ~ N(1, 1), Z ~ N(0, sigma)
    """
    if seed is not None:
        np.random.seed(seed)
    X = np.random.normal(1.0, 1.0, n_samples)
    Z = np.random.normal(0.0, sigma, n_samples)
    Y = 0.25 * X**2 - 1.0 + Z
    return X, Y

def fit_linear_regression(X, Y, use_x_squared=False):
    """
    Fit linear regression: Y = a*X_predictor + b
    where X_predictor = X if use_x_squared=False, else X^2
    Returns (a, b) - slope and intercept
    """
    if len(X) == 0:
        return 0.0, 0.0
    
    if use_x_squared:
        X_predictor = X**2
    else:
        X_predictor = X
    
    # Least squares: a = (n*sum(XY) - sum(X)*sum(Y)) / (n*sum(X^2) - (sum(X))^2)
    #               b = (sum(Y) - a*sum(X)) / n
    n = len(X)
    sum_X = np.sum(X_predictor)
    sum_Y = np.sum(Y)
    sum_XY = np.sum(X_predictor * Y)
    sum_X2 = np.sum(X_predictor**2)
    
    denominator = n * sum_X2 - sum_X**2
    if abs(denominator) < 1e-10:
        # Degenerate case
        return 0.0, np.mean(Y)
    
    a = (n * sum_XY - sum_X * sum_Y) / denominator
    b = (sum_Y - a * sum_X) / n
    
    return float(a), float(b)

def compute_mse(X, Y, a, b, use_x_squared=False):
    """
    Compute Mean Squared Error: (1/n) * sum((Y_j - (a*X_predictor_j + b))^2)
    """
    if len(X) == 0:
        return 0.0
    
    if use_x_squared:
        X_predictor = X**2
    else:
        X_predictor = X
    
    Y_pred = a * X_predictor + b
    mse = np.mean((Y - Y_pred)**2)
    return float(mse)

def compute_mse_grid(X, Y, a_range, b_range, use_x_squared=False):
    """
    Compute MSE over a grid of (a, b) values.
    Returns: A, B (meshgrid), MSE_grid (2D array)
    """
    if len(X) == 0:
        A, B = np.meshgrid(a_range, b_range)
        return A, B, np.zeros_like(A)
    
    if use_x_squared:
        X_predictor = X**2
    else:
        X_predictor = X
    
    A, B = np.meshgrid(a_range, b_range)
    MSE_grid = np.zeros_like(A)
    
    n = len(X)
    for i in range(len(b_range)):
        for j in range(len(a_range)):
            a_val = a_range[j]
            b_val = b_range[i]
            Y_pred = a_val * X_predictor + b_val
            MSE_grid[i, j] = np.mean((Y - Y_pred)**2)
    
    return A, B, MSE_grid

def compute_mse_gradient(X, Y, a, b, use_x_squared=False):
    """
    Compute gradient of MSE with respect to (a, b).
    Returns: (dMSE/da, dMSE/db)
    """
    if len(X) == 0:
        return 0.0, 0.0
    
    if use_x_squared:
        X_predictor = X**2
    else:
        X_predictor = X
    
    n = len(X)
    Y_pred = a * X_predictor + b
    residuals = Y - Y_pred
    
    # dMSE/da = -2/n * sum(X_predictor * residuals)
    # dMSE/db = -2/n * sum(residuals)
    dMSE_da = -2.0 / n * np.sum(X_predictor * residuals)
    dMSE_db = -2.0 / n * np.sum(residuals)
    
    return float(dMSE_da), float(dMSE_db)

def compute_mse_gradient_field(X, Y, a_range, b_range, use_x_squared=False):
    """
    Compute gradient field of MSE over a grid of (a, b) values.
    Returns: A, B (meshgrid), grad_a_grid, grad_b_grid (2D arrays)
    """
    if len(X) == 0:
        A, B = np.meshgrid(a_range, b_range)
        return A, B, np.zeros_like(A), np.zeros_like(A)
    
    if use_x_squared:
        X_predictor = X**2
    else:
        X_predictor = X
    
    A, B = np.meshgrid(a_range, b_range)
    grad_a_grid = np.zeros_like(A)
    grad_b_grid = np.zeros_like(A)
    
    n = len(X)
    for i in range(len(b_range)):
        for j in range(len(a_range)):
            a_val = a_range[j]
            b_val = b_range[i]
            Y_pred = a_val * X_predictor + b_val
            residuals = Y - Y_pred
            grad_a_grid[i, j] = -2.0 / n * np.sum(X_predictor * residuals)
            grad_b_grid[i, j] = -2.0 / n * np.sum(residuals)
    
    return A, B, grad_a_grid, grad_b_grid

def add_mse_gradient_field_flat(fig: go.Figure, X, Y, a_range, b_range, use_x_squared, 
                                 z_floor: float, density: int = 12, 
                                 arrow_color: str = "#1f77b4", arrow_length: float = 0.15,
                                 head_length_frac: float = 0.25, head_angle_deg: float = 28.0, 
                                 line_width: int = 4) -> None:
    """
    Add gradient field of MSE on the floor plane, similar to add_gradient_field_flat
    """
    A, B, grad_a_grid, grad_b_grid = compute_mse_gradient_field(X, Y, a_range, b_range, use_x_squared)
    
    ny, nx = grad_a_grid.shape
    step_x = max(1, nx // density)
    step_y = max(1, ny // density)
    
    a_sampled = A[::step_y, ::step_x]
    b_sampled = B[::step_y, ::step_x]
    grad_a_sampled = grad_a_grid[::step_y, ::step_x]
    grad_b_sampled = grad_b_grid[::step_y, ::step_x]
    
    # Negate gradients to point toward minimum (direction of steepest decrease)
    # Gradient points in direction of steepest increase, so -gradient points toward minimum
    grad_a_sampled = -grad_a_sampled
    grad_b_sampled = -grad_b_sampled
    
    # Normalize gradients for visualization
    mags = np.sqrt(grad_a_sampled**2 + grad_b_sampled**2) + 1e-9
    ua = grad_a_sampled / mags
    ub = grad_b_sampled / mags
    
    # Prepare multi-segment lines with NaN breaks
    x_lines = []
    y_lines = []
    z_lines = []
    x_heads = []
    y_heads = []
    z_heads = []
    
    head_len = float(arrow_length * head_length_frac)
    theta = float(np.deg2rad(head_angle_deg))
    cos_t, sin_t = float(np.cos(theta)), float(np.sin(theta))
    
    def rot(u, v, c, s):
        return u * c - v * s, u * s + v * c
    
    for j in range(a_sampled.shape[0]):
        for i in range(a_sampled.shape[1]):
            a0 = float(a_sampled[j, i])
            b0 = float(b_sampled[j, i])
            da = float(ua[j, i])
            db = float(ub[j, i])
            a1 = a0 + arrow_length * da
            b1 = b0 + arrow_length * db
            
            x_lines.extend([a0, a1, np.nan])
            y_lines.extend([b0, b1, np.nan])
            z_lines.extend([z_floor, z_floor, np.nan])
            
            # Arrowhead
            ra1, rb1 = rot(da, db, cos_t, sin_t)
            ra2, rb2 = rot(da, db, cos_t, -sin_t)
            x_heads.extend([a1, a1 - head_len * ra1, np.nan])
            y_heads.extend([b1, b1 - head_len * rb1, np.nan])
            z_heads.extend([z_floor, z_floor, np.nan])
            x_heads.extend([a1, a1 - head_len * ra2, np.nan])
            y_heads.extend([b1, b1 - head_len * rb2, np.nan])
            z_heads.extend([z_floor, z_floor, np.nan])
    
    fig.add_trace(go.Scatter3d(
        x=x_lines, y=y_lines, z=z_lines,
        mode="lines",
        line=dict(color=arrow_color, width=line_width),
        name="Gradient field",
        showlegend=True
    ))
    fig.add_trace(go.Scatter3d(
        x=x_heads, y=y_heads, z=z_heads,
        mode="lines",
        line=dict(color=arrow_color, width=line_width),
        name="",
        showlegend=False
    ))


In [15]:
# Widgets
model_dropdown = widgets.Dropdown(
    options=[("Linear", "linear"), 
             ("Quadratic", "quadratic")],
    value="linear",
    description="Model:",
    layout=widgets.Layout(width="350px")
)

n_samples_slider = widgets.IntSlider(
    description="Number of samples:",
    min=1,
    max=200,
    value=50,
    continuous_update=False,
    layout=widgets.Layout(width="400px")
)

sigma_slider = widgets.FloatSlider(
    description="Noise (σ):",
    min=0.1,
    max=2.0,
    value=0.5,
    step=0.1,
    continuous_update=False,
    readout_format=".2f",
    layout=widgets.Layout(width="400px")
)

sample_button = widgets.Button(
    description="Sample",
    button_style="primary",
    layout=widgets.Layout(width="150px")
)

bootstrap_button = widgets.Button(
    description="Add possible best fit lines",
    button_style="",
    layout=widgets.Layout(width="200px"),
    disabled=True  # Disabled until samples are generated
)

# Sliders for proposed best fit line
# Note: For quadratic model, 'a' is the coefficient of X^2
a_slider = widgets.FloatSlider(
    description="Coefficient (a):",
    min=-1.0,
    max=1.0,
    value=0.0,
    step=0.05,
    continuous_update=True,
    readout_format=".2f",
    layout=widgets.Layout(width="300px")
)

b_slider = widgets.FloatSlider(
    description="Intercept (b):",
    min=-5.0,
    max=5.0,
    value=0.0,
    step=0.1,
    continuous_update=True,
    readout_format=".2f",
    layout=widgets.Layout(width="300px")
)

status_html = widgets.HTML(value="Click 'Sample' to generate data points.")
mse_html = widgets.HTML(value="")
gradient_html = widgets.HTML(value="")

# Output widget for the 2D plot
plot_output = widgets.Output()
plot_output.layout = widgets.Layout(width="800px", height="600px")

# Output widget for the 3D MSE plot
plot3d_output = widgets.Output()
plot3d_output.layout = widgets.Layout(width="800px", height="600px")

# Buttons for 3D visualization
save_mse_button = widgets.Button(
    description="Save MSE",
    button_style="",
    layout=widgets.Layout(width="150px"),
    disabled=True  # Disabled until samples are generated
)

reveal_surface_button = widgets.Button(
    description="Reveal MSE surface",
    button_style="",
    layout=widgets.Layout(width="150px"),
    disabled=True  # Disabled until samples are generated
)

# Checkboxes for 3D visualization options
show_heatmap_levelsets_chk = widgets.Checkbox(value=False, description="Show heatmap and level sets")
show_gradients_chk = widgets.Checkbox(value=False, description="Show gradient field")

# Storage for current samples
current_X = np.array([])
current_Y = np.array([])

# Storage for bootstrap lines
bootstrap_lines = []  # List of (a, b) tuples for bootstrap lines
show_bootstrap_lines = False

# Storage for saved MSE points
saved_mse_points = []  # List of (a, b, mse) tuples
show_mse_surface = False


In [16]:
def determine_batch_size(sample_index):
    """
    Determine how many samples to add in this batch.
    - Samples 1-10: one at a time
    - Samples 10-30: 2 at a time
    - Samples 30-70: 4 at a time
    - Samples 70+: 8 at a time
    """
    if sample_index < 10:
        return 1
    elif sample_index < 30:
        return 2
    elif sample_index < 70:
        return 4
    else:
        return 8

def get_color_from_mse(mse, min_mse=None, max_mse=None):
    """
    Convert MSE to a color from green (good/low MSE) to red (bad/high MSE).
    Returns an RGB color string.
    """
    if min_mse is None or max_mse is None or max_mse <= min_mse:
        # Default: green for low MSE, red for high MSE
        # Use a reasonable range
        if min_mse is None:
            min_mse = 0.0
        if max_mse is None or max_mse <= min_mse:
            max_mse = max(mse, 1.0)
    
    # Normalize MSE to [0, 1]
    normalized = (mse - min_mse) / (max_mse - min_mse) if max_mse > min_mse else 0.5
    normalized = max(0.0, min(1.0, normalized))  # Clamp to [0, 1]
    
    # Interpolate from green (0, 1, 0) to red (1, 0, 0)
    r = int(255 * normalized)
    g = int(255 * (1 - normalized))
    b = 0
    
    return f"rgb({r}, {g}, {b})"

def update_plot(X_visible, Y_visible, model_type, show_bootstrap=False, proposed_a=None, proposed_b=None):
    """Update the scatter plot with current visible samples"""
    global bootstrap_lines
    
    fig = go.Figure()
    
    # Determine x-axis range for true model curve
    if len(X_visible) > 0:
        x_min = float(X_visible.min())
        x_max = float(X_visible.max())
        # Add some padding
        x_range = x_max - x_min
        x_min = x_min - 0.1 * x_range if x_range > 0 else x_min - 0.5
        x_max = x_max + 0.1 * x_range if x_range > 0 else x_max + 0.5
    else:
        # Default range when no samples yet
        x_min, x_max = -2.0, 4.0
    
    x_curve = np.linspace(x_min, x_max, 200)
    use_x_squared = (model_type == "quadratic")
    
    # Add bootstrap lines in light grey (behind everything)
    if show_bootstrap and len(bootstrap_lines) > 0 and len(X_visible) > 0:
        # Add first bootstrap line with legend entry
        a_bs_first, b_bs_first = bootstrap_lines[0]
        if use_x_squared:
            y_bs_first = a_bs_first * x_curve**2 + b_bs_first
        else:
            y_bs_first = a_bs_first * x_curve + b_bs_first
        fig.add_trace(go.Scatter(
            x=x_curve,
            y=y_bs_first,
            mode='lines',
            line=dict(color='rgba(200, 200, 200, 0.3)', width=1),
            name=f'Bootstrap lines ({len(bootstrap_lines)} total)',
            showlegend=True,
            hoverinfo='skip'
        ))
        # Add remaining bootstrap lines without legend
        for a_bs, b_bs in bootstrap_lines[1:]:
            if use_x_squared:
                y_bs = a_bs * x_curve**2 + b_bs
            else:
                y_bs = a_bs * x_curve + b_bs
            fig.add_trace(go.Scatter(
                x=x_curve,
                y=y_bs,
                mode='lines',
                line=dict(color='rgba(200, 200, 200, 0.3)', width=1),
                name='',
                showlegend=False,
                hoverinfo='skip'
            ))
    
    # Add proposed best fit line (if provided)
    if proposed_a is not None and proposed_b is not None and len(X_visible) > 0:
        # Compute MSE for color
        mse = compute_mse(X_visible, Y_visible, proposed_a, proposed_b, use_x_squared)
        
        # Always use red color for the proposed line
        line_color = 'red'
        
        # Plot the proposed line
        if use_x_squared:
            y_proposed = proposed_a * x_curve**2 + proposed_b
        else:
            y_proposed = proposed_a * x_curve + proposed_b
        
        fig.add_trace(go.Scatter(
            x=x_curve,
            y=y_proposed,
            mode='lines',
            line=dict(color=line_color, width=3),
            name=f'Your Proposed Best Fit Line: Y = {proposed_a:.2f}*X{"²" if use_x_squared else ""} + {proposed_b:.2f}'
        ))
        
        # Update MSE display
        mse_html.value = f"<b>MSE: {mse:.4f}</b>"
    else:
        mse_html.value = ""
    
    # Add scatter plot (only if we have samples)
    if len(X_visible) > 0:
        fig.add_trace(go.Scatter(
            x=X_visible,
            y=Y_visible,
            mode='markers',
            marker=dict(
                size=6,
                color='#1f77b4',
                line=dict(width=1, color='DarkSlateGrey')
            ),
            name='Samples'
        ))
    
    fig.update_layout(
        title=f"Least Squares Demo - {len(X_visible)} samples",
        xaxis_title="X",
        yaxis_title="Y",
        width=800,
        height=600,
        showlegend=True,
        legend=dict(x=0.02, y=0.98, yanchor="top")
    )
    
    with plot_output:
        clear_output(wait=True)
        display(fig)

def update_3d_plot():
    """Update the 3D MSE visualization"""
    global saved_mse_points, show_mse_surface, current_X, current_Y
    
    if len(current_X) == 0:
        # Show empty plot
        fig = go.Figure()
        fig.update_layout(
            scene=dict(
                xaxis_title="a (coefficient)",
                yaxis_title="b (intercept)",
                zaxis_title="MSE(a,b)"
            ),
            title="MSE Surface - Generate samples first",
            width=800,
            height=600
        )
        with plot3d_output:
            clear_output(wait=True)
            display(fig)
        return
    
    model_type = model_dropdown.value
    use_x_squared = (model_type == "quadratic")
    
    # Define ranges for a and b
    a_min, a_max = -1.0, 1.0
    b_min, b_max = -5.0, 5.0
    
    fig = go.Figure()
    
    # Add MSE surface if requested
    if show_mse_surface:
        a_range = np.linspace(a_min, a_max, 50)
        b_range = np.linspace(b_min, b_max, 50)
        A, B, MSE_grid = compute_mse_grid(current_X, current_Y, a_range, b_range, use_x_squared)
        
        fig.add_trace(go.Surface(
            x=A,
            y=B,
            z=MSE_grid,
            colorscale="Viridis",
            opacity=0.4,
            showscale=True,
            name="MSE Surface"
        ))
    
    # Add heatmap and level sets on bottom plane if requested
    if show_heatmap_levelsets_chk.value:
        a_range = np.linspace(a_min, a_max, 50)
        b_range = np.linspace(b_min, b_max, 50)
        A, B, MSE_grid = compute_mse_grid(current_X, current_Y, a_range, b_range, use_x_squared)
        
        # Find minimum MSE for floor level
        mse_min = float(np.min(MSE_grid)) if show_mse_surface else 0.0
        mse_floor = mse_min - 0.1 * (float(np.max(MSE_grid)) - mse_min) if show_mse_surface else 0.0
        
        # Add heatmap (lighter opacity so blue arrows are visible)
        fig.add_trace(go.Surface(
            x=A,
            y=B,
            z=np.full_like(MSE_grid, mse_floor),
            surfacecolor=MSE_grid,
            colorscale="Viridis",
            showscale=False,
            opacity=0.25,  # Reduced from 0.6 to make it lighter
            name="Heatmap"
        ))
        
        # Add level sets
        mse_min_val = float(np.min(MSE_grid))
        mse_max_val = float(np.max(MSE_grid))
        if mse_max_val > mse_min_val:
            # Use matplotlib to extract contours
            try:
                import matplotlib.pyplot as plt
                levels = np.linspace(mse_min_val, mse_max_val, 8)
                cs = plt.contour(A, B, MSE_grid, levels=levels)
                plt.close()
                
                for i, level in enumerate(levels):
                    paths = cs.collections[i].get_paths()
                    for path in paths:
                        vertices = path.vertices
                        if len(vertices) > 1:
                            fig.add_trace(go.Scatter3d(
                                x=vertices[:, 0],
                                y=vertices[:, 1],
                                z=np.full(len(vertices), mse_floor + 0.01),
                                mode='lines',
                                line=dict(color='gray', width=2),
                                showlegend=False,
                                name='',
                                hoverinfo='skip'
                            ))
            except Exception as e:
                # Fallback: simple visualization
                levels = np.linspace(mse_min_val, mse_max_val, 8)
                threshold = (mse_max_val - mse_min_val) / 100
                for level in levels:
                    mask = np.abs(MSE_grid - level) < threshold
                    if np.any(mask):
                        fig.add_trace(go.Scatter3d(
                            x=A[mask].flatten(),
                            y=B[mask].flatten(),
                            z=np.full(np.sum(mask), mse_floor + 0.01),
                            mode='markers',
                            marker=dict(size=2, color='gray'),
                            showlegend=False,
                            name='',
                            hoverinfo='skip'
                        ))
    
    # Add gradient field on floor if requested
    if show_gradients_chk.value:
        a_range = np.linspace(a_min, a_max, 50)
        b_range = np.linspace(b_min, b_max, 50)
        A, B, MSE_grid = compute_mse_grid(current_X, current_Y, a_range, b_range, use_x_squared)
        
        # Find minimum MSE for floor level
        mse_min = float(np.min(MSE_grid)) if show_mse_surface else 0.0
        mse_floor = mse_min - 0.1 * (float(np.max(MSE_grid)) - mse_min) if show_mse_surface else 0.0
        
        add_mse_gradient_field_flat(fig, current_X, current_Y, a_range, b_range, use_x_squared,
                                    mse_floor, density=12, arrow_color="#1f77b4", 
                                    arrow_length=0.15, head_length_frac=0.28, 
                                    head_angle_deg=26.0, line_width=4)
    
    # Add saved MSE points
    if len(saved_mse_points) > 0:
        a_points = [p[0] for p in saved_mse_points]
        b_points = [p[1] for p in saved_mse_points]
        mse_points = [p[2] for p in saved_mse_points]
        
        fig.add_trace(go.Scatter3d(
            x=a_points,
            y=b_points,
            z=mse_points,
            mode='markers',
            marker=dict(
                size=8,
                color='red',
                line=dict(width=2, color='darkred')
            ),
            name='Saved points',
            text=[f"a={p[0]:.2f}, b={p[1]:.2f}, MSE={p[2]:.4f}" for p in saved_mse_points],
            hovertemplate='a=%{x:.2f}<br>b=%{y:.2f}<br>MSE=%{z:.4f}<extra></extra>'
        ))
        
        # Add gradient vectors from each saved point if gradients are enabled
        if show_gradients_chk.value:
            # Arrow length is twice the field arrows (0.15 * 2 = 0.30)
            arrow_length_point = 0.30
            head_length_frac = 0.28
            head_angle_deg = 26.0
            head_len = arrow_length_point * head_length_frac
            theta = float(np.deg2rad(head_angle_deg))
            cos_t, sin_t = float(np.cos(theta)), float(np.sin(theta))
            
            def rot(u, v, c, s):
                return u * c - v * s, u * s + v * c
            
            # Draw gradient vector for each saved point
            for a_val, b_val, mse_val in saved_mse_points:
                grad_a, grad_b = compute_mse_gradient(current_X, current_Y, a_val, b_val, use_x_squared)
                grad_mag = np.sqrt(grad_a**2 + grad_b**2)
                
                if grad_mag > 1e-10:
                    # Negate gradient to point toward minimum (direction of steepest decrease)
                    # Gradient points in direction of steepest increase, so -gradient points toward minimum
                    da = -grad_a / grad_mag
                    db = -grad_b / grad_mag
                    
                    # End point of arrow (in the a-b plane, at the same z level as the point)
                    a1 = a_val + arrow_length_point * da
                    b1 = b_val + arrow_length_point * db
                    
                    # Draw arrow line starting from the point in 3D space
                    fig.add_trace(go.Scatter3d(
                        x=[a_val, a1],
                        y=[b_val, b1],
                        z=[mse_val, mse_val],  # Keep at the same height as the point
                        mode='lines',
                        line=dict(color='red', width=6),
                        showlegend=False,
                        name='',
                        hoverinfo='skip'
                    ))
                    
                    # Draw arrowhead
                    ra1, rb1 = rot(da, db, cos_t, sin_t)
                    ra2, rb2 = rot(da, db, cos_t, -sin_t)
                    fig.add_trace(go.Scatter3d(
                        x=[a1, a1 - head_len * ra1, a1, a1 - head_len * ra2],
                        y=[b1, b1 - head_len * rb1, b1, b1 - head_len * rb2],
                        z=[mse_val, mse_val, mse_val, mse_val],  # Keep at the same height as the point
                        mode='lines',
                        line=dict(color='red', width=6),
                        showlegend=False,
                        name='',
                        hoverinfo='skip'
                    ))
        
        # Update gradient display for last saved point
        if len(saved_mse_points) > 0:
            last_a, last_b, last_mse = saved_mse_points[-1]
            grad_a, grad_b = compute_mse_gradient(current_X, current_Y, last_a, last_b, use_x_squared)
            gradient_html.value = f"<b>Gradient at last point:</b> (∂MSE/∂a, ∂MSE/∂b) = ({grad_a:.4f}, {grad_b:.4f})"
        else:
            gradient_html.value = ""
    
    fig.update_layout(
        scene=dict(
            xaxis_title="a (coefficient)",
            yaxis_title="b (intercept)",
            zaxis_title="MSE(a,b)",
            xaxis=dict(range=[a_min, a_max]),
            yaxis=dict(range=[b_min, b_max])
        ),
        title="MSE Surface: MSE(a, b)",
        width=800,
        height=600,
        showlegend=True,
        legend=dict(x=0.02, y=0.98, yanchor="top", xanchor="left")
    )
    
    with plot3d_output:
        clear_output(wait=True)
        display(fig)


In [17]:
def on_sample_clicked(button):
    """Handle the Sample button click with progressive visualization"""
    global current_X, current_Y, bootstrap_lines, show_bootstrap_lines
    
    n_total = n_samples_slider.value
    sigma = sigma_slider.value
    model_type = model_dropdown.value
    
    # Generate all samples at once
    if model_type == "linear":
        X_all, Y_all = generate_samples_linear(n_total, sigma)
    else:  # quadratic
        X_all, Y_all = generate_samples_quadratic(n_total, sigma)
    
    # Store all samples
    current_X = X_all
    current_Y = Y_all
    
    # Clear bootstrap lines when new samples are generated
    bootstrap_lines = []
    show_bootstrap_lines = False
    
    # Progressive visualization
    status_html.value = "Generating samples..."
    
    sample_index = 0
    while sample_index < n_total:
        batch_size = determine_batch_size(sample_index)
        end_index = min(sample_index + batch_size, n_total)
        
        # Get samples up to current index
        X_visible = X_all[:end_index]
        Y_visible = Y_all[:end_index]
        
        # Update plot (without bootstrap lines during sampling)
        update_plot(X_visible, Y_visible, model_type, 
                   show_bootstrap=False, 
                   proposed_a=a_slider.value if len(X_visible) > 0 else None,
                   proposed_b=b_slider.value if len(X_visible) > 0 else None)
        
        # Update status
        status_html.value = f"Generated {end_index} / {n_total} samples"
        
        # Small delay for animation effect (only for first few samples)
        if sample_index < 10:
            time.sleep(0.1)
        elif sample_index < 30:
            time.sleep(0.05)
        else:
            time.sleep(0.01)
        
        sample_index = end_index
    
    # Final update
    update_plot(current_X, current_Y, model_type,
               show_bootstrap=show_bootstrap_lines,
               proposed_a=a_slider.value if len(current_X) > 0 else None,
               proposed_b=b_slider.value if len(current_X) > 0 else None)
    status_html.value = f"Complete! Generated {n_total} samples."
    
    # Enable the buttons now that we have samples
    bootstrap_button.disabled = False
    save_mse_button.disabled = False
    reveal_surface_button.disabled = False
    
    # Initialize 3D plot
    update_3d_plot()

def on_bootstrap_clicked(button):
    """Handle the Bootstrap button click - generate 100 bootstrap best fit lines"""
    global bootstrap_lines, show_bootstrap_lines, current_X, current_Y
    
    if len(current_X) == 0:
        status_html.value = "Please generate samples first!"
        return
    
    status_html.value = "Computing bootstrap samples..."
    
    model_type = model_dropdown.value
    use_x_squared = (model_type == "quadratic")
    n_samples = len(current_X)
    
    bootstrap_lines = []
    
    # Generate 100 bootstrap samples
    for i in range(100):
        # Bootstrap sample: sample with replacement
        indices = np.random.choice(n_samples, size=n_samples, replace=True)
        X_bs = current_X[indices]
        Y_bs = current_Y[indices]
        
        # Fit regression line
        a, b = fit_linear_regression(X_bs, Y_bs, use_x_squared)
        bootstrap_lines.append((a, b))
    
    show_bootstrap_lines = True
    
    # Update plot with bootstrap lines
    update_plot(current_X, current_Y, model_type,
               show_bootstrap=True,
               proposed_a=a_slider.value,
               proposed_b=b_slider.value)
    
    status_html.value = f"Generated {len(bootstrap_lines)} bootstrap best fit lines."

def on_slider_change(change):
    """Handle slider changes for proposed line"""
    global current_X, current_Y
    if len(current_X) > 0:
        update_plot(current_X, current_Y, model_dropdown.value,
                   show_bootstrap=show_bootstrap_lines,
                   proposed_a=a_slider.value,
                   proposed_b=b_slider.value)

def on_save_mse_clicked(button):
    """Handle Save MSE button click - add current (a, b, MSE) to saved points"""
    global saved_mse_points, current_X, current_Y
    
    if len(current_X) == 0:
        status_html.value = "Please generate samples first!"
        return
    
    a_val = a_slider.value
    b_val = b_slider.value
    model_type = model_dropdown.value
    use_x_squared = (model_type == "quadratic")
    
    mse_val = compute_mse(current_X, current_Y, a_val, b_val, use_x_squared)
    saved_mse_points.append((a_val, b_val, mse_val))
    
    # Update 3D plot
    update_3d_plot()
    
    status_html.value = f"Saved point: a={a_val:.2f}, b={b_val:.2f}, MSE={mse_val:.4f}"

def on_reveal_surface_clicked(button):
    """Handle Reveal MSE surface button click"""
    global show_mse_surface
    
    if len(current_X) == 0:
        status_html.value = "Please generate samples first!"
        return
    
    show_mse_surface = True
    update_3d_plot()
    status_html.value = "MSE surface revealed."

def on_checkbox_change(change):
    """Handle checkbox changes for 3D visualization options"""
    update_3d_plot()

# Wire up the buttons
sample_button.on_click(on_sample_clicked)
bootstrap_button.on_click(on_bootstrap_clicked)
save_mse_button.on_click(on_save_mse_clicked)
reveal_surface_button.on_click(on_reveal_surface_clicked)

# Wire up sliders
a_slider.observe(on_slider_change, names="value")
b_slider.observe(on_slider_change, names="value")

# Wire up checkboxes
show_heatmap_levelsets_chk.observe(on_checkbox_change, names="value")
show_gradients_chk.observe(on_checkbox_change, names="value")


In [18]:
# Build the UI
controls_row = widgets.HBox([
    model_dropdown,
    n_samples_slider,
    sigma_slider
])

button_row = widgets.HBox([
    sample_button,
    bootstrap_button,
    status_html
])

proposed_line_row = widgets.HBox([
    widgets.HTML("<b>Your Proposed Best Fit Line:</b>"),
    a_slider,
    b_slider,
    mse_html
])

mse_3d_row = widgets.HBox([
    save_mse_button,
    reveal_surface_button,
    gradient_html
])

mse_3d_options_row = widgets.HBox([
    widgets.HTML("<b>3D Visualization Options:</b>"),
    show_heatmap_levelsets_chk,
    show_gradients_chk
])

ui = widgets.VBox([
    controls_row,
    button_row,
    proposed_line_row,
    plot_output,
    widgets.HTML("<hr><h3>MSE Surface Visualization</h3>"),
    mse_3d_row,
    mse_3d_options_row,
    plot3d_output
])

# Display initial empty plots
update_plot(np.array([]), np.array([]), model_dropdown.value)
update_3d_plot()

# Update plot when model changes
def on_model_change(change):
    global bootstrap_lines, show_bootstrap_lines, current_X, current_Y, saved_mse_points, show_mse_surface
    # Clear everything when model changes
    bootstrap_lines = []
    show_bootstrap_lines = False
    current_X = np.array([])
    current_Y = np.array([])
    saved_mse_points = []
    show_mse_surface = False
    
    # Disable buttons since samples are cleared
    bootstrap_button.disabled = True
    save_mse_button.disabled = True
    reveal_surface_button.disabled = True
    
    # Adjust slider default values based on model
    if model_dropdown.value == "linear":
        a_slider.value = 0
        b_slider.value = 0
    else:  # quadratic
        a_slider.value = 0
        b_slider.value = 0
    
    # Always show blank plot when model changes
    update_plot(np.array([]), np.array([]), model_dropdown.value)
    update_3d_plot()
    
    # Reset status message
    status_html.value = "Click 'Sample' to generate data points."
    gradient_html.value = ""

model_dropdown.observe(on_model_change, names="value")

display(ui)


VBox(children=(HBox(children=(Dropdown(description='Model:', layout=Layout(width='350px'), options=(('Linear',…