In [137]:
import typing
from types import MappingProxyType, SimpleNamespace

import attrs
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import jaxtyping as jt
import numpy as np

from arc25.symmetry import SymOp, transform_vector
from arc25.dsl.types import Vector, Dir4

In [89]:
# we could have other symmetries
AnySymOp: typing.TypeAlias = SymOp

@attrs.frozen
class SymRep:
    # these are not actually operations of the symmetry!
    # these are just labels, and symmetry operations connect these labels
    opseq: tuple[AnySymOp,...]
    op2idx: typing.Mapping[AnySymOp,int] = attrs.field(default=attrs.Factory(
        lambda self:MappingProxyType({v:k for k,v in enumerate(self.opseq)}),
        takes_self=True,
    ))

    @classmethod
    def from_seq(cls, opseq: typing.Iterable[AnySymOp]) -> typing.Self:
        ret = cls(tuple(opseq))
        assert ret.is_valid()
        return ret

    def is_valid(self):
        # ensure inverse map is correct 
        if self.op2idx != {v:k for k,v in enumerate(self.opseq)}:
            return False
        # ensure group is closed
        operations = set(self.opseq[0].inverse.combine(o) for o in self.opseq)
        completion = set(o.inverse for o in operations) | set(
            a.combine(b) for a in operations for b in operations
        )
        return completion == operations
    
    @property
    def dim(self)->int:
        return len(self.opseq)

standard_rep = SymRep.from_seq(SymOp)
standard_rep

SymRep(opseq=(<SymOp.e: 0>, <SymOp.x: 1>, <SymOp.y: 2>, <SymOp.i: 3>, <SymOp.t: 4>, <SymOp.l: 5>, <SymOp.r: 6>, <SymOp.d: 7>), op2idx=mappingproxy({<SymOp.e: 0>: 0, <SymOp.x: 1>: 1, <SymOp.y: 2>: 2, <SymOp.i: 3>: 3, <SymOp.t: 4>: 4, <SymOp.l: 5>: 5, <SymOp.r: 6>: 6, <SymOp.d: 7>: 7}))

In [106]:


@attrs.frozen
class Embedding:
    iso: jt.Float[jt.Array, "... Ci"]
    full: jt.Float[jt.Array, "... R Cf"]
    rep: SymRep = standard_rep

    @property
    def shapes(self):
        return SimpleNamespace(
            iso=self.iso.shape,
            full=self.full.shape,
            rep=self.rep.dim,
        )

@attrs.frozen
class EmbeddingDims:
    iso: int # isotropic values/trivial representation
    full: int # full-dimensional representation
    rep: SymRep = standard_rep

    @property
    def dims(self):
        return SimpleNamespace(
            iso=self.iso,
            full=self.full,
            rep=self.rep.dim,
        )
    
    def validate(self, embedding: Embedding) -> bool:
        try:
            np.broadcast_shapes(
                embedding.iso.shape[:-1],
                embedding.full.shape[:-2],
            )
        except ValueError:
            return False
        return (
            self.rep == embedding.rep
            and self.iso == embedding.iso.shape[-1]
            and (self.rep.dim,self.full) == embedding.full.shape[-2:]
        )

    def make_empty(self, batch: tuple[int,...]=()) -> Embedding:
        ret = Embedding(
            iso = np.empty(batch+(self.iso,)),
            full = np.empty(batch+(self.rep.dim,self.full)),
            rep = self.rep,
        )
        assert self.validate(ret)
        return ret

In [224]:
from jax import lax
from flax import nnx
from flax.nnx.nn.linear import default_kernel_init, default_bias_init
from flax.nnx.nn import dtypes
from flax.nnx import rnglib
from flax.typing import (
  Dtype,
  Initializer,
  PrecisionLike,
  DotGeneralT,
  PromoteDtypeFn,
)

