In [1]:
import lovely_tensors as lt
from lovely_numpy import lovely, set_config
from rich import print

lt.monkey_patch()
set_config(repr=lovely)
print("Loaded")

In [2]:
from ll.actsave import ActivationLoader

escn = ActivationLoader.from_latest_version("/mnt/shared/ocp-actsave/escn/")
# print(escn)

rhescn = ActivationLoader.from_latest_version("/mnt/shared/ocp-actsave/rhescnv8clean/")
# print(rhescn)


def toplevel_keys(activations):
    return [k for k in activations.keys() if k.count(".") <= 1]


print(
    {
        "escn": toplevel_keys(escn.activations),
        "rhescn": toplevel_keys(rhescn.activations),
    }
)

In [3]:
import numpy as np
from bidict import bidict
from einops import rearrange
from IPython.display import Markdown, display

LMAX = 4
MMAX = 2
RH_MMAX = 4


def trimask(x: np.ndarray, lmax: int = LMAX):
    return x[:, : (lmax + 1) ** 2]


def _escnindices(lmax: int, mmax: int | None):
    coeffs = bidict[tuple[int, int], int]()
    for l in range(lmax + 1):
        for m in range(-l, l + 1):
            if mmax is not None and abs(m) > mmax:
                continue

            coeffs[(l, m)] = len(coeffs)
    return coeffs


def trimask_mmax(
    x: np.ndarray,
    lmax: int = LMAX,
    mmax: int = MMAX,
):
    indices: list[int] = []
    i = 0
    for l in range(lmax + 1):
        for m in range(-l, l + 1):
            if abs(m) > mmax:
                i += 1
                continue

            indices.append(i)
            i += 1

    return x[:, indices]


def rhmask(
    x: np.ndarray,
    lmax: int = LMAX,
    mmax: int | None = MMAX,
    rh_mmax: int = RH_MMAX,
):
    if x.ndim == 3:
        x = rearrange(
            x,
            "E (m two_sign l) C -> E m two_sign l C",
            m=rh_mmax + 1,
            two_sign=2,
        )

    idx = _escnindices(lmax, mmax)
    x_masked = np.zeros((x.shape[0], len(idx), *x.shape[4:]), dtype=x.dtype)
    for (l, m), i in idx.items():
        signidx = int(np.signbit(m).item())
        x_masked[:, i] = x[:, abs(m), signidx, l - abs(m)]
    return x_masked

In [4]:
display(Markdown("## Wigner"))

x_rh = rhescn.activations["full_wigner"][0]
x_rh = trimask(x_rh).swapaxes(-1, -2)
x_rh = trimask(x_rh).swapaxes(-1, -2)


