In [1]:
import typing
from types import MappingProxyType, SimpleNamespace

import attrs
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import jaxtyping as jt
import numpy as np

from arc25.symmetry import SymOp, transform_vector
from arc25.dsl.types import Vector, Dir4
from arc25.vision.symrep import SymRep, Embedding, EmbeddingDims
from arc25.vision.linear import SymmetricLinear
from arc25.vision.rope import QKV, attention_RoPE_with_global
from arc25.vision.flavours import Features, FeatureDim, SymAttention

## Symmetry-preserving axial attention
Applying any point-symmetry to the input field and representation
should be equivalent to applying the point symmetry to the output field and representation

Formally, axial attention computes:
$$
Y(t) = \sum_r |r\rangle \int_\delta p_r[t,\delta] \langle r| V(t+\delta),
\quad\text{where}\quad
p_r[\delta] = w_r[\delta] \cdot \mathrm{softmax}_\delta \left[ \langle r| Q(t)) \cdot \phi_r(\delta) \langle r| K(t+\delta) \right]
$$

Furthermore, applying the operation $s$ to an input field $X(p)$ has the following effect:
$$
X'(p) = R_s X(R_s^{-1} p) = \sum_{uv} |u\rangle \langle u| R_s |v\rangle \langle v| X(R_s^{-1} p)
= \sum_v |s \cdot v\rangle \langle v| X(R_s^{-1} p),
$$
where the last equality holds for representation labels matching symmetry operations.