class SymmetricLinear(nnx.Module):
    """Representation-elemet-wise pointwise operation across symmetry and groups of channels; respecting symmetry.

    Weight format is (C,R,C')
    The layer computes o(n,h,w,r',c') = b(c') + sum_cr k(c,r^-1.r',c') * i(n,h,w,r,c)
    """
    def __init__(
        self,
        in_features: EmbeddingDims,
        out_features: EmbeddingDims,
        *,
        constraint_mode: typing.Literal["gather-then-concat","concat-then-gather"] = "concat-then-gather",
        use_bias: bool = True,
        dtype: Dtype | None = None,
        param_dtype: Dtype = jnp.float32,
        precision: PrecisionLike = None,
        kernel_init: Initializer = default_kernel_init,
        bias_init: Initializer = default_bias_init,
        dot_general: DotGeneralT = lax.dot_general,
        promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
        preferred_element_type: Dtype | None = None,
        rngs: rnglib.Rngs,
    ):
        R = min(in_features.rep.dim, out_features.rep.dim)
        kw = dict(
            dtype = dtype,
            param_dtype = param_dtype,
            precision = precision,
            kernel_init = kernel_init,
            bias_init = bias_init,
            dot_general = dot_general,
            promote_dtype = promote_dtype,
            # preferred_element_type = preferred_element_type,
        )
        self.iso2iso = nnx.Linear(
            in_features.iso,
            out_features.iso,
            use_bias = use_bias,
            **kw,
            rngs = rngs,
        ) if in_features.iso and out_features.iso else nnx.data(None)
        self.iso2full = nnx.Linear(
            in_features.iso,
            out_features.full,
            use_bias = use_bias,
            **kw,
            rngs = rngs,
        ) if in_features.iso and out_features.full else nnx.data(None)
        self.full2iso = nnx.Linear(
            in_features.full,
            out_features.iso,
            use_bias = False,
            **kw,
            rngs = rngs,
        ) if in_features.full and out_features.iso else nnx.data(None)
        kernel_key = rngs.params()
        self.full2full = nnx.Param(
          kernel_init(kernel_key, (in_features.full, R, out_features.full), param_dtype)
        )
        
        self.in_features = in_features
        self.out_features = out_features
        self.constraint_mode = dict(gtc="gather-then-concat",ctg="concat-then-gather").get(constraint_mode,constraint_mode)
        self.use_bias = use_bias
        for k,v in kw.items():
            setattr(self, k, v)

    def _prepare_kernel(self, kernel):
        r"""
        The layer computes o(n,h,w,r',c') = sum_mr k(c,r^-1.r',c') * i(n,h,w,r,c)
        This is re-cast into a "standard" linear as o'(n;h;w;r',c') = sum_mr i'(n;h;w;r,c)*k'(r,c;r',c')

        Thus, we return a kernel of shape (R,C,R',C')

        We want the kernel to be symmetry-invariant; that is \rho(s) Y = \rho(s) K X = K \rho(s) X.
        Let X = \sum_v X_v |v>, such that \rho(s) X = \sum_v X_v |s.v>
        Then Y = \sum_uv K_uv X_v |u>
        Thus \rho(s) Y = \sum_uv K_uv X_v |s.u> = \sum_uvw K_uw X_v |u><w|s.v>
        -> <s.u|K|v> = <u|K|s.v> for all u,v,s
        We will select an arbitrary channel o, keep <u|K|o> as the kernel element u
        and derive the rest from it.
        -> <u|K|v> = <o.v^-1.u|K|o> for all u,v
        """
        fi = self.in_features
        fo = self.out_features
        ri = fi.rep
        ro = fo.rep

        assert ri == ro, "Not sure if the representation logic below is correct for representation changes"
        
        R = ri.dim
        C = fi.full
        Rp = ro.dim
        Cp = fo.full

        mode = self.constraint_mode

        # this is that |o>
        refop = ri.opseq[0]
        
        ret = []
        for u in ro.opseq:
            # apply the representation
            ki = np.array([ro.op2idx[refop.combine(v.inverse).combine(u)] for v in ri.opseq],"i4")
            # reshape into the format we need
            if mode == "gather-then-concat":
                k = kernel[None,:,ki,:]
                ret.append(k)
            elif mode == "concat-then-gather":
                ret.append(ki)
            else:
                raise ValueError(mode)

        match mode:
            case "gather-then-concat":
                ret = jnp.concatenate(ret, axis=-4)
            case "concat-then-gather":
                ki = np.array(ret)
                k = kernel[:,ki,:]
                k = k.transpose(1,0,2,3)
                ret = k
            case _:
                raise KeyError(mode)

        return ret

    
    def __call__(self, inputs: Embedding) -> Embedding:
        """Applies a linear transformation to the inputs along the last dimension.
        
        Args:
          inputs: The nd-array to be transformed.
        
        Returns:
          The transformed input.
        """
        assert self.in_features.validate(inputs)
        
        kernel_base = self.full2full.value

        xi, xf, kernel_base = self.promote_dtype(
          (inputs.iso, inputs.full, kernel_base), dtype=self.dtype
        )

        # TODO: should the be applied before promotion instead?
        kernel = self._prepare_kernel(kernel_base)
        
        xfa = jnp.mean(xf,axis=-2)

        of = self.out_features
        yi = [
            lin(inp)
            for lin,inp in [(self.iso2iso,xi), (self.full2iso,xfa)]
            if lin is not None
        ]
        yi = sum(yi) if yi else jnp.empty(xi.shape[:-1]+(of.iso,),xi.dtype)
        
        # We use dot_general_kwargs for BC compatibility with
        # user custom self.dot_general method which may not have
        # preferred_element_type argument to avoid breaking
        # existing code
        dot_general_kwargs = {}
        if False and self.preferred_element_type is not None:
            dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
        yf = self.dot_general(
            xf,
            kernel,
            (((xf.ndim - 2,xf.ndim - 1), (0,1)), ((), ())),
            precision=self.precision,
            **dot_general_kwargs,
        )
        if self.iso2full is not None:
            yfa = self.iso2full(xi)
            yf = yf + yfa[...,None,:]
        ret = Embedding(yi,yf,rep=of.rep)
        assert self.out_features.validate(ret)
        return ret