print(
    {
        "rhescn": x_rh,
        "escn": (x_escn := escn.activations["wigner"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Wigner

In [5]:
display(Markdown("## Wigner"))

x_rh = rhescn.activations["wigner_tri_to_rh"][0]
x_rh = rearrange(
    x_rh,
    "E (m1 two1 l1) l_sq -> E m1 two1 l1 l_sq",
    m1=RH_MMAX + 1,
    two1=2,
)
x_rh = rhmask(x_rh, mmax=None).swapaxes(-1, -2)
x_rh = trimask(x_rh).swapaxes(-1, -2)


print(
    {
        "rhescn": x_rh,
        "escn": (x_escn := escn.activations["wigner"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Wigner

In [6]:
display(Markdown("## Embedding Layer"))
print(
    {
        "rhescn": (x_rh := rhescn.activations["M0L0Embedding.x_l0m0"][0]),
        "escn": (x_escn := escn.activations["x_embedding_l0m0"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

print(
    {
        "rhescn": (x_rh := trimask(rhescn.activations["M0L0Embedding.x"][0])),
        "escn": (x_escn := escn.activations["x_embedding"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

## Embedding Layer

In [7]:
display(Markdown("## Layer 0 x_edge"))
print(
    {
        "rhescn": (x_rh := rhescn.activations["LayerBlock_0.MessageBlock.x_edge"][0]),
        "escn": (x_escn := escn.activations["LayerBlock_0.message_block.x_edge"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

## Layer 0 x_edge

In [8]:
display(Markdown("## Layer 0 x_source"))
print(
    {
        "rhescn": (
            x_rh := trimask(rhescn.activations["LayerBlock_0.MessageBlock.x_source"][0])
        ),
        "escn": (x_escn := escn.activations["LayerBlock_0.message_block.x_source"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

display(Markdown("## Layer 0 x_target"))
print(
    {
        "rhescn": (
            x_rh := trimask(rhescn.activations["LayerBlock_0.MessageBlock.x_target"][0])
        ),
        "escn": (x_escn := escn.activations["LayerBlock_0.message_block.x_target"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

## Layer 0 x_source

## Layer 0 x_target

In [9]:
display(Markdown("## Layer 0 x_source_rot"))
print(
    {
        "rhescn": (
            x_rh := rhmask(
                rhescn.activations["LayerBlock_0.MessageBlock.x_source_rot"][0]
            )
        ),
        "escn": (
            x_escn := escn.activations["LayerBlock_0.message_block.x_source_rot"][0]
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

display(Markdown("## Layer 0 x_target_rot"))
print(
    {
        "rhescn": (
            x_rh := rhmask(
                rhescn.activations["LayerBlock_0.MessageBlock.x_target_rot"][0]
            )
        ),
        "escn": (
            x_escn := escn.activations["LayerBlock_0.message_block.x_target_rot"][0]
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

## Layer 0 x_source_rot

## Layer 0 x_target_rot

In [10]:
display(Markdown("## Layer 0 x_source_updated"))
print(
    {
        "rhescn": (
            x_rh := rhmask(
                rhescn.activations[
                    "LayerBlock_0.MessageBlock.so2_block_source.x_source_updated"
                ][0]
            )
        ),
        "escn": (
            x_escn := escn.activations["LayerBlock_0.message_block.x_source_post_so2"][
                0
            ]
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

display(Markdown("## Layer 0 x_target_updated"))
print(
    {
        "rhescn": (
            x_rh := rhmask(
                rhescn.activations[
                    "LayerBlock_0.MessageBlock.so2_block_target.x_target_updated"
                ][0]
            )
        ),
        "escn": (
            x_escn := escn.activations["LayerBlock_0.message_block.x_target_post_so2"][
                0
            ]
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

## Layer 0 x_source_updated

## Layer 0 x_target_updated

In [11]:
display(Markdown("## Layer 0 x_updated"))
print(
    {
        "rhescn": (
            x_rh := rhmask(rhescn.activations["LayerBlock_0.MessageBlock.x_updated"][0])
        ),
        "escn": (x_escn := escn.activations["LayerBlock_0.message_block.x_updated"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn)

## Layer 0 x_updated

In [12]:
display(Markdown("## Layer 0 grid_act to_grid"))


print(
    {
        "rhescn": (
            x_rh := (
                rearrange(
                    rhescn.activations[
                        "LayerBlock_0.MessageBlock.act_rotate_inv.x_grid"
                    ][0],
                    "E (res_beta res_alpha) C -> E res_beta res_alpha C",
                    res_beta=10,
                )
            )
        ),
        "escn": (
            x_escn := (
                escn.activations["LayerBlock_0.message_block.grid_act.x_grid"][0]
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 grid_act to_grid

In [13]:
display(Markdown("## Layer 0 grid_act silu"))


print(
    {
        "rhescn": (
            x_rh := (
                rearrange(
                    rhescn.activations[
                        "LayerBlock_0.MessageBlock.act_rotate_inv.x_grid_silu"
                    ][0],
                    "E (res_beta res_alpha) C -> E res_beta res_alpha C",
                    res_beta=10,
                )
            )
        ),
        "escn": (
            x_escn := (
                escn.activations["LayerBlock_0.message_block.grid_act.x_grid_silu"][0]
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 grid_act silu

In [14]:
display(Markdown("## Layer 0 grid_act to_sphere"))


print(
    {
        "rhescn": (
            x_rh := trimask_mmax(
                rhescn.activations[
                    "LayerBlock_0.MessageBlock.act_rotate_inv.x_sphere_silu"
                ][0]
            )
        ),
        "escn": (
            x_escn := (
                escn.activations["LayerBlock_0.message_block.grid_act.x_sphere_silu"][0]
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 grid_act to_sphere

In [15]:
display(Markdown("## Layer 0 rotate_inv"))


print(
    {
        "rhescn": (
            x_rh := trimask(
                rhescn.activations[
                    "LayerBlock_0.MessageBlock.act_rotate_inv.x_sphere_rotated"
                ][0]
            )
        ),
        "escn": (
            x_escn := (
                escn.activations["LayerBlock_0.message_block.x_target_rot_inv"][0]
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 rotate_inv

In [16]:
display(Markdown("## Layer 0 scatter"))


print(
    {
        "rhescn": (
            x_rh := trimask(
                rhescn.activations["LayerBlock_0.MessageBlock.x_scattered"][0]
            )
        ),
        "escn": (
            x_escn := (
                escn.activations["LayerBlock_0.message_block.x_target_reduce"][0]
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 scatter

In [17]:
display(Markdown("## Layer 0 x_message"))


print(
    {
        "rhescn": (x_rh := trimask(rhescn.activations["LayerBlock_0.x_message"][0])),
        "escn": (
            x_escn := (
                escn.activations["LayerBlock_0.message_block.x_target_reduce"][0]
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 x_message

In [18]:
display(Markdown("## Layer 0 conv"))


print(
    {
        "rhescn": (
            x_rh := trimask(
                rhescn.activations[
                    "LayerBlock_0.pointwise_grid_conv.x_message_updated"
                ][0]
            )
        ),
        "escn": (x_escn := (escn.activations["LayerBlock_0.x_message_updated"][0])),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 conv

In [19]:
display(Markdown("## Layer 0 x_message after SO2"))


print(
    {
        "rhescn": (x_rh := trimask(rhescn.activations["LayerBlock_0.x_message"][0])),
        "escn": (x_escn := escn.activations["LayerBlock_0.x_message"][0]),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 x_message after SO2

In [20]:
display(Markdown("## Layer 0 conv: x_grid"))


print(
    {
        "rhescn": (
            x_rh := rhescn.activations[
                "LayerBlock_0.pointwise_grid_conv.s2_conv.x_grid"
            ][0]
        ),
        "escn": (
            x_escn := rearrange(
                escn.activations["LayerBlock_0.grid_conv.x_grid"][0],
                "N res_beta res_alpha C -> N (res_beta res_alpha) C",
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 conv: x_grid

In [21]:
display(Markdown("## Layer 0 conv: x_grid_message"))


print(
    {
        "rhescn": (
            x_rh := rhescn.activations[
                "LayerBlock_0.pointwise_grid_conv.s2_conv.x_message_grid"
            ][0]
        ),
        "escn": (
            x_escn := rearrange(
                escn.activations["LayerBlock_0.grid_conv.x_grid_message"][0],
                "N res_beta res_alpha C -> N (res_beta res_alpha) C",
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 conv: x_grid_message

In [22]:
display(Markdown("## Layer 0 conv: to_grid_mat"))


print(
    {
        "rhescn": (
            x_rh := trimask(
                rhescn.activations["LayerBlock_0.pointwise_grid_conv.to_grid_sh_tri"][0]
            )
        ),
        "escn": (
            x_escn := rearrange(
                escn.activations[
                    "LayerBlock_0.grid_conv.x_message_to_grid.to_grid_mat"
                ][0],
                "res_beta res_alpha l_sq -> (res_beta res_alpha) l_sq",
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 conv: to_grid_mat

In [23]:
display(Markdown("## Layer 0 conv: x_grid_conv"))


print(
    {
        "rhescn": (
            x_rh := rhescn.activations[
                "LayerBlock_0.pointwise_grid_conv.s2_conv.x_grid_conv"
            ][0]
        ),
        "escn": (
            x_escn := rearrange(
                escn.activations["LayerBlock_0.grid_conv.x_grid_conv"][0],
                "N res_beta res_alpha C -> N (res_beta res_alpha) C",
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 conv: x_grid_conv

In [24]:
display(Markdown("## Layer 0 conv: x_message_from_grid"))


print(
    {
        "rhescn": (
            x_rh := trimask(
                rhescn.activations[
                    "LayerBlock_0.pointwise_grid_conv.x_message_updated"
                ][0]
            )
        ),
        "escn": (
            x_escn := escn.activations["LayerBlock_0.grid_conv.x_message_from_grid"][0]
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0 conv: x_message_from_grid

In [25]:
for layer_idx in range(12):
    display(Markdown(f"## Layer {layer_idx}: Full"))

    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.pointwise_grid_conv.x_message_updated"
                    ][0]
                )
            ),
            "escn": (
                x_escn := (
                    escn.activations[f"LayerBlock_{layer_idx}.x_message_updated"][0]
                )
            ),
        },
    )
    # np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 0: Full

## Layer 1: Full

## Layer 2: Full

## Layer 3: Full

## Layer 4: Full

## Layer 5: Full

## Layer 6: Full

## Layer 7: Full

## Layer 8: Full

## Layer 9: Full

## Layer 10: Full

## Layer 11: Full

In [28]:
def check_layer(layer_idx: int):
    display(Markdown(f"## Layer {layer_idx} x_edge"))
    print(
        {
            "rhescn": (
                x_rh := rhescn.activations[
                    f"LayerBlock_{layer_idx}.MessageBlock.x_edge"
                ][0]
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_edge"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1.0e-4)

    display(Markdown(f"## Layer {layer_idx} x_source"))
    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[f"LayerBlock_{layer_idx}.MessageBlock.x_source"][
                        0
                    ]
                )
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_source"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1.0e-4)

    display(Markdown(f"## Layer {layer_idx} x_target"))
    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[f"LayerBlock_{layer_idx}.MessageBlock.x_target"][
                        0
                    ]
                )
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_target"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1.0e-4)

    display(Markdown(f"## Layer {layer_idx} x_source_rot"))
    print(
        {
            "rhescn": (
                x_rh := rhmask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.x_source_rot"
                    ][0]
                )
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_source_rot"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1.0e-4)

    display(Markdown(f"## Layer {layer_idx} x_target_rot"))
    print(
        {
            "rhescn": (
                x_rh := rhmask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.x_target_rot"
                    ][0]
                )
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_target_rot"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1.0e-4)

    display(Markdown(f"## Layer {layer_idx} x_source_updated"))
    print(
        {
            "rhescn": (
                x_rh := rhmask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.so2_block_source.x_source_updated"
                    ][0]
                )
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_source_post_so2"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1.0e-4)

    display(Markdown(f"## Layer {layer_idx} x_target_updated"))
    print(
        {
            "rhescn": (
                x_rh := rhmask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.so2_block_target.x_target_updated"
                    ][0]
                )
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_target_post_so2"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1.0e-4)

    display(Markdown(f"## Layer {layer_idx} x_updated"))
    print(
        {
            "rhescn": (
                x_rh := rhmask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.x_updated"
                    ][0]
                )
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_updated"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1.0e-4)

    display(Markdown(f"## Layer {layer_idx} grid_act to_grid"))

    print(
        {
            "rhescn": (
                x_rh := (
                    rearrange(
                        rhescn.activations[
                            f"LayerBlock_{layer_idx}.MessageBlock.act_rotate_inv.x_grid"
                        ][0],
                        "E (res_beta res_alpha) C -> E res_beta res_alpha C",
                        res_beta=10,
                    )
                )
            ),
            "escn": (
                x_escn := (
                    escn.activations[
                        f"LayerBlock_{layer_idx}.message_block.grid_act.x_grid"
                    ][0]
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} grid_act silu"))

    print(
        {
            "rhescn": (
                x_rh := (
                    rearrange(
                        rhescn.activations[
                            f"LayerBlock_{layer_idx}.MessageBlock.act_rotate_inv.x_grid_silu"
                        ][0],
                        "E (res_beta res_alpha) C -> E res_beta res_alpha C",
                        res_beta=10,
                    )
                )
            ),
            "escn": (
                x_escn := (
                    escn.activations[
                        f"LayerBlock_{layer_idx}.message_block.grid_act.x_grid_silu"
                    ][0]
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} grid_act to_sphere"))

    print(
        {
            "rhescn": (
                x_rh := trimask_mmax(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.act_rotate_inv.x_sphere_silu"
                    ][0]
                )
            ),
            "escn": (
                x_escn := (
                    escn.activations[
                        f"LayerBlock_{layer_idx}.message_block.grid_act.x_sphere_silu"
                    ][0]
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} rotate_inv"))

    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.act_rotate_inv.x_sphere_rotated"
                    ][0]
                )
            ),
            "escn": (
                x_escn := (
                    escn.activations[
                        f"LayerBlock_{layer_idx}.message_block.x_target_rot_inv"
                    ][0]
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} scatter"))

    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.x_scattered"
                    ][0]
                )
            ),
            "escn": (
                x_escn := (
                    escn.activations[
                        f"LayerBlock_{layer_idx}.message_block.x_target_reduce"
                    ][0]
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} x_message"))

    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[f"LayerBlock_{layer_idx}.x_message"][0]
                )
            ),
            "escn": (
                x_escn := (
                    escn.activations[
                        f"LayerBlock_{layer_idx}.message_block.x_target_reduce"
                    ][0]
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} x_message after SO2"))

    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[f"LayerBlock_{layer_idx}.x_message"][0]
                )
            ),
            "escn": (
                x_escn := escn.activations[f"LayerBlock_{layer_idx}.x_message"][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} conv: to_grid_mat"))

    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.pointwise_grid_conv.to_grid_sh_tri"
                    ][0]
                )
            ),
            "escn": (
                x_escn := rearrange(
                    escn.activations[
                        f"LayerBlock_{layer_idx}.grid_conv.x_message_to_grid.to_grid_mat"
                    ][0],
                    "res_beta res_alpha l_sq -> (res_beta res_alpha) l_sq",
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)
    display(Markdown(f"## Layer {layer_idx} conv: x_grid"))

    print(
        {
            "rhescn": (
                x_rh := rhescn.activations[
                    f"LayerBlock_{layer_idx}.pointwise_grid_conv.s2_conv.x_grid"
                ][0]
            ),
            "escn": (
                x_escn := rearrange(
                    escn.activations[f"LayerBlock_{layer_idx}.grid_conv.x_grid"][0],
                    "N res_beta res_alpha C -> N (res_beta res_alpha) C",
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} conv: x_grid_message"))

    print(
        {
            "rhescn": (
                x_rh := rhescn.activations[
                    f"LayerBlock_{layer_idx}.pointwise_grid_conv.s2_conv.x_message_grid"
                ][0]
            ),
            "escn": (
                x_escn := rearrange(
                    escn.activations[
                        f"LayerBlock_{layer_idx}.grid_conv.x_grid_message"
                    ][0],
                    "N res_beta res_alpha C -> N (res_beta res_alpha) C",
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} conv: x_grid_conv"))

    print(
        {
            "rhescn": (
                x_rh := rhescn.activations[
                    f"LayerBlock_{layer_idx}.pointwise_grid_conv.s2_conv.x_grid_conv"
                ][0]
            ),
            "escn": (
                x_escn := rearrange(
                    escn.activations[f"LayerBlock_{layer_idx}.grid_conv.x_grid_conv"][
                        0
                    ],
                    "N res_beta res_alpha C -> N (res_beta res_alpha) C",
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} conv: x_message_from_grid"))

    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.pointwise_grid_conv.x_message_updated"
                    ][0]
                )
            ),
            "escn": (
                x_escn := escn.activations[
                    f"LayerBlock_{layer_idx}.grid_conv.x_message_from_grid"
                ][0]
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)

    display(Markdown(f"## Layer {layer_idx} conv"))

    print(
        {
            "rhescn": (
                x_rh := trimask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.pointwise_grid_conv.x_message_updated"
                    ][0]
                )
            ),
            "escn": (
                x_escn := (
                    escn.activations[f"LayerBlock_{layer_idx}.x_message_updated"][0]
                )
            ),
        },
    )
    np.testing.assert_allclose(x_rh, x_escn, atol=1e-4)


# check_layer(0)

In [30]:
for i in range(12):
    check_layer(i)

## Layer 0 x_edge

## Layer 0 x_source

## Layer 0 x_target

## Layer 0 x_source_rot

## Layer 0 x_target_rot

## Layer 0 x_source_updated

## Layer 0 x_target_updated

## Layer 0 x_updated

## Layer 0 grid_act to_grid

## Layer 0 grid_act silu

## Layer 0 grid_act to_sphere

## Layer 0 rotate_inv

## Layer 0 scatter

## Layer 0 x_message

## Layer 0 x_message after SO2

## Layer 0 conv: to_grid_mat

## Layer 0 conv: x_grid

## Layer 0 conv: x_grid_message

## Layer 0 conv: x_grid_conv

## Layer 0 conv: x_message_from_grid

## Layer 0 conv

## Layer 1 x_edge

## Layer 1 x_source

## Layer 1 x_target

## Layer 1 x_source_rot

## Layer 1 x_target_rot

## Layer 1 x_source_updated

## Layer 1 x_target_updated

## Layer 1 x_updated

## Layer 1 grid_act to_grid

## Layer 1 grid_act silu

## Layer 1 grid_act to_sphere

## Layer 1 rotate_inv

## Layer 1 scatter

## Layer 1 x_message

## Layer 1 x_message after SO2

## Layer 1 conv: to_grid_mat

## Layer 1 conv: x_grid

## Layer 1 conv: x_grid_message

## Layer 1 conv: x_grid_conv

## Layer 1 conv: x_message_from_grid

## Layer 1 conv

## Layer 2 x_edge

## Layer 2 x_source

## Layer 2 x_target

## Layer 2 x_source_rot

## Layer 2 x_target_rot

## Layer 2 x_source_updated

## Layer 2 x_target_updated

## Layer 2 x_updated

## Layer 2 grid_act to_grid

## Layer 2 grid_act silu

## Layer 2 grid_act to_sphere

## Layer 2 rotate_inv

## Layer 2 scatter

## Layer 2 x_message

## Layer 2 x_message after SO2

## Layer 2 conv: to_grid_mat

## Layer 2 conv: x_grid

## Layer 2 conv: x_grid_message

## Layer 2 conv: x_grid_conv

## Layer 2 conv: x_message_from_grid

## Layer 2 conv

## Layer 3 x_edge

## Layer 3 x_source

## Layer 3 x_target

## Layer 3 x_source_rot

## Layer 3 x_target_rot

## Layer 3 x_source_updated

## Layer 3 x_target_updated

## Layer 3 x_updated

## Layer 3 grid_act to_grid

## Layer 3 grid_act silu

## Layer 3 grid_act to_sphere

## Layer 3 rotate_inv

## Layer 3 scatter

## Layer 3 x_message

## Layer 3 x_message after SO2

## Layer 3 conv: to_grid_mat

## Layer 3 conv: x_grid

## Layer 3 conv: x_grid_message

## Layer 3 conv: x_grid_conv

## Layer 3 conv: x_message_from_grid

## Layer 3 conv

## Layer 4 x_edge

## Layer 4 x_source

## Layer 4 x_target

## Layer 4 x_source_rot

## Layer 4 x_target_rot

## Layer 4 x_source_updated

## Layer 4 x_target_updated

## Layer 4 x_updated

## Layer 4 grid_act to_grid

## Layer 4 grid_act silu

## Layer 4 grid_act to_sphere

## Layer 4 rotate_inv

## Layer 4 scatter

## Layer 4 x_message

## Layer 4 x_message after SO2

## Layer 4 conv: to_grid_mat

## Layer 4 conv: x_grid

## Layer 4 conv: x_grid_message

## Layer 4 conv: x_grid_conv

## Layer 4 conv: x_message_from_grid

## Layer 4 conv

## Layer 5 x_edge

## Layer 5 x_source

## Layer 5 x_target

## Layer 5 x_source_rot

## Layer 5 x_target_rot

## Layer 5 x_source_updated

## Layer 5 x_target_updated

## Layer 5 x_updated

## Layer 5 grid_act to_grid

## Layer 5 grid_act silu

## Layer 5 grid_act to_sphere

## Layer 5 rotate_inv

## Layer 5 scatter

## Layer 5 x_message

## Layer 5 x_message after SO2

## Layer 5 conv: to_grid_mat

## Layer 5 conv: x_grid

## Layer 5 conv: x_grid_message

## Layer 5 conv: x_grid_conv

## Layer 5 conv: x_message_from_grid

## Layer 5 conv

## Layer 6 x_edge

## Layer 6 x_source

## Layer 6 x_target

## Layer 6 x_source_rot

## Layer 6 x_target_rot

## Layer 6 x_source_updated

## Layer 6 x_target_updated

## Layer 6 x_updated

## Layer 6 grid_act to_grid

## Layer 6 grid_act silu

## Layer 6 grid_act to_sphere

## Layer 6 rotate_inv

## Layer 6 scatter

## Layer 6 x_message

## Layer 6 x_message after SO2

## Layer 6 conv: to_grid_mat

## Layer 6 conv: x_grid

## Layer 6 conv: x_grid_message

## Layer 6 conv: x_grid_conv

## Layer 6 conv: x_message_from_grid

## Layer 6 conv

## Layer 7 x_edge

## Layer 7 x_source

## Layer 7 x_target

## Layer 7 x_source_rot

## Layer 7 x_target_rot

## Layer 7 x_source_updated

## Layer 7 x_target_updated

## Layer 7 x_updated

## Layer 7 grid_act to_grid

## Layer 7 grid_act silu

## Layer 7 grid_act to_sphere

## Layer 7 rotate_inv

## Layer 7 scatter

## Layer 7 x_message

## Layer 7 x_message after SO2

## Layer 7 conv: to_grid_mat

## Layer 7 conv: x_grid

## Layer 7 conv: x_grid_message

## Layer 7 conv: x_grid_conv

## Layer 7 conv: x_message_from_grid

## Layer 7 conv

## Layer 8 x_edge

## Layer 8 x_source

## Layer 8 x_target

## Layer 8 x_source_rot

## Layer 8 x_target_rot

## Layer 8 x_source_updated

## Layer 8 x_target_updated

## Layer 8 x_updated

## Layer 8 grid_act to_grid

## Layer 8 grid_act silu

## Layer 8 grid_act to_sphere

## Layer 8 rotate_inv

## Layer 8 scatter

## Layer 8 x_message

## Layer 8 x_message after SO2

## Layer 8 conv: to_grid_mat

## Layer 8 conv: x_grid

## Layer 8 conv: x_grid_message

## Layer 8 conv: x_grid_conv

## Layer 8 conv: x_message_from_grid

## Layer 8 conv

## Layer 9 x_edge

## Layer 9 x_source

## Layer 9 x_target

## Layer 9 x_source_rot

## Layer 9 x_target_rot

## Layer 9 x_source_updated

## Layer 9 x_target_updated

## Layer 9 x_updated

## Layer 9 grid_act to_grid

## Layer 9 grid_act silu

## Layer 9 grid_act to_sphere

## Layer 9 rotate_inv

## Layer 9 scatter

## Layer 9 x_message

## Layer 9 x_message after SO2

## Layer 9 conv: to_grid_mat

## Layer 9 conv: x_grid

## Layer 9 conv: x_grid_message

## Layer 9 conv: x_grid_conv

## Layer 9 conv: x_message_from_grid

## Layer 9 conv

## Layer 10 x_edge

## Layer 10 x_source

## Layer 10 x_target

## Layer 10 x_source_rot

## Layer 10 x_target_rot

## Layer 10 x_source_updated

## Layer 10 x_target_updated

## Layer 10 x_updated

## Layer 10 grid_act to_grid

## Layer 10 grid_act silu

## Layer 10 grid_act to_sphere

## Layer 10 rotate_inv

## Layer 10 scatter

## Layer 10 x_message

## Layer 10 x_message after SO2

## Layer 10 conv: to_grid_mat

## Layer 10 conv: x_grid

## Layer 10 conv: x_grid_message

## Layer 10 conv: x_grid_conv

## Layer 10 conv: x_message_from_grid

## Layer 10 conv

## Layer 11 x_edge

## Layer 11 x_source

## Layer 11 x_target

## Layer 11 x_source_rot

## Layer 11 x_target_rot

## Layer 11 x_source_updated

## Layer 11 x_target_updated

## Layer 11 x_updated

## Layer 11 grid_act to_grid

## Layer 11 grid_act silu

## Layer 11 grid_act to_sphere

## Layer 11 rotate_inv

## Layer 11 scatter

## Layer 11 x_message

## Layer 11 x_message after SO2

## Layer 11 conv: to_grid_mat

## Layer 11 conv: x_grid

## Layer 11 conv: x_grid_message

## Layer 11 conv: x_grid_conv

## Layer 11 conv: x_message_from_grid

## Layer 11 conv

In [None]:
layer_idx = 1
display(Markdown(f"## Layer {layer_idx} rotate_inv"))


L = 4


def _select(x):
    range_ = slice((L) ** 2, (L + 1) ** 2)
    return x[:, range_]


print(
    {
        "rhescn": (
            x_rh := _select(
                trimask(
                    rhescn.activations[
                        f"LayerBlock_{layer_idx}.MessageBlock.act_rotate_inv.x_sphere_rotated"
                    ][0]
                )
            )
        ),
        "escn": (
            x_escn := _select(
                escn.activations[
                    f"LayerBlock_{layer_idx}.message_block.x_target_rot_inv"
                ][0]
            )
        ),
    },
)
np.testing.assert_allclose(x_rh, x_escn, atol=1e-5)

## Layer 1 rotate_inv