In [13]:
from typing import Callable
import jax
import jax.numpy as jnp
import sys

sys.path.append("..")

from util import identity

In [12]:
def transition_matrices(
    T0,
    top_level: int = 4,
    einsum: Callable[[str, ...], jax.Array] = jnp.einsum,
    null_elem: Callable[[jax.Array], jax.Array] = identity,
):
    """
    >>> torch.allclose(
    ...     transition_matrices(0.5*torch.ones([8])[None, :, None, None], top_level=3)[0],
    ...     torch.tensor([[[[0.5000]],[[0.5000]],[[0.5000]],[[0.5000]],[[0.5000]],[[0.5000]],[[0.5000]],[[0.5000]]]]))
    True
    >>> torch.allclose(
    ...     transition_matrices(1.+torch.arange(4.)[None, :, None, None])[1],
    ...     torch.tensor([[[[2.]], [[12.]]]]))
    True
    >>> torch.allclose(
    ...     transition_matrices(0.5*torch.ones([8])[None, :, None, None])[1],
    ...     torch.tensor([[[[0.2500]],[[0.2500]],[[0.2500]],[[0.2500]]]]))
    True
    >>> torch.allclose(
    ...     transition_matrices(0.5*torch.ones([8])[None, :, None, None])[2],
    ...     torch.tensor([[[[0.0625]], [[0.0625]]]]))
    True
    >>> torch.allclose(
    ...     transition_matrices(0.5*torch.ones([8])[None, :, None, None])[3],
    ...     torch.tensor([[[[0.00390625]]]]))
    True
    >>> transition_matrices(0.5*torch.ones([1, 4, 1, 1]))
    [tensor([[[[0.5000]],
    <BLANKLINE>
             [[0.5000]],
    <BLANKLINE>
             [[0.5000]],
    <BLANKLINE>
             [[0.5000]]]]), tensor([[[[0.2500]],
    <BLANKLINE>
             [[0.2500]]]]), tensor([[[[0.0625]]]])]
    """
    T = [T0]
    for i in range(top_level):
        T.append(
            einsum("nxij,nxjk->nxik", T[i][:, 1::2], T[i][:, 0:-1:2]),
        )
    return T