In [225]:
out = []
for m in ["gtc","ctg"]:
    lin = SymmetricLinear(
        EmbeddingDims(3,2),
        EmbeddingDims(2,3),
        constraint_mode="gtc",
        rngs=nnx.Rngs(0),
    )
    out.append(lin(Embedding(np.arange(3),np.arange(8*2).reshape(8,2))))
gtc,ctg = out
for k in ["iso","full"]:
    ok = np.allclose(getattr(gtc,k), getattr(ctg,k))
    print(f"{k}: {ok=}")

iso: ok=True
full: ok=True


In [226]:
for sobseq in [
    # D2 (never swaps x and y)
    standard_rep.opseq[4:],
    # parity preserving subset
    (SymOp.x, SymOp.y, SymOp.t, SymOp.d),
    # C2 
    (SymOp.e, SymOp.i),
]:
    o = sobseq[0]
    print(f"{o.name}: {" ".join(v.name for v in sobseq)}")
    for u in sobseq:
        comb = tuple(o.combine(v.inverse).combine(u) for v in sobseq)
        print(f"{u.name}: {" ".join(q.name for q in comb)} | {"ok" if all(q in sobseq for q in comb) else "NOK"}")

t: t l r d
t: t l r d | ok
l: r d t l | ok
r: l t d r | ok
d: d r l t | ok
x: x y t d
x: x y t d | ok
y: y x d t | ok
t: t d x y | ok
d: d t y x | ok
e: e i
e: e i | ok
i: i e | ok


In [315]:
@attrs.frozen
class QKV:
    query: jt.Float[jt.Array, "... T P N H 2"]
    key: jt.Float[jt.Array, "... S P K H 2"]
    value: jt.Float[jt.Array, "... S P K D"]
    mask: jt.Bool[jt.Array, "... S"]| None = None

    @property
    def shape(self):
        T,P,N,H,two = self.query.shape[-5:]
        S = self.key.shape[-4]
        K = self.key.shape[-2]
        D = self.value.shape[-1]
        return SimpleNamespace(
            T=T,P=P,N=N,H=H,S=S,K=K,D=D,batch=self.query.shape[:-5],
        )
    
    def validation_problems(self):
        T,P,N,H,two = self.query.shape[-5:]
        S = self.key.shape[-5]
        K = self.key.shape[-3]
        D = self.value.shape[-1]
        if two != 2:
            return "query 2"
        if self.key.shape[-5:] != (S,P,K,H,2):
            return "key"
        if self.value.shape[-4:] != (S,P,K,D):
            return "value"
        if self.mask.shape[-1] != S:
            return "mask"
        try:
            np.broadcast_shapes(
                 self.query.shape[:-5],
                 self.key.shape[:-5],
                 self.value.shape[:-4],
                 self.mask.shape[:-1],
            )
        except ValueError:
            return "batch"

    def is_valid(self):
        return not self.validation_problems()

