GenJAX is a library built on top of JAX and Python. It is designed to be a flexible library for Bayesian modelling and inference. It inherits the benefits of JAX to be jittable, GPU-accelerated, compatible with automatic differentiation and other convenient functions such as automatic vectorization.
Throughout, we will use GenStudio for visualizing what we are doing.

The main goal of Bayesian inference is to be able to sample from complex distributions, which are often implicitly defined. In this tutorial we will look at a very simple visual example and see a variety of ways to solve this problem using GenJAX. Some of these techniques will be overkill for this simple problem but will scale much better to higher dimensional situations. 

Let's import some libraries that will be useful throughout.

In [None]:
import genstudio.plot as Plot
import jax
import jax.numpy as jnp
import matplotlib.image as mpimg
import numpy as np

from genjax import ChoiceMapBuilder as C
from genjax import Pytree, gen, pretty
from genjax._src.generative_functions.distributions.distribution import Distribution
from genjax.typing import FloatArray, PRNGKey

Plot.configure(display_as="html")
pretty()
key = jax.random.key(0)

Here's a simple image representing a black and white version of the GenJAX logo, with some GenStudio visualization setup.

In [None]:
# Read the image
im = mpimg.imread("../../../docs/assets/img/logo.png")

# Convert to mask
im = np.amax(im[:, :, :2], 2) < 0.9
# Invert the image
im = np.logical_not(im)

# Convert back to float
im = im.astype(float)
height, width = im.shape

# Save the processed image back to a file
mpimg.imsave("../../../docs/assets/img/logo_bw.png", im, cmap="gray")

base_plot = Plot.new(
    Plot.aspectRatio(1),
    Plot.hideAxis(),
    Plot.domain([0, width], [0, height]),
    {"y": {"reverse": True}},
)

logo_plot = Plot.img(
    ["../../../docs/assets/img/logo_bw.png"],
    x=0,
    y=height,
    width=width,
    height=-height,
    src=Plot.identity,
)


def dots_plot(x, y, fill="black", r=2, label=None, opacity=0.5, **kwargs):
    opts = {"fill": fill, "r": r, "opacity": opacity} | kwargs
    plot = base_plot + Plot.dot({"x": x, "y": y}, opts)
    if label is not None:
        plot += Plot.subtitle(label)
    return plot


base_plot + logo_plot

### Interpreting the logo as a distribution
We can interpret this logo as a distribution defined on the image rectangle as follows. 
A distribution on is defined by its density function at every point, i.e. we need a non-negative value per point of the image.
Imagine every point on the black region has a density of 1 and the white region a density of 0. 

### Link to energy function
This defines what's sometimes called an energy function, which is a common concept in physics and mathematics. In the context of distributions, an energy function can be thought of as a way to assign a "cost" or "energy" to each point in the space. The density of the distribution is then inversely proportional to the energy at each point. In this case, the energy function is binary, with 0 energy in the white regions and infinite energy in the black regions, effectively defining a distribution with density 1 in the black regions and 0 in the white regions.

### Renormalizing to a distribution
The density function is currently unnormalized, as there is no inherent reason for the total density to sum to 1, a fundamental property of distributions.
The total mass is currently $$M = \sum_{x\in \textbf{LOGO}} 1= \int_{x\in \textbf{LOGO}}1.dx$$
We can renormalize the density by assigning density 0 to every point in the white region and $\frac{1}{M}$ to every point in the black region.



This way, we have defined a target distribution. Our goal will be to generate samples from that distribution. We can start packaging this target as distribution using GenJAX.
We will define a custom distribution for that. Do not worry about the implementation details for now, we will come back to it in due time.

In [None]:
@Pytree.dataclass
class Logo(Distribution):
    # TODO: should use ExactDensity instead of Distribution here. But there is a bug.
    image: FloatArray = Pytree.static()
    threshold: FloatArray = Pytree.static(default=1e2)
    likelihood_multiplier: FloatArray = Pytree.static(default=20.0)

    def log_likelihood(self, x, y):
        floor_x, floor_y = jnp.floor(x), jnp.floor(y)
        floor_x, floor_y = (
            jnp.astype(floor_x, jnp.int32),
            jnp.astype(floor_y, jnp.int32),
        )
        out_of_bounds = (
            (floor_x < 0) | (floor_x >= width) | (floor_y < 0) | (floor_y >= height)
        )
        value = jax.lax.cond(
            out_of_bounds,
            lambda *_: -self.threshold,
            lambda arg: self.likelihood_multiplier * (1.0 - self.image[arg[1], arg[0]])
            - jnp.log(height * width),
            operand=(floor_x, floor_y),
        )
        return value

    # The logo defines a density but we can't easily sample from it.
    # Fortunately, we won't need to, but we need to implement something to
    # satisfy the interface.
    # In the future, we should have a generalization of `Distribution` that
    # only need to support estimating logpdf.
    def random_weighted(self, key: PRNGKey):
        key, subkey = jax.random.split(key)
        x = jax.random.uniform(key, minval=0, maxval=width)
        y = jax.random.uniform(subkey, minval=0, maxval=height)
        logpdf = self.log_likelihood(x, y)
        return -logpdf, (x, y)

    def estimate_logpdf(self, key: PRNGKey, z):
        x, y = z
        return self.log_likelihood(x, y)


im_jax = jnp.array(im.astype(float))
logo = Logo(image=im_jax)

The key object of interest in GenJAX is a generative function. 
Make sure you have done the tutorial on generative functions before moving forward.

In our case, we want to obtain samples from the Logo distribution.
The description is simply that there is a random variable that represents a sample from this distribution.

In [None]:
@gen
def model():
    z = logo() @ "z"
    return None

A first thing we can do is to evaluate the likelihood of a point under the model.

In [None]:
chm = C["z"].set((40.0, 100.0))
model.update(chm, ())

# model.simulate(key, ())

Even though we do not have observations, this is still an inference problem: our goal is to produce samples from a distribution from which we cannot directly sample.

The first way we recommend attacking such a  problem is by producing samples from an exact surrogate model. 
That is we simplify the problem in a controlled way and sample from that surrogate model.

In [None]:
import numpy as np


def simplify_image_to_rectangles(image):
    """
    Simplify a black and white image into a union of black rectangles.

    Parameters:
    - image: A 2D numpy array representing the black and white image.
    - num_rectangles: The number of rectangles to simplify the image into.

    Returns:
    - A list of rectangles, each represented as a tuple of (x, y, width, height).
    """
    # Find connected components in the image
    from skimage.measure import label, regionprops

    labeled_image = label(image, connectivity=2)
    regions = regionprops(labeled_image)
    simplified_region = regions[0]
    print(len(regions))

    return simplified_region
    # # Sort regions by area in descending order
    # for region in regions:
    #     from skimage.io import imshow, show
    #     imshow(region.image)
    #     show()
    # print(len(regions))

    # # Select the top num_rectangles regions
    # selected_regions = regions[:num_rectangles]

    # # Convert regions to rectangles
    # rectangles = []
    # for region in selected_regions:
    #     min_row, min_col, max_row, max_col = region.bbox
    #     width = max_col - min_col
    #     height = max_row - min_row
    #     rectangles.append((min_col, min_row, width, height))

    # return rectangles


# Example usage
# Assuming 'im' is a 2D numpy array representing the black and white image
# simplified_rectangles = simplify_image_to_rectangles(im, 10)
region = simplify_image_to_rectangles(im)
# imshow(region.image)

TODO: stratified sampling, custom proposal, adev?, etc.