# Gen2D

This notebook shows a simple model for clustering a 2D image into different components.

In [None]:
import animation
import gibbs_updates
import jax
import jax.numpy as jnp
import model_simple_continuous
from scipy import datasets

from genjax import ChoiceMapBuilder as C
from genjax import pretty

pretty()

## Model

### Testing to sample from model

In [None]:
image = datasets.face()
H, W, _ = image.shape

hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([100.0, 100.0]),
    b_xy=jnp.array([10000.0, 10000.0]),
    mu_xy=jnp.array([H / 2, W / 2]),
    a_rgb=jnp.array([25.0, 25.0, 25.0]),
    b_rgb=jnp.array([450.0, 450.0, 450.0]),
    alpha=1.0,
    sigma_xy=jnp.array([50.0, 50.0]),
    sigma_rgb=jnp.array([10.0, 10.0, 10.0]),
    n_blobs=10,
    H=H,
    W=W,
)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
# tr = jax.jit(model_simple_continuous.model.simulate)(subkey, (hypers,))

In [None]:
# TODO: attempt to modify hypers in place. this works following what McCoy recommended. It creates a copy but it's convenient.

# hypers = model_simple_continuous.Hyperparams(
#     a_xy=jnp.array([300.0, 300.0]),
#     b_xy=hypers.b_xy,
#     mu_xy=hypers.mu_xy,
#     a_rgb=hypers.a_rgb,
#     b_rgb=hypers.b_rgb,
#     alpha=hypers.alpha,
#     sigma_xy=hypers.sigma_xy,
#     sigma_rgb=hypers.sigma_rgb,
#     n_blobs=hypers.n_blobs,
#     H=hypers.H,
#     W=hypers.W,
# )

# replace(hypers, a_xy=jnp.array([301.0, 301.0]))

## Inference

We will do inference via exact block-Gibbs, using the fact that the model is defined using conjugate pairs.

### Gibbs updates

NEXT STEPS: 
- add one by one the Gibbs updates and test them individually
  - update xy mean
  - update rgb mean
  - update cluster assignment
  - update cluster weight
  - update sigma_xy
  - update sigma_rgb

### Main inference loop

In [None]:
N_ITER = 10
RECORD = True
DEBUG = False
TRIVIAL = True


def id(key, trace):
    return trace


def infer(image, hypers):
    key = jax.random.key(32421)

    # Image pre-processing
    H = hypers.H
    W = hypers.W
    flattened_image = jnp.concatenate(
        (jnp.indices((H, W)).reshape(H * W, 2), image.reshape(H * W, 3)), axis=1
    )
    xy, rgb = flattened_image[:, :2], flattened_image[:, 2:]

    # Setup for better initial trace
    n_blobs = hypers.n_blobs
    obs = C["likelihood_model", "xy"].set(xy) | C["likelihood_model", "rgb"].set(rgb)
    initial_weights = C["blob_model", "mixture_weight"].set(jnp.ones(n_blobs) / n_blobs)
    constraints = obs | initial_weights

    # Sample an initial trace
    key, subkey = jax.random.split(key)
    args = (hypers,)
    tr, _ = jax.jit(model_simple_continuous.model.importance)(subkey, constraints, args)

    # Record info for plotting and debugging purposes
    if RECORD:
        all_posterior_xy_means = [tr.get_choices()["blob_model", "xy_mean"]]
        all_posterior_xy_variances = [tr.get_choices()["blob_model", "sigma_xy"]]
        all_posterior_rgb_means = [tr.get_choices()["blob_model", "rgb_mean"]]
        all_posterior_rgb_variances = [tr.get_choices()["blob_model", "sigma_rgb"]]
        all_cluster_assignment = [tr.get_choices()["likelihood_model", "blob_idx"]]
        all_posterior_weights = [tr.get_choices()["blob_model", "mixture_weight"]]

        if DEBUG:
            jax.debug.print("Initial means: {v}", v=all_posterior_xy_means[0])
            jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        # Main inference loop
        for _ in range(N_ITER):
            # Gibbs update on `("blob_model", "xy_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_xy_mean)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "xy_mean"])

            # Gibbs update on `("blob_model", "sigma_xy", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_xy_sigma)(subkey, tr)
            all_posterior_xy_variances.append(
                tr.get_choices()["blob_model", "sigma_xy"]
            )

            # Gibbs update on `("blob_model", "rgb_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_mean)(subkey, tr)
            all_posterior_rgb_means.append(tr.get_choices()["blob_model", "rgb_mean"])

            # Gibbs update on `("blob_model", "sigma_rgb", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_sigma)(subkey, tr)
            all_posterior_rgb_variances.append(
                tr.get_choices()["blob_model", "sigma_rgb"]
            )

            # Gibbs update on `("likelihood_model", "blob_idx", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_cluster_assignment)(subkey, tr)
            all_cluster_assignment.append(
                tr.get_choices()["likelihood_model", "blob_idx"]
            )

            # Gibbs update on `("blob_model", "mixture_weight", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_mixture_weight)(subkey, tr)
            all_posterior_weights.append(
                tr.get_choices()["blob_model", "mixture_weight"]
            )

        return (
            all_posterior_xy_means,
            all_posterior_xy_variances,
            all_posterior_rgb_means,
            all_posterior_rgb_variances,
            all_posterior_weights,
            all_cluster_assignment,
            tr,
        )


