# Gen2D

This notebook shows a simple model for clustering a 2D image into different components.

In [None]:
import genstudio.plot as Plot
import gibbs_updates
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import model_simple_continuous
import numpy as np
from scipy import datasets

from genjax import ChoiceMapBuilder as C
from genjax import pretty

pretty()

## Model

### Testing to sample from model

In [None]:
image = datasets.face()
H, W, _ = image.shape

hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([100.0, 100.0]),
    b_xy=jnp.array([10000.0, 10000.0]),
    mu_xy=jnp.array([H / 2, W / 2]),
    a_rgb=jnp.array([25.0, 25.0, 25.0]),
    b_rgb=jnp.array([450.0, 450.0, 450.0]),
    alpha=1.0,
    sigma_xy=jnp.array([50.0, 50.0]),
    sigma_rgb=jnp.array([10.0, 10.0, 10.0]),
    n_blobs=10,
    H=H,
    W=W,
)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
# tr = jax.jit(model_simple_continuous.model.simulate)(subkey, (hypers,))

In [None]:
# TODO: attempt to modify hypers in place. this works following what McCoy recommended. It creates a copy but it's convenient.

hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([300.0, 300.0]),
    b_xy=hypers.b_xy,
    mu_xy=hypers.mu_xy,
    a_rgb=hypers.a_rgb,
    b_rgb=hypers.b_rgb,
    alpha=hypers.alpha,
    sigma_xy=hypers.sigma_xy,
    sigma_rgb=hypers.sigma_rgb,
    n_blobs=hypers.n_blobs,
    H=hypers.H,
    W=hypers.W,
)

# replace(hypers, a_xy=jnp.array([301.0, 301.0]))

## Inference

We will do inference via exact block-Gibbs, using the fact that the model is defined using conjugate pairs.

### Gibbs updates

NEXT STEPS: add one by one the Gibbs updates and test them individually
- update xy mean
- update rgb mean
- update cluster assignment
- update cluster weight
- update sigma_xy
- update sigma_rgb

### Main inference loop

In [None]:
N_ITER = 100
RECORD = True
DEBUG = False
TRIVIAL = True


def id(key, trace):
    return trace


def infer(image, hypers):
    key = jax.random.key(32421)

    # Image pre-processing
    H = hypers.H
    W = hypers.W
    flattened_image = jnp.concatenate(
        (jnp.indices((H, W)).reshape(H * W, 2), image.reshape(H * W, 3)), axis=1
    )
    xy, rgb = flattened_image[:, :2], flattened_image[:, 2:]

    # Setup for better initial trace
    n_blobs = hypers.n_blobs
    obs = C["likelihood_model", "xy"].set(xy) ^ C["likelihood_model", "rgb"].set(rgb)
    initial_weights = C["blob_model", "mixture_weight"].set(jnp.ones(n_blobs) / n_blobs)
    constraints = obs | initial_weights

    # Sample an initial trace
    key, subkey = jax.random.split(key)
    args = (hypers,)
    tr, _ = jax.jit(model_simple_continuous.model.importance)(subkey, constraints, args)

    # Record info for plotting and debugging purposes
    if RECORD:
        all_posterior_xy_means = [tr.get_choices()["blob_model", "xy_mean"]]
        all_posterior_xy_variances = [tr.get_choices()["blob_model", "sigma_xy"]]
        all_posterior_rgb_means = [tr.get_choices()["blob_model", "rgb_mean"]]
        all_posterior_rgb_variances = [tr.get_choices()["blob_model", "sigma_rgb"]]
        all_cluster_assignment = [tr.get_choices()["likelihood_model", "blob_idx"]]
        all_posterior_weights = [tr.get_choices()["blob_model", "mixture_weight"]]

        if DEBUG:
            jax.debug.print("Initial means: {v}", v=all_posterior_xy_means[0])
            jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        # Main inference loop
        for _ in range(N_ITER):
            # Gibbs update on `("blob_model", "xy_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_xy_mean)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "xy_mean"])

            # Gibbs update on `("blob_model", "sigma_xy", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_xy_sigma)(subkey, tr)
            all_posterior_xy_variances.append(
                tr.get_choices()["blob_model", "sigma_xy"]
            )

            # Gibbs update on `("blob_model", "rgb_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_mean)(subkey, tr)
            all_posterior_rgb_means.append(tr.get_choices()["blob_model", "rgb_mean"])

            # Gibbs update on `("blob_model", "sigma_rgb", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_sigma)(subkey, tr)
            all_posterior_rgb_variances.append(
                tr.get_choices()["blob_model", "sigma_rgb"]
            )

            # Gibbs update on `("likelihood_model", "blob_idx", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_cluster_assignment)(subkey, tr)
            all_cluster_assignment.append(
                tr.get_choices()["likelihood_model", "blob_idx"]
            )

            # Gibbs update on `("blob_model", "mixture_weight", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_mixture_weight)(subkey, tr)
            all_posterior_weights.append(
                tr.get_choices()["blob_model", "mixture_weight"]
            )

        return (
            all_posterior_xy_means,
            all_posterior_xy_variances,
            all_posterior_rgb_means,
            all_posterior_rgb_variances,
            all_posterior_weights,
            all_cluster_assignment,
            tr,
        )


