In [1]:
import typing
from types import MappingProxyType, SimpleNamespace

import attrs
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import jaxtyping as jt
import numpy as np

from arc25.symmetry import SymOp, transform_vector
from arc25.dsl.types import Vector, Dir4
from arc25.vision.symrep import SymRep, SymDecomp, SymDecompDims, standard_rep
from arc25.vision.linear import SymmetricLinear
from arc25.vision.rope import QKV, attention_RoPE_with_global
from arc25.vision.flavours import Field, FieldDims, SymAttention

## Symmetry-preserving axial attention
Applying any point-symmetry to the input field and representation
should be equivalent to applying the point symmetry to the output field and representation

Formally, axial attention computes:
$$
Y(t) = \sum_r |r\rangle \int_\delta p_r[t,\delta] \langle r| V(t+\delta),
\quad\text{where}\quad
p_r[\delta] = w_r[\delta] \cdot \mathrm{softmax}_\delta \left[ \langle r| Q(t)) \cdot \phi_r(\delta) \langle r| K(t+\delta) \right]
$$

Furthermore, applying the operation $s$ to an input field $X(p)$ has the following effect:
$$
X'(p) = R_s X(R_s^{-1} p) = \sum_{uv} |u\rangle \langle u| R_s |v\rangle \langle v| X(R_s^{-1} p)
= \sum_v |s \cdot v\rangle \langle v| X(R_s^{-1} p),
$$
where the last equality holds for representation labels matching symmetry operations.