(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_rgb_means,
    all_posterior_rgb_variances,
    all_posterior_weights,
    all_cluster_assignment,
    tr,
) = jax.jit(infer)(image, hypers)

### Visualization

In [None]:
visualization = animation.create_cluster_visualization(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_weights,
    all_posterior_rgb_means,
    all_cluster_assignment,
    num_frames=15,
    pixel_sampling=10,  # Sample every 10th pixel
    confidence_factor=3.0,  # Scale factor for ellipses
    min_weight=0.01,  # Minimum weight threshold for showing clusters
)

visualization

In [None]:
# def create_cluster_visualization(
#     all_posterior_xy_means,
#     all_posterior_xy_variances,
#     all_posterior_weights,
#     all_posterior_rgb_means,
#     all_cluster_assignments,
#     image=None,
#     num_frames=10,
#     pixel_sampling=10,
#     confidence_factor=3.0,
#     min_weight=0.01,
#     max_points=2000,  # New parameter to limit total points for performance
# ):
#     """
#     Create an interactive visualization of clustering results.

#     Parameters:
#     -----------
#     all_posterior_xy_means : list of arrays
#         XY mean positions for each cluster in each iteration
#     all_posterior_xy_variances : list of arrays
#         XY variances for each cluster in each iteration
#     all_posterior_weights : list of arrays
#         Weights for each cluster in each iteration
#     all_posterior_rgb_means : list of arrays
#         RGB mean colors for each cluster in each iteration
#     all_cluster_assignments : list of arrays
#         Cluster assignment for each pixel, where all_cluster_assignments[i][y*W + x]
#         gives the cluster ID for pixel at position (x,y) in frame i
#     image : ndarray, optional
#         Image to visualize (defaults to scipy.datasets.face())
#     num_frames : int, optional
#         Number of frames to show in the animation
#     pixel_sampling : int, optional
#         Sample every Nth pixel in both directions
#     confidence_factor : float, optional
#         Scale factor for ellipse size (larger = bigger ellipses)
#     min_weight : float, optional
#         Minimum weight for a cluster to be displayed
#     max_points : int, optional
#         Maximum number of points to display for performance

#     Returns:
#     --------
#     Plot object
#         The interactive visualization
#     """
#     # Load the default image if none provided
#     if image is None:
#         image = datasets.face()

#     H, W, _ = image.shape

#     # Save the image to a file
#     plt.imsave("face_temp.png", image)

#     # Calculate optimal sampling rate based on max_points
#     total_pixels = H * W
#     adaptive_sampling = max(pixel_sampling, int(np.sqrt(total_pixels / max_points)))