image = datasets.face()
hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([100.0, 100.0]),
    b_xy=jnp.array([10000.0, 10000.0]),
    mu_xy=jnp.array([H / 2, W / 2]),
    a_rgb=jnp.array([25.0, 25.0, 25.0]),
    b_rgb=jnp.array([450.0, 450.0, 450.0]),
    alpha=1.0,
    sigma_xy=jnp.array([1.0, 1.0]),
    sigma_rgb=jnp.array([1.0, 1.0, 1.0]),
    n_blobs=10,
    H=H,
    W=W,
)

(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_rgb_means,
    all_posterior_rgb_variances,
    all_posterior_weights,
    all_cluster_assignment,
    tr,
) = jax.jit(infer)(image, hypers)

### Visualization

In [None]:
NUM_FRAMES = 10
import json
import math

from scipy.spatial.distance import cdist

# Load the face image
image = datasets.face()
H, W, _ = image.shape

# Save the image to a file
plt.imsave("face_temp.png", image)

# Sample pixels every 10 steps
sampled_pixels = []
for y in range(0, H, 10):
    for x in range(0, W, 10):
        sampled_pixels.append([x, y, *image[y, x]])

sampled_pixels = np.array(sampled_pixels)
sampled_xy = sampled_pixels[:, 0:2]
total_points = len(sampled_xy)

# Calculate assignments for each frame
all_assignments = []
for frame_idx in range(len(all_posterior_xy_means)):
    # Get data for this frame
    xy_means = all_posterior_xy_means[frame_idx]
    weights = all_posterior_weights[frame_idx]

    # Calculate distance from each pixel to each cluster center
    distances = cdist(sampled_xy, xy_means)

    # Weight by cluster weights
    weighted_distances = distances - np.log(weights + 1e-10)[:, np.newaxis].T

    # Assign each pixel to closest weighted cluster
    assignments = np.argmin(weighted_distances, axis=1)

    # Count assignments per cluster
    unique, counts = np.unique(assignments, return_counts=True)
    assignment_counts = np.zeros(len(weights))
    assignment_counts[unique] = counts

    all_assignments.append(assignment_counts.tolist())

# Prepare data for JavaScript
all_weights_js = []
all_means_js = []
all_variances_js = []
all_colors_js = []

# Prepare data for all frames
for frame_idx in range(len(all_posterior_xy_means)):
    # Get data for this frame
    weights = all_posterior_weights[frame_idx].tolist()
    xy_means = all_posterior_xy_means[frame_idx].tolist()
    xy_variances = all_posterior_xy_variances[frame_idx].tolist()
    rgb_means = all_posterior_rgb_means[frame_idx].tolist()

    all_weights_js.append(weights)
    all_means_js.append(xy_means)
    all_variances_js.append(xy_variances)
    all_colors_js.append(rgb_means)