def attention_RoPE_with_global(
    globl: QKV,
    axial: QKV,
    pQ: jt.Float[jt.Array, "... T K H"],
    pK: jt.Float[jt.Array, "... S K H"] | None=None,
    *,
    # this one is usually static; values are 0: normal, 1: reverse
    polarisation: jt.Int[jt.Array, "P"],
):
    print(f"{globl.shape=} {axial.shape=} {pQ.shape=}")
    assert globl.is_valid(), f"{globl.validation_problems()}: q={globl.query.shape} k={globl.key.shape} v={globl.value.shape}"
    assert axial.is_valid(), axial.validation_problems()

    # global and axial need to be mostly consistent
    sa = axial.shape
    sg = globl.shape
    for k,v in vars(sa).items():
        if k in {"S","T"}:
            continue
        assert getattr(sg,k) == v, f"{k}: {getattr(globl,k)} <> {v}"
    assert sa.T == sa.S, f"RoPE requires equal S & T: {sa}"
    assert pQ.shape == sa.batch+(sa.T,sa.K,sa.H), f"{pQ.shape=} {sa}"
    assert pK is None or pK.shape == sa.batch+(sa.T,sa.K,sa.H), f"{pK.shape=} {sa}"
    
    # calculate rotation matrices; these have shape [... S/T P K H 2 2]: (length, polarisation, head, feature, u, v)
    phi = []
    for p in [pQ,pK]:
        if p is None:
            phi.append(phi[-1])
            continue
        cs,sn = jnp.cos(p), jnp.sin(p)
        nsn = -sn
        # rd will have shape ... S/T K H 3
        rd = jnp.moveaxis(jnp.array([cs,sn,nsn]),0,-1)
        # idx will have shape 2 2 2
        idx = np.r_[0,2,1,0,0,1,2,0].reshape(2,2,2)
        # r will have shape ... S/T 2 K H 2 2
        r = jnp.moveaxis(rd[...,idx],-3,-5)
        print(f"{p.shape=} {rd.shape=} {r.shape=}")
        # r now will have the final target shape
        r = r[...,polarisation,:,:,:,:]
        print(f"{polarisation.shape=} {r.shape=}")
        phi.append(r)
    rQ,rK = phi
    if pK is None:
        pK = pQ
    
    Sa,K,H = axial.key.shape[-4:-1]
    Sg,_,D = globl.value.shape[-3:]
    Na = axial.query.shape[-2]
    Tg,Ng = globl.query.shape[-3:-1]
    assert not Na % K
    assert not Ng % K
    Ma = Na // K
    Mg = Ng // K
    
    aQ = jnp.einsum("...tpmkhu, tpkuv -> ...tpmkv", axis.query.reshape(*axis.query.shape[:-4],L,K,Ma,H,2), rQ)
    aK = jnp.einsum("...spkhu, spkuv -> ...spkv", axis.key, rK)
    aV = axis.value

    gQ = globl.query.reshape(*globl.query.shape[:-4],Tg,K,Mg,H,2)
    gK = globl.key
    gV = globl.value

    log_aa = jnp.einsum("...tpkmhv,...spkhv -> ...pkmts", aQ, aK)
    log_gg = jnp.einsum("...tpkmhv,...spkhv -> ...pkmts", gQ, gK)
    log_ga = jnp.einsum("...tpkmhv,...spkhv -> ...pkmts", gQ, aK)
    log_ag = jnp.einsum("...tpkmhv,...spkhv -> ...pkmts", aQ, gK)
    
    scale = 1/np.sqrt(H)
    V = jnp.concatenate([gV, aV],axis=-3)
    msh = np.broadcast_shapes(*[arg.mask.shape[:-1] for arg in [globl, axial] if arg.mask is not None])
    msk = jnp.concatenate([
        jnp.ones(msh+arg.value.shape[-3:-2]) if arg.mask is None else arg.mask
        for arg in [globl,axial]
    ]) if msh else None

    if Ma == Mg:
        assert K*Ma == Na == Ng == K*Mg
        N = Na
        logits = jnp.block([[log_gg,log_ga],[log_ag,log_aa]])
        P = jax.nn.softmax(logits*scale, axis=-1, where=msk)
        result = jnp.einsum("...pkmts,...spkd -> ...tpkmd",P,V)
        result = result.reshape(*result.shape[:-3],-1,D)
        assert result.shape[-3:-1] == (Tg+Ta, N)
        globl = result[...,:Tg,:,:,:]
        axial = result[...,Tg:,:,:,:]
    else:
        res = []
        for logits in [[log_gg,log_ga], [log_ag,log_aa]]:
            logits = jnp.concatenate(logits, axis=-1)
            P = jax.nn.softmax(logits*scale, axis=-1, where=msk)
            result = jnp.einsum("...pkmts,...spkd -> ...tpkmd",P,V)
            result = result.reshape(*result.shape[:-3],-1,D)
            res.append(result)
        globl, axial = res
    return globl, axial
        
            

In [316]:
@attrs.frozen
class Features:
    globl: Embedding # dimensions (... R? C); full representation
    rows: Embedding # dimensions (... Y R? C); representation (t,l,r,d)
    cols: Embedding # dimensions (... X R? C); representation (e,x,y,i)
    cells: Embedding # dimensions (... Y X R? C); full representation
    ypos: jt.Float[jt.Array,"... Y 2"] # (absolute positions, relative positions)
    xpos: jt.Float[jt.Array,"... X 2"] # (absolute positions, relative positions)
    rmsk: jt.Bool[jt.Array,"... Y"]
    cmsk: jt.Bool[jt.Array,"... X"]
    mask: jt.Bool[jt.Array,"... Y X"]

    @property
    def shapes(self):
        return SimpleNamespace({
            k:v.shapes if isinstance(v, Embedding) else v.shape
            for k,v in attrs.asdict(self, recurse=False).items()
        })