#     # Sample pixels using numpy (vectorized)
#     y_indices = np.arange(0, H, adaptive_sampling)
#     x_indices = np.arange(0, W, adaptive_sampling)
#     Y, X = np.meshgrid(y_indices, x_indices, indexing='ij')

#     # Flatten for easy indexing
#     sampled_y = Y.flatten()
#     sampled_x = X.flatten()
#     sampled_indices = sampled_y * W + sampled_x

#     # Get RGB values directly (much faster than loop)
#     sampled_rgb = image[sampled_y, sampled_x]

#     # Create the sampled_xy array
#     sampled_xy = np.column_stack((sampled_x, sampled_y))
#     total_points = len(sampled_xy)

#     # Pre-compute frame indices to visualize
#     num_iteration = len(all_posterior_xy_means)
#     step = max(1, num_iteration // num_frames)
#     frame_indices = list(range(0, num_iteration, step))

#     # Pre-compute all point assignments and counts (vectorized)
#     all_assignments = []
#     all_point_colors = []

#     for frame_idx in frame_indices:
#         # Get assignments for this frame
#         frame_assignments = all_cluster_assignments[frame_idx][sampled_indices]

#         # Count assignments
#         weights = all_posterior_weights[frame_idx]
#         unique, counts = np.unique(frame_assignments, return_counts=True)
#         assignment_counts = np.zeros(len(weights))

#         # Only count valid assignments (within range of clusters)
#         valid_mask = unique < len(weights)
#         assignment_counts[unique[valid_mask]] = counts[valid_mask]
#         all_assignments.append(assignment_counts.tolist())

#         # Get RGB colors for each cluster
#         rgb_means = all_posterior_rgb_means[frame_idx]

#         # Create a mapping of cluster IDs to RGB values
#         # Handle out-of-range assignments by using a default color
#         default_color = np.array([128, 128, 128])  # Gray for invalid clusters

#         # Pre-allocate point colors array
#         point_colors = np.zeros((len(frame_assignments), 3), dtype=np.uint8)

#         # Only process valid assignments (faster than looping)
#         valid_mask = frame_assignments < len(rgb_means)
#         valid_assignments = frame_assignments[valid_mask]

#         # Set colors for valid assignments
#         point_colors[valid_mask] = np.array([rgb_means[i] for i in valid_assignments])

#         # Set default color for invalid assignments
#         point_colors[~valid_mask] = default_color

#         all_point_colors.append(point_colors)

#     # Prepare data for JavaScript
#     all_weights_js = [all_posterior_weights[i].tolist() for i in frame_indices]
#     all_means_js = [all_posterior_xy_means[i].tolist() for i in frame_indices]
#     all_variances_js = [all_posterior_xy_variances[i].tolist() for i in frame_indices]
#     all_colors_js = [all_posterior_rgb_means[i].tolist() for i in frame_indices]

#     # Convert data to JSON for JavaScript
#     frame_data_js = f"""
#     const allWeights = {json.dumps(all_weights_js)};
#     const allMeans = {json.dumps(all_means_js)};
#     const allVariances = {json.dumps(all_variances_js)};
#     const allColors = {json.dumps(all_colors_js)};
#     const allAssignments = {json.dumps(all_assignments)};
#     const imageWidth = {W};
#     const imageHeight = {H};
#     const numFrames = {len(frame_indices)};
#     const totalPoints = {total_points};
#     const minWeight = {min_weight};
#     """

#     # Function to create a plot for a specific frame (optimize by batching points)
#     def create_frame_plot(idx):
#         frame_idx = frame_indices[idx]

#         # Get data for this frame
#         xy_means = all_posterior_xy_means[frame_idx]
#         xy_variances = all_posterior_xy_variances[frame_idx]
#         weights = all_posterior_weights[frame_idx]
#         rgb_means = all_posterior_rgb_means[frame_idx]

#         # Get point colors for this frame
#         point_colors = all_point_colors[idx]

#         # Get assignments for this frame
#         frame_assignments = all_cluster_assignments[frame_idx][sampled_indices]

