What is the `dimap` combinator?

!! It is only meant to be use by library authors. It is used to implement other combinators such as `switch`, the joint `mix` combinator, and `repeat`.

In [None]:
# Example of rewriting the `mixture` using the `contramap` and `switch` functions
import genjax
import jax.random as random
from genjax import categorical, gen, normal
from genjax._src.core.generative import GenerativeFunction
from genjax._src.core.typing import List, typecheck


@typecheck
def new_mixture_combinator(
    *gen_fns: GenerativeFunction,
) -> GenerativeFunction:
    def argument_mapping(mixture_logits, *args):
        return (mixture_logits, *args)

    # Packing the generative functions using a switch combinator
    inner_combinator_closure = genjax.switch(*gen_fns)

    @gen
    def mixture_model(mixture_logits, *args):
        # The mixture combinator creates a mixture model that samples from a categorical distribution to get the index of the component to sample from
        mix_idx = categorical(logits=mixture_logits) @ "mixture_component"
        # And then calls the inner combinator closure to sample from the selected component
        v = inner_combinator_closure(mix_idx, *args) @ "component_sample"
        return v

    # The `contramap` method is used to wrap the mixture model function with the argument mapping function to create a new generative function
    return mixture_model.contramap(
        argument_mapping, info="Derived combinator (Mixture)"
    )


# To add a version accessible as model.new_mix:
def new_mix(
    self,
    branches: List["GenerativeFunction"],
    *args,
) -> "GenerativeFunction":
    return (
        new_mixture_combinator(self, *branches)(*args)
        if args
        else new_mixture_combinator(self, *branches)
    )


# Testing the rewritten version on an example
@gen
def mixture_model(p):
    logits = (0.4, 0.6)
    arg_1 = (p,)
    arg_2 = (p,)
    component_1 = gen(lambda p: normal(p, 0.1) @ "x")
    component_2 = gen(lambda p: normal(p, 0.2) @ "y")
    a = new_mixture_combinator(component_1, component_2)(logits, arg_1, arg_2) @ "a"
    return a


key = random.PRNGKey(23)
tr = mixture_model.simulate(key, (0.4,))
print(tr.get_retval())

# Testing the version accessible as model.new_switch