Thus, we require the following to hold:
$$
\begin{align}
Y'(t) &= \sum_r |r\rangle \int_\delta p'_r[t,\delta] \langle r| V'(t+\delta) \\
&= \sum_r |r\rangle \int_\delta p'_r[t,\delta] \langle s^{-1} r| V(R_s^{-1} (t+\delta)) \\
&= \sum_{r'} |s\cdot r'\rangle \int_\delta p'_{s\cdot r'}[t,\delta] \langle r'| V(R_s^{-1} (t+\delta)) \\
&= \sum_{r'} |s\cdot r'\rangle \int_{\delta'} p'_{s\cdot r'}[t,R_s \delta'] \langle r'| V(R_s^{-1} t+\delta') \\
&= R_s \sum_{r'} |r'\rangle \int_{\delta'} p'_{s\cdot r'}[t,R_s \delta'] \langle r'| V(R_s^{-1} t+\delta') \\
&= R_s Y(R_s^{-1} t).
\end{align}
$$

A sufficient condition for the last equality is $p'_{s\cdot r'}[t,R_s \delta'] = p_{r'}[R_s^{-1} t,\delta']$ everywhere.
Starting from the definition, we have
$$
\begin{align}
p'_{s\cdot r'}[t, R_s \delta'] &= w_{s\cdot r'}[R_s \delta'] \cdot
\mathrm{softmax}_\delta \left[ \langle s\cdot r'| Q'(t))
\cdot \phi_{s\cdot r'}(R_s \delta') \langle s\cdot r'| K'(t+R_s \delta') \right] \\
&= w_{s\cdot r'}[R_s \delta'] \cdot
\mathrm{softmax}_\delta \left[ \langle r'| Q(R_s^{-1} t))
\cdot \phi_{s\cdot r'}(R_s \delta') \langle r'| K(R_s^{-1} t + \delta') \right] \\
\end{align}
$$
thus, again, a sufficient condition are $w_{s\cdot r'}[R_s \delta'] = w_{r'}[\delta']$
and $\phi_{s\cdot r'}(R_s \delta')=\phi_{r'}(\delta')$.
This can be implemented by setting $\phi_r(\delta) = \vec \delta \cdot R_r \hat u$, and suitable rotation of $w$.

In [2]:
dim = FeatureDim(
    globl = EmbeddingDims(iso=128,full=64),
    rows = EmbeddingDims(iso=64,full=32,rep=SymRep.from_seq((SymOp.t,SymOp.l,SymOp.r,SymOp.d))),
    cols = EmbeddingDims(iso=64,full=32,rep=SymRep.from_seq((SymOp.e,SymOp.x,SymOp.y,SymOp.i))),
    cells = EmbeddingDims(iso=32,full=16),
    flavours = 10,
    shape = (7,9),
)
assert dim.is_valid()
inp = dim.make_empty(batch=(3,))
inp.shapes

namespace(globl=namespace(iso=(3, 10, 128), full=(3, 10, 8, 64), rep=8),
          rows=namespace(iso=(3, 7, 10, 64), full=(3, 7, 10, 4, 32), rep=4),
          cols=namespace(iso=(3, 9, 10, 64), full=(3, 9, 10, 4, 32), rep=4),
          cells=namespace(iso=(3, 7, 9, 10, 32),
                          full=(3, 7, 9, 10, 8, 16),
                          rep=8),
          ypos=(3, 7, 2),
          xpos=(3, 9, 2),
          rmsk=(3, 7),
          cmsk=(3, 9),
          mask=(3, 7, 9))

In [3]:
attn = SymAttention(8, dim, 128, num_groups=4, rngs=nnx.Rngs(0))
out = attn(inp)

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
  out = np.asarray(object, dtype=dtype)


msh=(3, 1)
globl.mask: (3,1,s=2,f=90)
axial.mask: (3,1,s=7,f=90)
logits.shape=(3, 1, 90, 4, 2, 4, 8, 9) ((3,1,f=90,p=4,m=2,k=4,t=8,s=9)) msk.shape=(3, 1, 90, 1, 1, 1, 1, 9) ((3,1,f=90,p=1,m=1,k=1,t=1,s=9))
msh=(3, 7)
globl.mask: (3,7,s=2,f=10)
axial.mask: (3,7,s=9,f=10)
logits.shape=(3, 7, 10, 4, 2, 4, 10, 11) ((3,7,f=10,p=4,m=2,k=4,t=10,s=11)) msk.shape=(3, 7, 10, 1, 1, 1, 1, 11) ((3,7,f=10,p=1,m=1,k=1,t=1,s=11))


In [4]:
def count(obj):
    match obj:
        case np.ndarray():
            return obj.size
        case int():
            return 0
        case SymRep():
            return 0
        case _:
            if attrs.has(type(obj)):
                return sum(count(v) for v in attrs.asdict(obj,recurse=False).values())
            print(type(obj).__name__)
            return 0
count(inp)

414093

In [5]:
nnx.display(attn)

In [6]:
class SymencLayer(nnx.Module):
    def __init__(
        self,
        hidden_size: FeatureDim,
        *,
        mlp_width_factor: float,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        def norms(features: FeatureDim = hidden_size,**kw):
            return SimpleNamespace({
                k:SimpleNamespace({
                    kk: nnx.LayerNorm(getattr(getattr(features,k),kk), rngs=rngs, **kw)
                    for kk in ["iso", "full"]
                }) for k in ["globl", "rowcol", "cells"]
            })
        
        # First layer normalization using `flax.nnx.LayerNorm`
        # before we apply Multi-Head Attentn.
        self.norm1 = norms()
        # The Multi-Head Attention layer (using `flax.nnx.MultiHeadAttention`).
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            normalize_qk=False, # True to stabilise learning in ViT-22B; see paper http://arxiv.org/abs/2302.05442
            rngs=rngs,
        )
        # Second layer normalization using `flax.nnx.LayerNorm`.
        self.norm2 = norms()

        # The MLP for point-wise feedforward (using `flax.nnx.Sequential`, `flax.nnx.Linear, flax.nnx.Dropout`)
        # with the GeLU activation function (`flax.nnx.gelu`).
        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )

    # The forward pass through the transformer encoder block.
    def __call__(self, x: jax.Array) -> jax.Array:
        # The Multi-Head Attention layer with layer normalization.
        x = x + self.attn(self.norm1(x))
        # The feed-forward network with layer normalization.
        x = x + self.mlp(self.norm2(x))
        return x

# Example usage for testing:
x = jnp.ones((4, 224, 224, 3))
model = VisionTransformer(num_classes=1000)
y = model(x)
print("Predictions shape: ", y.shape)

SyntaxError: * argument may appear only once (3745105500.py, line 9)

In [None]:
10e6/4e3/3600/24