#         # Start with a base plot
#         plot = Plot.new(
#             Plot.aspectRatio(1),
#             Plot.hideAxis(),
#             Plot.domain([0, W], [0, H]),
#             {"y": {"reverse": True}},
#             Plot.title(f"Iteration {frame_idx}/{num_iteration - 1}"),
#         )

#         # Add the background image
#         plot += Plot.img(
#             ["face_temp.png"],
#             x=0,
#             y=H,
#             width=W,
#             height=-H,
#             src=Plot.identity,
#             opacity=0.3,
#         )

#         # Add points by cluster batch (much faster than individual points)
#         valid_clusters = np.unique(frame_assignments[frame_assignments < len(weights)])

#         for cluster_id in valid_clusters:
#             # Get points belonging to this cluster
#             cluster_mask = frame_assignments == cluster_id
#             if not np.any(cluster_mask):
#                 continue

#             cluster_x = sampled_x[cluster_mask]
#             cluster_y = sampled_y[cluster_mask]

#             # Use cluster color
#             color = rgb_means[cluster_id]
#             color_str = f"rgb({color[0]}, {color[1]}, {color[2]})"

#             # Add all points for this cluster in one batch
#             plot += Plot.dot(
#                 {"x": cluster_x.tolist(), "y": cluster_y.tolist()},
#                 {
#                     "data-point": "",
#                     "data-cluster": str(cluster_id),
#                     "fill": color_str,
#                     "fillOpacity": 0.6,
#                     "stroke": "black",
#                     "strokeOpacity": 0.4,
#                     "r": 2.5,
#                 }
#             )

#         # Add cluster ellipses and centers
#         for i in range(len(weights)):
#             if weights[i] < min_weight:
#                 continue

#             # Get cluster data
#             mean_x, mean_y = xy_means[i]
#             var_x, var_y = xy_variances[i]
#             color = rgb_means[i]
#             color_str = f"rgb({color[0]}, {color[1]}, {color[2]})"

#             # Calculate ellipse
#             std_x = np.sqrt(var_x)
#             std_y = np.sqrt(var_y)

#             # Create ellipse using parametric equation
#             theta = np.linspace(0, 2 * np.pi, 50)
#             ellipse_x = mean_x + confidence_factor * std_x * np.cos(theta)
#             ellipse_y = mean_y + confidence_factor * std_y * np.sin(theta)

#             # Add ellipse
#             plot += Plot.line(
#                 {"x": ellipse_x.tolist(), "y": ellipse_y.tolist()},
#                 {
#                     "data-cluster": str(i),
#                     "stroke": color_str,
#                     "strokeWidth": 2,
#                     "fill": color_str,
#                     "fillOpacity": 0.2,
#                 }
#             )

#             # Add cluster center
#             size = 5 + weights[i] * 50
#             plot += Plot.dot(
#                 {"x": [float(mean_x)], "y": [float(mean_y)]},
#                 {
#                     "data-cluster": str(i),
#                     "fill": color_str,
#                     "r": size,
#                     "stroke": "black",
#                     "strokeWidth": 1,
#                     "symbol": "star",
#                 }
#             )

#         return plot

#     # Create all the frames
#     frames = [create_frame_plot(i) for i in range(len(frame_indices))]

#     # Return the animation with legend and interactive controls
#     return Plot.html([
#         "div",
#         {"className": "grid grid-cols-3 gap-4 p-4"},
#         [
#             "div",
#             {"className": "col-span-2"},
#             Plot.Frames(frames),
#         ],
#         [
#             "div",
#             {"className": "col-span-1"},
#             Plot.js(
#                 """function() {
#                 """
#                 + frame_data_js
#                 + """
#                 // Get current frame index
#                 const frame = $state.frame || 0;

#                 // Get data for current frame
#                 const weights = allWeights[frame] || [];
#                 const means = allMeans[frame] || [];
#                 const colors = allColors[frame] || [];
#                 const assignments = allAssignments[frame] || [];

