# Interactive Scenario Discovery Explorer: PRIM, PCA-PRIM, and CART

This notebook provides an interactive comparison of three scenario discovery methods — **PRIM**, **PCA-PRIM**, and **CART** — on a synthetic 2D dataset containing a quadrilateral. It is designed for demonstration, experimentation and diagnostics, with support for real-time parameter tuning and side-by-side visual analysis.

### Algorithms Compared

- **PRIM** – Generates axis-aligned boxes by peeling low-density regions.
- **PCA-PRIM** – Applies PRIM after a PCA-based rotation to align with tilted structures.
- **CART** – Builds a binary classification tree using axis-parallel splits.

Oblique decision trees and advanced methods are **not included** in this notebook.

### Slider Descriptions
The user interface provides adjustable sliders to control data generation and scenario discovery behavior. These parameters affect the shape, noise level, and algorithmic thresholds:

- **Num Dots**: Sets the total number of points in the unit square \([0, 1]^2\). Higher values improve resolution but increase runtime.

- **Corner X/Y (1–4)**: Sets the quadrilateral's four vertices, which determine the region. Adjusting them changes its rotation, skew or size — useful for testing method sensitivity to shape orientation.

- **Frac Inside**: Sets the probability that points inside the ground truth shape are labeled as of interest (class 1). Lower values increase ambiguity within the region, introducing internal label uncertainty that makes it harder for methods to form compact, high-purity boxes.

- **Frac Outside**: Sets the probability that points outside the shape are labeled as of interest (class 1). Higher values increase external label ambiguity, forcing methods to differentiate relevant and irrelevant areas in the presence of scattered positive labels beyond the true region.

- **Peel Frac**: Sets how aggressively PRIM peels low-density regions. Larger values speed up peeling but may overshoot; smaller values offer many boxes to choose from, and a more detailed peeling trajectory.

- **PRIM Mass Min**: Sets the minimum fraction of data that a PRIM box must contain. Higher values make the process stop earlier.

- **CART Mass Min**: Sets the minimum mass per leaf in the CART tree. Smaller values allow higher complexity trees but may overfit; higher values yield lower complexity trees.

In [1]:
# Imports and setup
import numpy as np
import ipywidgets as widgets
from IPython.display import display
from A_scenario_methods_demo.notebook_helpers import update_plots, save_prim_plots, save_cart_plots

The cell below sets up and creates the widget containing the sliders, plots and buttons. It does this using the functions in the `A_scenario_methods_demo` package.

In [2]:
# Create the sliders for the interface
num_dots_slider = widgets.IntSlider(value=1700, min=100, max=2500, step=100, description="Num Dots")
default_quad = np.array([0.3, 0.9, 0.9, 0.7, 1.0, 0.4, 0.2, 0.55])
corner_labels = ["Corner 1 X", "Corner 1 Y", "Corner 2 X", "Corner 2 Y",
                 "Corner 3 X", "Corner 3 Y", "Corner 4 X", "Corner 4 Y"]
quad_sliders = [widgets.FloatSlider(value=val, min=0.0, max=1.0, step=0.05, description=lbl)
                for lbl, val in zip(corner_labels, default_quad)]
frac_inside_slider = widgets.FloatSlider(value=0.95, min=0.7, max=1.0, step=0.05, description="Frac Inside")
frac_outside_slider = widgets.FloatSlider(value=0.05, min=0.0, max=0.3, step=0.05, description="Frac Outside")
peel_frac_slider = widgets.FloatSlider(value=0.10, min=0.0, max=0.5, step=0.05, description="Peel Frac")
prim_mass_min_slider = widgets.FloatSlider(value=0.05, min=0.0, max=1.0, step=0.01, description="PRIM Mass Min")
cart_mass_min_slider = widgets.FloatSlider(value=0.05, min=0.0, max=1.0, step=0.01, description="CART Mass Min")

# Save buttons
save_prim_button = widgets.Button(description="Save PRIM Plots")
save_cart_button = widgets.Button(description="Save CART Plots")

# Create output areas
plot_outputs = [widgets.Output(layout=widgets.Layout(width="100%", height="300px"))
                for _ in range(9)]
table_output = widgets.Output(layout=widgets.Layout(width="100%", height="150px"))
grid = widgets.GridspecLayout(3, 3, width="12in", height="9in")
for i in range(3):
    for j in range(3):
        grid[i, j] = plot_outputs[i * 3 + j]

# Automatically refresh plots when sliders change        
def on_update(_):
    update_plots(
        quad_sliders,
        num_dots_slider.value,
        frac_inside_slider.value,
        frac_outside_slider.value,
        peel_frac_slider.value,
        prim_mass_min_slider.value,
        cart_mass_min_slider.value,
        plot_outputs,
        table_output,
    )
    
# Attach event listeners
num_dots_slider.observe(on_update, names="value")
frac_inside_slider.observe(on_update, names="value")
frac_outside_slider.observe(on_update, names="value")
peel_frac_slider.observe(on_update, names="value")
prim_mass_min_slider.observe(on_update, names="value")
cart_mass_min_slider.observe(on_update, names="value")
for slider in quad_sliders:
    slider.observe(on_update, names="value")

# Initial plot update
on_update(None)

# Save current PRIM and PCA-PRIM plots
save_prim_button.on_click(lambda b: save_prim_plots(
    quad_sliders,
    num_dots_slider.value,
    frac_inside_slider.value,
    frac_outside_slider.value,
    peel_frac_slider.value,
    prim_mass_min_slider.value,
    cart_mass_min_slider.value,
))
# Save current CART plots
save_cart_button.on_click(lambda b: save_cart_plots(
    quad_sliders,
    num_dots_slider.value,
    frac_inside_slider.value,
    frac_outside_slider.value,
    peel_frac_slider.value,
    prim_mass_min_slider.value,
    cart_mass_min_slider.value
))

# Display UI containing created sliders, buttons and widgets
input_widgets = widgets.VBox([
    num_dots_slider,
    widgets.HBox([frac_inside_slider, frac_outside_slider, peel_frac_slider]),
    widgets.HBox(quad_sliders[:4]),
    widgets.HBox(quad_sliders[4:]),
    widgets.HBox([prim_mass_min_slider, cart_mass_min_slider]),
    save_prim_button,
    save_cart_button,
])
display(input_widgets, grid, table_output)

VBox(children=(IntSlider(value=1700, description='Num Dots', max=2500, min=100, step=100), HBox(children=(Floa…

GridspecLayout(children=(Output(layout=Layout(grid_area='widget001', height='300px', width='100%')), Output(la…

Output(layout=Layout(height='585px', width='100%'))