Thus, we require the following to hold:
$$
\begin{align}
Y'(t) &= \sum_r |r\rangle \int_\delta p'_r[t,\delta] \langle r| V'(t+\delta) \\
&= \sum_r |r\rangle \int_\delta p'_r[t,\delta] \langle s^{-1} r| V(R_s^{-1} (t+\delta)) \\
&= \sum_{r'} |s\cdot r'\rangle \int_\delta p'_{s\cdot r'}[t,\delta] \langle r'| V(R_s^{-1} (t+\delta)) \\
&= \sum_{r'} |s\cdot r'\rangle \int_{\delta'} p'_{s\cdot r'}[t,R_s \delta'] \langle r'| V(R_s^{-1} t+\delta') \\
&= R_s \sum_{r'} |r'\rangle \int_{\delta'} p'_{s\cdot r'}[t,R_s \delta'] \langle r'| V(R_s^{-1} t+\delta') \\
&= R_s Y(R_s^{-1} t).
\end{align}
$$

A sufficient condition for the last equality is $p'_{s\cdot r'}[t,R_s \delta'] = p_{r'}[R_s^{-1} t,\delta']$ everywhere.
Starting from the definition, we have
$$
\begin{align}
p'_{s\cdot r'}[t, R_s \delta'] &= w_{s\cdot r'}[R_s \delta'] \cdot
\mathrm{softmax}_\delta \left[ \langle s\cdot r'| Q'(t))
\cdot \phi_{s\cdot r'}(R_s \delta') \langle s\cdot r'| K'(t+R_s \delta') \right] \\
&= w_{s\cdot r'}[R_s \delta'] \cdot
\mathrm{softmax}_\delta \left[ \langle r'| Q(R_s^{-1} t))
\cdot \phi_{s\cdot r'}(R_s \delta') \langle r'| K(R_s^{-1} t + \delta') \right] \\
\end{align}
$$
thus, again, a sufficient condition are $w_{s\cdot r'}[R_s \delta'] = w_{r'}[\delta']$
and $\phi_{s\cdot r'}(R_s \delta')=\phi_{r'}(\delta')$.
This can be implemented by setting $\phi_r(\delta) = \vec \delta \cdot R_r \hat u$, and suitable rotation of $w$.

In [2]:
dim = FieldDims.make(
    inv_fac=2,
    context=64,
    hdrs=32,
    cells=16,
    flavours = 10,
    shape = (7,9),
)
assert dim.is_valid()
inp = dim.make_empty(batch=(3,))
inp.shapes

namespace(context=namespace(inv=(3, 10, 128), equiv=(3, 10, 8, 64), rep=8),
          rows=namespace(inv=(3, 7, 10, 64), equiv=(3, 7, 10, 4, 32), rep=4),
          cols=namespace(inv=(3, 9, 10, 64), equiv=(3, 9, 10, 4, 32), rep=4),
          cells=namespace(inv=(3, 7, 9, 10, 32),
                          equiv=(3, 7, 9, 10, 8, 16),
                          rep=8),
          ypos=(3, 7, 2),
          xpos=(3, 9, 2),
          rmsk=(3, 7),
          cmsk=(3, 9),
          mask=(3, 7, 9))

In [3]:
attn = SymAttention(8, dim, 128, num_groups=4, rngs=nnx.Rngs(0))
out = attn(inp)

In [4]:
def count(obj):
    match obj:
        case np.ndarray():
            return obj.size
        case int():
            return 0
        case SymRep():
            return 0
        case _:
            if attrs.has(type(obj)):
                return sum(count(v) for v in attrs.asdict(obj,recurse=False).values())
            print(type(obj).__name__)
            return 0
count(inp)

414093

In [5]:
nnx.display(attn)

In [6]:
class SymencLayer(nnx.Module):    
    def __init__(
        self,
        hidden_size: FieldDims,
        mha_features: int,
        *,
        mlp_width_factor: float,
        num_heads: int,
        num_groups: int | None = None,
        dropout_rate: float = 0.0,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        def norms(features: FieldDims = hidden_size,**kw):
            return features.map_representations(
                lambda k,kk,v: nnx.LayerNorm(v, rngs=rngs, **kw)
            )
            
        def make_linear(in_feat: FieldDims, out_feat: FieldDims,**kw):
            return in_feat.map_projections(
                lambda k,v,o: SymmetricLinear(v, o, rngs=rngs),
                out_feat,
            )
        
        # First layer normalization using `flax.nnx.LayerNorm`
        # before we apply Multi-Head Attentn.
        self.norm1 = norms()
        # The Multi-Head Attention layer
        self.attn = SymAttention(
            num_heads=num_heads,
            num_groups = num_groups,
            in_features=hidden_size,
            qkv_features=mha_features,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            deterministic=False,
            normalize_qk=False, # True to stabilise learning in ViT-22B; see paper http://arxiv.org/abs/2302.05442
            rngs=rngs,
        )
        # Second layer normalization using `flax.nnx.LayerNorm`.
        self.norm2 = norms()

        mlp_dim = attrs.evolve(hidden_size,**{
            k:attrs.evolve(v,**{kk:int(round(vv*mlp_width_factor)) for kk,vv in v.representations.items()})
            for k,v in hidden_size.projections.items()
        })

        # The MLP for point-wise feedforward (using `flax.nnx.Sequential`, `flax.nnx.Linear, flax.nnx.Dropout`)
        # with the GeLU activation function (`flax.nnx.gelu`).
        self.mlp = SimpleNamespace(
            widen = make_linear(hidden_size, mlp_dim),
            activation = nnx.gelu,
            pre_dropout = nnx.Dropout(dropout_rate, rngs=rngs),
            narrow = make_linear(mlp_dim, hidden_size),
            post_dropout = nnx.Dropout(dropout_rate, rngs=rngs),
        )

    def __call__(self, x: Field) -> Field:
        def apply(inp, fun, *other):
            return fun(inp, *other)

        # The Multi-Head Attention layer with layer normalization.
        ax = x.map_representations(apply, self.norm1)
        ax = self.attn(ax)
        x = x.map_representations(lambda a,b:a+b,ax)
        
        # The feed-forward network with layer normalization.
        ax = x.map_representations(apply, self.norm2)
        ax = ax.map_projections(apply, self.mlp.widen)
        ax = ax.map_representations(self.mlp.activation)
        ax = ax.map_representations(self.mlp.pre_dropout)
        ax = ax.map_projections(apply, self.mlp.narrow)
        ax = ax.map_representations(self.mlp.post_dropout)
        x = x.map_representations(lambda a,b:a+b,ax)
        return x

# Example usage for testing:
model = SymencLayer(
    dim,
    mha_features = 128,
    num_heads=4,
    mlp_width_factor = 4,
)
y = model(inp)


In [7]:
nnx.display(model)
nnx.display(model.attn)
nnx.display(model.mlp.widen)
nnx.display(model.mlp.narrow)

Traceback (most recent call last):
  File "/Users/yves/.pyenv/versions/3.13.7/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/treescope/renderers.py", line 290, in _render_subtree
    maybe_result = handler(node=node, path=path, subtree_renderer=rec)
  File "/Users/yves/.pyenv/versions/3.13.7/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/treescope/_internal/handlers/custom_type_handlers.py", line 65, in handle_via_treescope_repr_method
    return treescope_repr_method(path, subtree_renderer)
  File "/Users/yves/.pyenv/versions/3.13.7/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/flax/nnx/pytreelib.py", line 501, in __treescope_repr__
    stats = OBJECT_CONTEXT.node_stats[id(self)]
            ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^
KeyError: 13307663024



In [23]:
class ArcClassifier(nnx.Module):
    """
    """
    def __init__(
        self,
        *,
        num_classes: int = 1000,
        num_colours: int = 10,
        img_size: int = 30,
        num_layers: int = 12,

        dtype: typing.Any | None = None,
        
        hidden_size: FieldDims,
        mha_features: int,
        mlp_width_factor: float,
        num_heads: int,
        num_groups: int | None = None,
        dropout_rate: float = 0.1,

        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        self.num_colours = num_colours
        self.dtype = dtype 
        self.hidden_size = hidden_size

        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)

        self.embedding = hidden_size.map_projections(
            lambda k,v: SymmetricLinear(
                attrs.evolve(v,inv=dict(context=2,cells=2).get(k,0),equiv=0),
                v,
                rngs=rngs,
            )
        )
        self.encoder = nnx.Sequential(*[
            SymencLayer(
                hidden_size=hidden_size,
                mha_features=mha_features,
                mlp_width_factor=mlp_width_factor,
                num_heads=num_heads,
                num_groups=num_groups,
                dropout_rate=dropout_rate,
                rngs=rngs,
            )
            for i in range(num_layers)
        ])
        
        # Layer normalization with `flax.nnx.LayerNorm`.
        self.final_norm = hidden_size.context.map_representations(
            lambda k,v: nnx.LayerNorm(v, rngs=rngs)
        )
        
        # Classification head (maps the transformer encoder to class probabilities).
        rep = hidden_size.context.rep
        n_base = hidden_size.context.inv
        n_equiv = min(hidden_size.context.equiv, n_base//rep.dim//2)
        self.equiv_reduction = nnx.Linear(
            in_features = hidden_size.context.equiv,
            out_features = n_equiv,
            rngs = rngs
        )
        n_colour = min(hidden_size.context.equiv, n_base//self.num_colours//2)
        self.colour_reduction = nnx.Linear(
            in_features = hidden_size.context.inv,
            out_features = n_colour,
            rngs = rngs
        )
        print(f"{n_base=} {n_equiv=} {n_colour=}")
        self.classifier_activation = nnx.gelu
        self.classifier = nnx.Linear(n_base+n_equiv*rep.dim + n_colour*self.num_colours, num_classes, rngs=rngs)

    def __call__(self, x: jt.Int[jt.Array,"... Y X"], size: jt.Int[jt.Array,"... 2"]) -> jt.Float[jt.Array,"... L"]:
        batch = x.shape[:-2]
        pre_embedding = self.encode(x, size)
        print(f"{pre_embedding.shapes=}")
        embedding = pre_embedding.map_projections(lambda v,f:f(v),self.embedding)
        print(f"{embedding.shapes=}")

        # Apply the dropout layer to embedded patches.
        embedding = embedding.map_representations(self.dropout)

        # Transformer encoder blocks.
        # Process the embedded patches through the transformer encoder layers.
        x = self.encoder(embedding)
        # Apply final layer normalization (only to contxt, equivalent of the ViT's CLS token)
        x = x.context
        x = x.map_representations(lambda v,f:f(v),self.final_norm)

        base = x.inv[...,0,:]
        equiv = self.equiv_reduction(x.equiv[...,0,:,:]).reshape(*batch,-1)
        colour = self.colour_reduction(x.inv[...,1:,:]).reshape(*batch,-1)
        print(f"{base.shape=} {equiv.shape=} {colour.shape=}")
        x = jnp.concatenate([base,equiv,colour],axis=-1)
        print(f"{x.shape=}")
        x = self.classifier_activation(x)
        
        # Predict class probabilities based on the CLS token embedding.
        return self.classifier(x)
    
    def encode(self, x: jt.Int[jt.Array,"... Y X"], size: jt.Int[jt.Array,"... 2"]) -> Field:
        batch = x.shape[:-2]
        Y,X = shape = x.shape[-2:]
        Fc = self.num_colours
        F = Fc+1
        R = standard_rep.dim
        
        dtype = self.dtype
        
        x = x[...,:,:,None]
        sY = size[...,0,None] 
        sX = size[...,1,None]
        xpos = jnp.concatenate([
            jnp.tile(np.r_[:X][:,None].astype(dtype),batch+(1,1)),
            np.r_[:X][:,None].astype(dtype)/sX[...,:,None],
        ],axis=-1)
        ypos = jnp.concatenate([
            jnp.tile(np.r_[:Y][:,None].astype(dtype),batch+(1,1)),
            np.r_[:Y][:,None].astype(dtype)/sY[...,:,None],
        ],axis=-1)
        rmsk = np.r_[:Y]<sY
        cmsk = np.r_[:X]<sX
        mask = (rmsk[...,:,None] & cmsk[...,None,:])
        print(f"{x.shape=} {sY.shape=} {sX.shape=} {xpos.shape=} {rmsk.shape=} {mask.shape=}")
        
        colour_idx = np.r_[:self.num_colours]
        presence = (x == colour_idx) & mask[...,None]
        # TODO: product instead of mask?
        prevalence = presence.sum((-3,-2)).astype(dtype) / (1+mask.sum((-2,-1))[...,None].astype(dtype))
        intensity = 1 / (1+prevalence)

        special_ind = jnp.concatenate([
            jnp.ones(batch+(1,1),dtype),
            jnp.zeros(batch+(Fc,1),dtype),
        ],axis=-2)
        prevalence_ind = jnp.concatenate([
            jnp.zeros(batch+(1,1),dtype),
            prevalence[...,:,None],
        ],axis=-2)
        context = jnp.concatenate([special_ind,prevalence_ind],axis=-1)

        presence_ind = jnp.concatenate([
            jnp.zeros(batch+(Y,X,1,1),dtype),
            presence[...,:,None].astype(dtype),
        ],axis=-2)
        intensity_ind = jnp.concatenate([
            jnp.take_along_axis(intensity[...,None,None,:],x,axis=-1)[...,:,:,:,None],
            jnp.zeros(batch+(Y,X,Fc,1),dtype),
        ],axis=-2)
        cells = jnp.concatenate([presence_ind,intensity_ind],axis=-1)
        print(f"{batch=} {context.shape=} {cells.shape=}")

        rrep = self.hidden_size.rows.rep
        crep = self.hidden_size.cols.rep
        return Field(
            context = SymDecomp(inv=context,equiv=jnp.empty(batch+(F,R,0),dtype)),
            rows = SymDecomp(inv=jnp.empty(batch+(Y,F,0),dtype),equiv=jnp.empty(batch+(Y,F,rrep.dim,0),dtype),rep=rrep),
            cols = SymDecomp(inv=jnp.empty(batch+(X,F,0),dtype),equiv=jnp.empty(batch+(X,F,crep.dim,0),dtype),rep=crep),
            cells = SymDecomp(inv=cells,equiv=jnp.empty(batch+shape+(F,R,0),dtype)),
            xpos = xpos,
            ypos = ypos,
            rmsk = rmsk,
            cmsk = cmsk,
            mask = mask,
        )



In [24]:
dims = FieldDims.make(
    inv_fac = 2,
    context = 64,
    hdrs = 32,
    cells = 16,
#        flavours = 10,
#        shape = (7,9),
)
dims

FieldDims(context=SymDecompDims(inv=128, equiv=64, rep=SymRep(opseq=(e,x,y,i,t,l,r,d))), rows=SymDecompDims(inv=64, equiv=32, rep=SymRep(opseq=(t,l,r,d))), cols=SymDecompDims(inv=64, equiv=32, rep=SymRep(opseq=(e,x,y,i))), cells=SymDecompDims(inv=32, equiv=16, rep=SymRep(opseq=(e,x,y,i,t,l,r,d))), flavours=None, shape=None)

In [25]:
cls = ArcClassifier(
    hidden_size = dims,
    mha_features = 128,
    mlp_width_factor = 4,
    num_heads = 4,
)

n_base=128 n_equiv=8 n_colour=6


In [26]:
ei = cls.encode(np.random.randint(0,10,(3,7,9)),np.array([(5,5),(7,9),(4,8)]))
ei.shapes

x.shape=(3, 7, 9, 1) sY.shape=(3, 1) sX.shape=(3, 1) xpos.shape=(3, 9, 2) rmsk.shape=(3, 7) mask.shape=(3, 7, 9)
batch=(3,) context.shape=(3, 11, 2) cells.shape=(3, 7, 9, 11, 2)


namespace(context=namespace(inv=(3, 11, 2), equiv=(3, 11, 8, 0), rep=8),
          rows=namespace(inv=(3, 7, 11, 0), equiv=(3, 7, 11, 4, 0), rep=4),
          cols=namespace(inv=(3, 9, 11, 0), equiv=(3, 9, 11, 4, 0), rep=4),
          cells=namespace(inv=(3, 7, 9, 11, 2),
                          equiv=(3, 7, 9, 11, 8, 0),
                          rep=8),
          ypos=(3, 7, 2),
          xpos=(3, 9, 2),
          rmsk=(3, 7),
          cmsk=(3, 9),
          mask=(3, 7, 9))

In [27]:
logits = cls(np.random.randint(0,10,(3,7,9)),np.array([(5,5),(7,9),(4,8)]))
logits.shape

x.shape=(3, 7, 9, 1) sY.shape=(3, 1) sX.shape=(3, 1) xpos.shape=(3, 9, 2) rmsk.shape=(3, 7) mask.shape=(3, 7, 9)
batch=(3,) context.shape=(3, 11, 2) cells.shape=(3, 7, 9, 11, 2)
pre_embedding.shapes=namespace(context=namespace(inv=(3, 11, 2), equiv=(3, 11, 8, 0), rep=8), rows=namespace(inv=(3, 7, 11, 0), equiv=(3, 7, 11, 4, 0), rep=4), cols=namespace(inv=(3, 9, 11, 0), equiv=(3, 9, 11, 4, 0), rep=4), cells=namespace(inv=(3, 7, 9, 11, 2), equiv=(3, 7, 9, 11, 8, 0), rep=8), ypos=(3, 7, 2), xpos=(3, 9, 2), rmsk=(3, 7), cmsk=(3, 9), mask=(3, 7, 9))
embedding.shapes=namespace(context=namespace(inv=(3, 11, 128), equiv=(3, 11, 8, 64), rep=8), rows=namespace(inv=(3, 7, 11, 64), equiv=(3, 7, 11, 4, 32), rep=4), cols=namespace(inv=(3, 9, 11, 64), equiv=(3, 9, 11, 4, 32), rep=4), cells=namespace(inv=(3, 7, 9, 11, 32), equiv=(3, 7, 9, 11, 8, 16), rep=8), ypos=(3, 7, 2), xpos=(3, 9, 2), rmsk=(3, 7), cmsk=(3, 9), mask=(3, 7, 9))
base.shape=(3, 128) equiv.shape=(3, 64) colour.shape=(3, 60)
x.shape=(3

(3, 1000)

In [None]:
cls.encoder.layers[0].attn.in_features