#                 // Sort clusters by weight, filter by minimum weight
#                 const topClusters = weights
#                     .map((weight, idx) => ({
#                         id: idx,
#                         weight: weight,
#                         color: colors[idx] || [0,0,0],
#                         points: assignments[idx] || 0,
#                         percentage: ((assignments[idx] || 0) / totalPoints * 100).toFixed(1)
#                     }))
#                     .filter(c => c.weight >= minWeight)
#                     .sort((a, b) => b.weight - a.weight)
#                     .slice(0, 10);

#                 // Create placeholder rows for consistent height
#                 const placeholders = Array(Math.max(0, 10 - topClusters.length))
#                     .fill(0)
#                     .map(() => ["tr", {"className": "h-8"}, ["td", {"colSpan": 3}, ""]]);

#                 // Function to highlight/unhighlight clusters
#                 if (!$state.highlightCluster) {
#                     $state.highlightCluster = function(id) {
#                         // Unhighlight all elements
#                         document.querySelectorAll('[data-cluster], [data-point]').forEach(el => {
#                             el.style.filter = 'opacity(0.3)';
#                         });

#                         // Highlight the selected cluster
#                         if (id !== null) {
#                             document.querySelectorAll(`[data-cluster="${id}"], [data-point][data-cluster="${id}"]`).forEach(el => {
#                                 el.style.filter = 'opacity(1) drop-shadow(0 0 4px white)';
#                             });
#                         } else {
#                             // Reset all if nothing selected
#                             document.querySelectorAll('[data-cluster], [data-point]').forEach(el => {
#                                 el.style.filter = '';
#                             });
#                         }
#                     };
#                 }

#                 return [
#                     "div", {},
#                     ["h3", {}, `Top Clusters by Weight`],
#                     ["div", {"style": {"height": "400px", "overflow": "auto"}},
#                         ["table", {"className": "w-full mt-2"},
#                             ["thead", ["tr",
#                                 ["th", {"className": "text-left"}, "Cluster"],
#                                 ["th", {"className": "text-left"}, "Weight"],
#                                 ["th", {"className": "text-left"}, "Points (%)"]
#                             ]],
#                             ["tbody",
#                                 ...topClusters.map(cluster =>
#                                     ["tr", {
#                                         "className": "h-8",
#                                         "style": {
#                                             "cursor": "pointer",
#                                             "backgroundColor": $state.hoveredCluster === cluster.id ? "#f0f0f0" : "transparent"
#                                         },
#                                         "onMouseEnter": () => {
#                                             $state.hoveredCluster = cluster.id;
#                                             $state.highlightCluster(cluster.id);
#                                         },
#                                         "onMouseLeave": () => {
#                                             $state.hoveredCluster = null;
#                                             $state.highlightCluster(null);
#                                         }
#                                     },
#                                     ["td", {"className": "py-1"},
#                                         ["div", {"className": "flex items-center"},
#                                             ["div", {
#                                                 "style": {
#                                                     "backgroundColor": `rgb(${cluster.color[0]},${cluster.color[1]},${cluster.color[2]})`,
#                                                     "width": "24px",
#                                                     "height": "24px",
#                                                     "borderRadius": "4px",
#                                                     "border": "1px solid rgba(0,0,0,0.2)",
#                                                     "display": "inline-block",
#                                                     "marginRight": "8px"
#                                                 }
#                                             }],
#                                             `Cluster ${cluster.id}`
#                                         ]
#                                     ],
#                                     ["td", {"className": "py-1"}, cluster.weight.toFixed(4)],
#                                     ["td", {"className": "py-1"}, `${cluster.points} (${cluster.percentage}%)`]
#                                     ]
#                                 ),
#                                 ...placeholders
#                             ]
#                         ]
#                     ]
#                 ];
#             }()"""
#             ),
#         ],
#     ])

In [None]:
# visualization = create_cluster_visualization(
#     all_posterior_xy_means,
#     all_posterior_xy_variances,
#     all_posterior_weights,
#     all_posterior_rgb_means,
#     all_cluster_assignment,
#     num_frames=15,
#     pixel_sampling=10,  # Sample every 10th pixel
#     confidence_factor=3.0,  # Scale factor for ellipses
#     min_weight=0.01,  # Minimum weight threshold for showing clusters
# )

# visualization