@attrs.frozen
class FeatureDim:
    globl: EmbeddingDims
    rows: EmbeddingDims
    cols: EmbeddingDims
    cells: EmbeddingDims
    # shape is in fact optional, as we don't need it for any weight calculation
    shape: tuple[int,int] | None = None

    def validity_problem(self):
        if not (
            self.globl.rep.is_valid()
            and self.rows.rep.is_valid()
            and self.cols.rep.is_valid()
            and self.cells.rep.is_valid()
        ):
            return "invalid rep"
        if not set(self.globl.rep.opseq) == set(self.rows.rep.opseq) | set(self.cols.rep.opseq) == set(self.cells.rep.opseq):
            return "rep mismatch"
        if set(self.rows.rep.opseq) & set(self.cols.rep.opseq):
            return "rep overlap"

    def is_valid(self):
        return not self.validity_problem()

    def validation_problem(self, f: Features):
        ret = self.validity_problem()
        if ret:
            return ret
        if not self.globl.validate(f.globl):
            return f"globl {self.globl.dims} != {f.globl.shapes}"
        if not self.rows.validate(f.rows):
            return f"rows {self.rows.dims} != {f.rows.shapes}"
        if not self.cols.validate(f.cols):
            return f"cols {self.cols.dims} != {f.cols.shapes}"
        if not self.cells.validate(f.cells):
            return f"cells {self.cells.dims} != {f.cells.shapes}"
        if self.shape is None:
            Y,X = f.cells.full.shape[-4:-2]
        else:
            Y,X = self.shape
        if f.rows.full.shape[-3] != Y:
            return f"rows [{X},{Y}] <> {f.rows.shapes}"
        if f.cols.full.shape[-3] != X:
            return f"cols [{X},{Y}] <> {f.cols.shapes}"
        if f.cells.full.shape[-4:-2] != (Y,X):
            return f"cols [{X},{Y}] <> {f.cells.shapes}"
        if f.ypos.shape[-2:] != (Y,2):
            return f"ypos [{X},{Y}] <> {f.ypos.shape}"
        if f.xpos.shape[-2:] != (X,2):
            return f"xpos [{X},{Y}] <> {f.xpos.shape}"
        if f.rmsk.shape[-1] != Y:
            return f"rmsk [{X},{Y}] <> {f.rmsk.shape}"
        if f.cmsk.shape[-1] != X:
            return f"cmsk [{X},{Y}] <> {f.cmsk.shape}"
        if f.mask.shape[-2:] != (Y,X):
            return f"mask [{X},{Y}] <> {f.mask.shape}"
        try:
            np.broadcast_shapes(
                f.globl.full.shape[:-2],
                f.rows.full.shape[:-3],
                f.cols.full.shape[:-3],
                f.cells.full.shape[:-4],
                f.ypos.shape[:-2],
                f.xpos.shape[:-2],
                f.rmsk.shape[:-1],
                f.cmsk.shape[:-1],
                f.mask.shape[:-2],
            )
            
        except ValueError:
            return f"batch {f.shapes}"

    def validate(self, f:Features):
        return not self.validation_problem(f)

    def make_empty(self, batch:tuple[int,...] = (), *, shape:tuple[int,int] | None = None) -> Features:
        if shape is None:
            shape = self.shape
            assert shape is not None
        else:
            assert self.shape is None or shape == self.shape
        Y,X = shape
        ret = Features(
            globl = self.globl.make_empty(batch),
            rows = self.rows.make_empty(batch+(Y,)),
            cols = self.cols.make_empty(batch+(X,)),
            cells = self.cells.make_empty(batch+shape),
            ypos = np.empty(batch+(Y,2)),
            xpos = np.empty(batch+(X,2)),
            rmsk = np.empty(batch+(Y,),bool),
            cmsk = np.empty(batch+(X,),bool),
            mask = np.empty(batch+(Y,X),bool),
        ) 
        assert self.validate(ret), self.validation_problem(ret)
        return ret

## Symmetry-preserving axial attention
Applying any point-symmetry to the input field and representation
should be equivalent to applying the point symmetry to the output field and representation

Formally, axial attention computes:
$$
Y(t) = \sum_r |r\rangle \int_\delta p_r[t,\delta] \langle r| V(t+\delta),
\quad\text{where}\quad
p_r[\delta] = w_r[\delta] \cdot \mathrm{softmax}_\delta \left[ \langle r| Q(t)) \cdot \phi_r(\delta) \langle r| K(t+\delta) \right]
$$

Furthermore, applying the operation $s$ to an input field $X(p)$ has the following effect:
$$
X'(p) = R_s X(R_s^{-1} p) = \sum_{uv} |u\rangle \langle u| R_s |v\rangle \langle v| X(R_s^{-1} p)
= \sum_v |s \cdot v\rangle \langle v| X(R_s^{-1} p),
$$
where the last equality holds for representation labels matching symmetry operations.