# Convert data to JSON for JavaScript
frame_data_js = f"""
const allWeights = {json.dumps(all_weights_js)};
const allMeans = {json.dumps(all_means_js)};
const allVariances = {json.dumps(all_variances_js)};
const allColors = {json.dumps(all_colors_js)};
const allAssignments = {json.dumps(all_assignments)};
const imageWidth = {W};
const imageHeight = {H};
const numFrames = {len(all_posterior_xy_means)};
const totalPoints = {total_points};
"""


# Function to create a plot for a specific frame
def create_frame_plot(frame_idx):
    # Get data for this frame
    xy_means = all_posterior_xy_means[frame_idx]
    xy_variances = all_posterior_xy_variances[frame_idx]
    weights = all_posterior_weights[frame_idx]
    rgb_means = all_posterior_rgb_means[frame_idx]

    # Start with a base plot
    plot = Plot.new(
        Plot.aspectRatio(1),
        Plot.hideAxis(),
        Plot.domain([0, W], [0, H]),
        {"y": {"reverse": True}},
        Plot.title(f"Iteration {frame_idx}/{len(all_posterior_xy_means) - 1}"),
    )

    # Add the background image (slightly more transparent)
    plot += Plot.img(
        ["face_temp.png"], x=0, y=H, width=W, height=-H, src=Plot.identity, opacity=0.3
    )

    # Sample pixels every 10 steps and add as dots
    x_values = []
    y_values = []
    colors = []

    for y in range(0, H, 10):
        for x in range(0, W, 10):
            x_values.append(x)
            y_values.append(y)
            # Get RGB color from the image
            r, g, b = image[y, x]
            colors.append(f"rgb({r},{g},{b})")

    # Add sampled pixels as dots
    plot += Plot.dot(
        {"x": x_values, "y": y_values}, {"r": 2, "fill": colors, "stroke": "none"}
    )

    # Add standard deviation ellipses and cluster centers
    for i in range(len(xy_means)):
        # Skip clusters with very small weights
        if weights[i] < 0.01:
            continue

        # Get RGB color
        rgb = rgb_means[i].astype(float)
        color = f"rgb({rgb[0]},{rgb[1]},{rgb[2]})"

        # Extract and verify variance values
        var_x = float(xy_variances[i][0])
        var_y = float(xy_variances[i][1])

        # Ensure variances are positive
        var_x = max(var_x, 1.0)
        var_y = max(var_y, 1.0)

        # Calculate standard deviations
        x_stddev = math.sqrt(var_x)
        y_stddev = math.sqrt(var_y)

        # Scale to make them more visible
        confidence_factor = 3.0
        x_stddev *= confidence_factor
        y_stddev *= confidence_factor

        # Create points for an ellipse
        theta = np.linspace(0, 2 * np.pi, 30)
        ellipse_x = xy_means[i][0] + x_stddev * np.cos(theta)
        ellipse_y = xy_means[i][1] + y_stddev * np.sin(theta)

        # Add data-cluster attribute for highlighting
        ellipse_attrs = {
            "data-cluster": str(i),
            "stroke": color,
            "strokeWidth": 2,
            "fill": color,
            "fillOpacity": 0.2,
        }

        # Add ellipse to represent standard deviation
        plot += Plot.line(
            {"x": ellipse_x.tolist(), "y": ellipse_y.tolist()}, ellipse_attrs
        )

        # Add the cluster center as a star
        size = 5 + weights[i] * 50
        dot_attrs = {
            "data-cluster": str(i),
            "fill": color,
            "r": size,
            "stroke": "black",
            "strokeWidth": 1,
            "symbol": "star",
        }

        plot += Plot.dot(
            {"x": [float(xy_means[i][0])], "y": [float(xy_means[i][1])]}, dot_attrs
        )

    return plot