Thus, we require the following to hold:
$$
\begin{align}
Y'(t) &= \sum_r |r\rangle \int_\delta p'_r[t,\delta] \langle r| V'(t+\delta) \\
&= \sum_r |r\rangle \int_\delta p'_r[t,\delta] \langle s^{-1} r| V(R_s^{-1} (t+\delta)) \\
&= \sum_{r'} |s\cdot r'\rangle \int_\delta p'_{s\cdot r'}[t,\delta] \langle r'| V(R_s^{-1} (t+\delta)) \\
&= \sum_{r'} |s\cdot r'\rangle \int_{\delta'} p'_{s\cdot r'}[t,R_s \delta'] \langle r'| V(R_s^{-1} t+\delta') \\
&= R_s \sum_{r'} |r'\rangle \int_{\delta'} p'_{s\cdot r'}[t,R_s \delta'] \langle r'| V(R_s^{-1} t+\delta') \\
&= R_s Y(R_s^{-1} t).
\end{align}
$$

A sufficient condition for the last equality is $p'_{s\cdot r'}[t,R_s \delta'] = p_{r'}[R_s^{-1} t,\delta']$ everywhere.
Starting from the definition, we have
$$
\begin{align}
p'_{s\cdot r'}[t, R_s \delta'] &= w_{s\cdot r'}[R_s \delta'] \cdot
\mathrm{softmax}_\delta \left[ \langle s\cdot r'| Q'(t))
\cdot \phi_{s\cdot r'}(R_s \delta') \langle s\cdot r'| K'(t+R_s \delta') \right] \\
&= w_{s\cdot r'}[R_s \delta'] \cdot
\mathrm{softmax}_\delta \left[ \langle r'| Q(R_s^{-1} t))
\cdot \phi_{s\cdot r'}(R_s \delta') \langle r'| K(R_s^{-1} t + \delta') \right] \\
\end{align}
$$
thus, again, a sufficient condition are $w_{s\cdot r'}[R_s \delta'] = w_{r'}[\delta']$
and $\phi_{s\cdot r'}(R_s \delta')=\phi_{r'}(\delta')$.
This can be implemented by setting $\phi_r(\delta) = \vec \delta \cdot R_r \hat u$, and suitable rotation of $w$.

In [317]:
# symmetry-preserving axial attention: 
# - Applying any point-symmetry to the input field and representation
#   should be equivalent to applying the point symmetry to the output field and representation
# - Formally; attention computes: $Y(o) = \sum_s int_d p_u(d) V_u(d)$
for i,s in enumerate(standard_rep.opseq):
    a = Vector._vec2dir[tuple(transform_vector(s, Vector.RIGHT.as_array()))]
    b = Vector._vec2dir[tuple(transform_vector(s, Vector.DOWN.as_array()))]
    print(f"Componet {i}=|{s.name}>: axial={a:5s} transverse={b:5s}")
    

Componet 0=|e>: axial=right transverse=down 
Componet 1=|x>: axial=left  transverse=down 
Componet 2=|y>: axial=right transverse=up   
Componet 3=|i>: axial=left  transverse=up   
Componet 4=|t>: axial=down  transverse=right
Componet 5=|l>: axial=up    transverse=right
Componet 6=|r>: axial=down  transverse=left 
Componet 7=|d>: axial=up    transverse=left 


In [318]:
from jax import lax
from flax import nnx
from flax.nnx.nn.linear import default_kernel_init, default_bias_init, initializers
from flax.nnx.nn import dtypes
from flax.nnx import rnglib
from flax.typing import (
  Dtype,
  Initializer,
  PrecisionLike,
  DotGeneralT,
  PromoteDtypeFn,
)

class SymAttention(nnx.Module):
    """
    This module performs axial attention.
    
    For the "e" component of the representation,
    we have chosen an arbitrary axis along which to perform the attention;
    it determines the axis of attention for all other components.
    With trainable frequencies (allowing negative ones),
    revesing the direction would be equivalent, but rotations by 90° arent.
    Thus, with this choice, we break the symmetry within the
    representation. This is fine, if we do this only in one place.
    Otherwise, we'd have to augment this with a second attention axis.
    """
    def __init__(
        self,
        num_heads: int,
        in_features: FeatureDim,
        qkv_features: int,
        out_features: FeatureDim | None=None,
        *,
        num_groups: int | None = None,
        dtype: Dtype | None = None,
        param_dtype: Dtype = jnp.float32,
        broadcast_dropout: bool = True,
        dropout_rate: float = 0.0,
        deterministic: bool | None = None,
        precision: PrecisionLike = None,
        kernel_init: Initializer = default_kernel_init,
        out_kernel_init: Initializer | None = None,
        bias_init: Initializer = initializers.zeros_init(),
        out_bias_init: Initializer | None = None,
        use_bias: bool = True,
        # attention_fn: Callable[..., Array] = dot_product_attention,
        decode: bool | None = None,
        normalize_qk: bool = False,
        rngs: rnglib.Rngs,        
    ):
        if num_groups is None:
            num_groups = num_heads

        if out_features is None:
            out_features = in_features
        assert not qkv_features % (2*num_heads)
        assert not num_heads % num_groups
        self.n_features = n_features = qkv_features // (2*num_heads)
        self.in_features = in_features
        self.qkv_features = qkv_features
        self.out_features = out_features
        self.num_heads = num_heads
        self.num_groups = num_groups

        # frequency is both per group, per features, and linear in both absolute and relative
        kernel_key = rngs.params()
        freq_init = initializers.normal(1)
        self.freqs = nnx.Param(
            freq_init(kernel_key, (num_groups, n_features, 2), param_dtype)
        )

        def make_linear(inf,outf):
            # TODO dtypes, biases, etc...
            return SymmetricLinear(inf,outf,rngs=rngs)

        nqkv = (num_heads+2*num_groups)*n_features*2
        nv = num_groups*n_features*2
        self.qkv = {
            k:make_linear(v,attrs.evolve(v,iso=0,full=nqkv))
            for k,v in {k:getattr(in_features,k) for k in ["globl","rows","cols","cells"]}.items()
        }
        self.out = {
            k:make_linear(
                attrs.evolve(v.out_features,iso=0,full=nv),
                getattr(out_features,k),
            ) for k,v in self.qkv.items()
        }


    def __call__(self, features: Features) -> Features:
        assert self.in_features.validate(features), self.in_features.validation_problems(features)
        assert features.globl.rep == features.cells.rep
        
        Y,X = features.cells.full.shape[-4:-2]
        H = self.n_features
        D = 2*H
        K = self.num_groups
        N = self.num_heads
        
        xphi = jnp.einsum("...xa,kha -> ...xkh",features.xpos,self.freqs)
        yphi = jnp.einsum("...ya,kha -> ...ykh",features.ypos,self.freqs)
        phi = [yphi,xphi]
        
        # first; linear projection into QKV for each of the features separtely
        qkv = {}
        for k,v in self.qkv.items():
            inp = getattr(features, k)
            out = v(inp)
            assert not out.iso.size
            rep = out.rep
            out = out.full
            d = {}
            for kk,n in dict(Q=N*H*2,K=K*H*2,V=K*D).items():
                d[kk] = out[...,:n]
                out = out[...,n:]
            qkv[k] = SimpleNamespace(**d,rep=rep)
        qkv = SimpleNamespace(**qkv)
        # second; axial attention
        gres = []
        ares = []
        for axis in range(2):
            match axis:
                case 0:
                    maybe_swap = lambda a,i,j: jnp.swapaxes(a, i, j)
                case 1:
                    maybe_swap = lambda a,i,j: a
                case _:
                    raise RuntimeError
            # careful: performing attention along axis 0, column headers are global, row index acts as position
            # so in this case X is a batch dimension, and Y acts as source/target
            # globl shape: ... R hd
            # hdr shape: ... B R hd
            # axial shape
            #  - before `maybe_swap`: .... Y X R hd
            #  - after  `maybe_swap`: .... B L R hd
            hdr = [qkv.cols,qkv.rows][axis]
            hmsk = [features.cmsk,features.rmsk][axis]
            B = hdr.Q.shape[-3]
            Pi = np.array([qkv.cells.rep.op2idx[o] for o in hdr.rep.opseq])
            polarisation = np.array([transform_vector(o,Vector.DOWN.as_array())[axis] for o in hdr.rep.opseq])
            assert np.all(abs(polarisation) == 1)
            polarisation = (polarisation+1)//2
            # we concatenate global and axis headers for the KVs
            # but there will only be axis headers in the Qs
            # concatenation is along S/T, output shape is ... B S/T P hd,
            gQ = hdr.Q[...,:,None,:,:]
            gK = jnp.concatenate([
                jnp.tile(qkv.globl.K[...,None,None,Pi,:],(B,1,1,1)),
                hdr.K[...,:,None,:,:],
            ],axis=-3)
            gV = jnp.concatenate([
                jnp.tile(qkv.globl.V[...,None,None,Pi,:],(B,1,1,1)),
                hdr.V[...,:,None,:,:],
            ],axis=-3)
            mask = jnp.tile(hmsk[...,:,None],(1,gK.shape[-3]))

            def make_qkv(q,k,v,mask):
                # unravel hd -> (N H 2) / (K H 2) / (K D)
                return QKV(
                    query = q.reshape(*q.shape[:-1],N,H,2),
                    key = k.reshape(*k.shape[:-1],K,H,2),
                    value = v.reshape(*v.shape[:-1],K,D),
                    mask = mask,
                )
            
            res = attention_RoPE_with_global(
                globl = make_qkv(
                    gQ,gK,gV,
                    mask = mask,
                ),
                axial = make_qkv(
                    **{
                        k.lower():maybe_swap(v[...,:,:,Pi,:],-4,-3)
                        for k,v in vars(qkv.cells).items()
                        if k!="rep"
                    },
                    mask = maybe_swap(features.mask,-2,-1),
                ),
                pQ = phi[axis],
                polarisation = polarisation,
            )
            ohdr, oax = (v.reshape(*v.shape[:-2],N*D) for v in res)
            # ohdr now has dimensions ... B 1 P F
            # oax now has dimensions ... B S P F

            # TODO: global attention to axis headers
            
            assert ohdr.shape[-3] == 1
            gres.append(ohdr[...,:,0,:,:])
            ares.append(maybe_swap(oax,-4,-3))
        cells = jnp.concatenate(ares,axis=-2)
        orep = SymRep.from_seq(orep)
        
        # third; global attention
        # attention to cells
        assert qkv.globl.rep == qkv.cells.rep
        globl = jax.nn.dot_product_attention(
            # merge R directly into batch dimensions left of it
            query = qkv.globl.Q.reshape(-1,1,N,2*H), 
            # we first need to move R across X and Y before we can merge
            key = jnp.moveaxis(qkv.cells.K,-2,-4).reshape(-1,Y*X,K,2*H),
            # we first need to move R across X and Y before we can merge
            value = jnp.moveaxis(qkv.cells.V,-2,-4).reshape(-1,Y*X,K,D),
            mask = jnp.tile(features.mask,(R,1)).reshape(-1,1,1,Y*X),
        )
        assert globl.shape[-3] == 1
        globl = globl[:,0,:,:].reshape(qkv.globl.Q.shape[:-2],R,N*D)
        
        tmp = dict(
            globl = attrs.evolve(features.globl, iso=jnp.empty((0,),dtype), full=globl, rep=qkv.globl.rep),
            cols = attrs.evolve(features.cols, iso=jnp.empty((X,0),dtype), full=gres[0], rep=qkv.cols.rep),
            rows = attrs.evolve(features.rows, iso=jnp.empty((Y,0),dtype), full=gres[1], rep=qkv.rows.rep),
            cells = attrs.evolve(features.cells, iso=jnp.empty((Y,X,0),dtype), full=cells, rep=orep),
        )

        # finally; output projection
        output = attrs.evolve(features,**{k:self.out[k](v) for k,v in tmp.items()})
        return output
            
            

In [319]:
dim = FeatureDim(
    globl = EmbeddingDims(iso=128,full=64),
    rows = EmbeddingDims(iso=64,full=32,rep=SymRep.from_seq((SymOp.t,SymOp.l,SymOp.r,SymOp.d))),
    cols = EmbeddingDims(iso=64,full=32,rep=SymRep.from_seq((SymOp.e,SymOp.x,SymOp.y,SymOp.i))),
    cells = EmbeddingDims(iso=32,full=16),
    shape = (7,9),
)
assert dim.is_valid()
inp = dim.make_empty(batch=(3,))
inp.shapes

namespace(globl=namespace(iso=(3, 128), full=(3, 8, 64), rep=8),
          rows=namespace(iso=(3, 7, 64), full=(3, 7, 4, 32), rep=4),
          cols=namespace(iso=(3, 9, 64), full=(3, 9, 4, 32), rep=4),
          cells=namespace(iso=(3, 7, 9, 32), full=(3, 7, 9, 8, 16), rep=8),
          ypos=(3, 7, 2),
          xpos=(3, 9, 2),
          rmsk=(3, 7),
          cmsk=(3, 9),
          mask=(3, 7, 9))

In [320]:
attn = SymAttention(8, dim, 128, num_groups=4, rngs=nnx.Rngs(0))
out = attn(inp)

globl.shape=namespace(T=1, P=4, N=8, H=8, S=4, K=8, D=16, batch=(3, 9)) axial.shape=namespace(T=7, P=4, N=8, H=8, S=4, K=8, D=16, batch=(3, 9)) pQ.shape=(3, 7, 4, 8)


AssertionError: RoPE requires equal S & T: namespace(T=7, P=4, N=8, H=8, S=4, K=8, D=16, batch=(3, 9))

In [212]:
{k:v.in_features.rep.opseq for k,v in attn.qkv.items()}

{'globl': (<SymOp.e: 0>,
  <SymOp.x: 1>,
  <SymOp.y: 2>,
  <SymOp.i: 3>,
  <SymOp.t: 4>,
  <SymOp.l: 5>,
  <SymOp.r: 6>,
  <SymOp.d: 7>),
 'rows': (<SymOp.t: 4>, <SymOp.l: 5>, <SymOp.r: 6>, <SymOp.d: 7>),
 'cols': (<SymOp.e: 0>, <SymOp.x: 1>, <SymOp.y: 2>, <SymOp.i: 3>),
 'cells': (<SymOp.e: 0>,
  <SymOp.x: 1>,
  <SymOp.y: 2>,
  <SymOp.i: 3>,
  <SymOp.t: 4>,
  <SymOp.l: 5>,
  <SymOp.r: 6>,
  <SymOp.d: 7>)}

In [None]:
class SymencLayer(nnx.Module):
    def __init__(
        self,
        hidden_size: FeatureDim,
        *,
        mlp_width_factor: float,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        def norms(features: FeatureDim = hidden_size,**kw):
            return SimpleNamespace({
                k:SimpleNamespace({
                    kk: nnx.LayerNorm(getattr(getattr(features,k),kk), rngs=rngs, **kw)
                    for kk in ["iso", "full"]
                }) for k in ["globl", "rowcol", "cells"]
            })
        
        # First layer normalization using `flax.nnx.LayerNorm`
        # before we apply Multi-Head Attentn.
        self.norm1 = norms()
        # The Multi-Head Attention layer (using `flax.nnx.MultiHeadAttention`).
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            normalize_qk=False, # True to stabilise learning in ViT-22B; see paper http://arxiv.org/abs/2302.05442
            rngs=rngs,
        )
        # Second layer normalization using `flax.nnx.LayerNorm`.
        self.norm2 = norms()

        # The MLP for point-wise feedforward (using `flax.nnx.Sequential`, `flax.nnx.Linear, flax.nnx.Dropout`)
        # with the GeLU activation function (`flax.nnx.gelu`).
        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )

    # The forward pass through the transformer encoder block.
    def __call__(self, x: jax.Array) -> jax.Array:
        # The Multi-Head Attention layer with layer normalization.
        x = x + self.attn(self.norm1(x))
        # The feed-forward network with layer normalization.
        x = x + self.mlp(self.norm2(x))
        return x

# Example usage for testing:
x = jnp.ones((4, 224, 224, 3))
model = VisionTransformer(num_classes=1000)
y = model(x)
print("Predictions shape: ", y.shape)

In [61]:
10e6/4e3/3600/24

0.028935185185185185