# Create a list of frames to visualize
# Let's use every 5th frame to avoid too many frames
num_iteration = len(all_posterior_xy_means)
step = max(1, num_iteration // NUM_FRAMES)  # Show about 20 frames total
frame_indices = range(0, num_iteration, step)

# Create all the frames
frames = [create_frame_plot(idx) for idx in frame_indices]

# Add the animation with legend and interactive controls
Plot.html([
    "div",
    {"className": "grid grid-cols-3 gap-4 p-4"},
    [
        "div",
        {"className": "col-span-2"},
        Plot.Frames(frames),
    ],
    [
        "div",
        {"className": "col-span-1"},
        Plot.js(
            """function() {
            """
            + frame_data_js
            + """
            // Get current frame index
            const frame = $state.frame || 0;
            
            // Get data for current frame
            const weights = allWeights[frame] || [];
            const means = allMeans[frame] || [];
            const colors = allColors[frame] || [];
            const assignments = allAssignments[frame] || [];
            
            // Sort clusters by weight, filter by minimum weight
            const topClusters = weights
                .map((weight, idx) => ({ 
                    id: idx, 
                    weight: weight, 
                    color: colors[idx] || [0,0,0],
                    points: assignments[idx] || 0,
                    percentage: ((assignments[idx] || 0) / totalPoints * 100).toFixed(1)
                }))
                .filter(c => c.weight >= 0.01)
                .sort((a, b) => b.weight - a.weight)
                .slice(0, 10);
            
            // Create placeholder rows for consistent height
            const placeholders = Array(Math.max(0, 10 - topClusters.length))
                .fill(0)
                .map(() => ["tr", {"className": "h-8"}, ["td", {"colSpan": 3}, ""]]);
                
            // Function to highlight/unhighlight clusters
            if (!$state.highlightCluster) {
                $state.highlightCluster = function(id) {
                    // Unhighlight all clusters
                    document.querySelectorAll('[data-cluster]').forEach(el => {
                        el.style.filter = 'opacity(0.4)';
                    });
                    
                    // Highlight the selected cluster
                    if (id !== null) {
                        document.querySelectorAll(`[data-cluster="${id}"]`).forEach(el => {
                            el.style.filter = 'opacity(1) drop-shadow(0 0 5px white)';
                        });
                    } else {
                        // Reset all if nothing selected
                        document.querySelectorAll('[data-cluster]').forEach(el => {
                            el.style.filter = '';
                        });
                    }
                };
            }
                
            return [
                "div", {},
                ["h3", {}, `Top Clusters by Weight`],
                ["div", {"style": {"height": "400px", "overflow": "auto"}},
                    ["table", {"className": "w-full mt-2"},
                        ["thead", ["tr",
                            ["th", {"className": "text-left"}, "Cluster"],
                            ["th", {"className": "text-left"}, "Weight"],
                            ["th", {"className": "text-left"}, "Points (%)"]
                        ]],
                        ["tbody",
                            ...topClusters.map(cluster => 
                                ["tr", {
                                    "className": "h-8",
                                    "style": {
                                        "cursor": "pointer",
                                        "backgroundColor": $state.hoveredCluster === cluster.id ? "#f0f0f0" : "transparent"
                                    },
                                    "onMouseEnter": () => { 
                                        $state.hoveredCluster = cluster.id; 
                                        $state.highlightCluster(cluster.id);
                                    },
                                    "onMouseLeave": () => { 
                                        $state.hoveredCluster = null; 
                                        $state.highlightCluster(null);
                                    }
                                },
                                ["td", {"className": "py-1"}, 
                                    ["div", {"className": "flex items-center"},
                                        ["div", {
                                            "style": {
                                                "backgroundColor": `rgb(${cluster.color[0]},${cluster.color[1]},${cluster.color[2]})`,
                                                "width": "24px",
                                                "height": "24px",
                                                "borderRadius": "4px",
                                                "border": "1px solid rgba(0,0,0,0.2)",
                                                "display": "inline-block",
                                                "marginRight": "8px"
                                            }
                                        }],
                                        `Cluster ${cluster.id}`
                                    ]
                                ],
                                ["td", {"className": "py-1"}, cluster.weight.toFixed(4)],
                                ["td", {"className": "py-1"}, `${cluster.points} (${cluster.percentage}%)`]
                                ]
                            ),
                            ...placeholders
                        ]
                    ]
                ]
            ];
        }()"""
        ),
    ],
])