In [1]:
import typing
from types import MappingProxyType, SimpleNamespace

import attrs
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import jaxtyping as jt
import numpy as np

from arc25.symmetry import SymOp, transform_vector
from arc25.dsl.types import Vector, Dir4
from arc25.vision.symrep import SymRep, Embedding, EmbeddingDims
from arc25.vision.linear import SymmetricLinear
from arc25.vision.rope import QKV, attention_RoPE_with_global

## Symmetry-preserving axial attention
Applying any point-symmetry to the input field and representation
should be equivalent to applying the point symmetry to the output field and representation

Formally, axial attention computes:
$$
Y(t) = \sum_r |r\rangle \int_\delta p_r[t,\delta] \langle r| V(t+\delta),
\quad\text{where}\quad
p_r[\delta] = w_r[\delta] \cdot \mathrm{softmax}_\delta \left[ \langle r| Q(t)) \cdot \phi_r(\delta) \langle r| K(t+\delta) \right]
$$

Furthermore, applying the operation $s$ to an input field $X(p)$ has the following effect:
$$
X'(p) = R_s X(R_s^{-1} p) = \sum_{uv} |u\rangle \langle u| R_s |v\rangle \langle v| X(R_s^{-1} p)
= \sum_v |s \cdot v\rangle \langle v| X(R_s^{-1} p),
$$
where the last equality holds for representation labels matching symmetry operations.

Thus, we require the following to hold:
$$
\begin{align}
Y'(t) &= \sum_r |r\rangle \int_\delta p'_r[t,\delta] \langle r| V'(t+\delta) \\
&= \sum_r |r\rangle \int_\delta p'_r[t,\delta] \langle s^{-1} r| V(R_s^{-1} (t+\delta)) \\
&= \sum_{r'} |s\cdot r'\rangle \int_\delta p'_{s\cdot r'}[t,\delta] \langle r'| V(R_s^{-1} (t+\delta)) \\
&= \sum_{r'} |s\cdot r'\rangle \int_{\delta'} p'_{s\cdot r'}[t,R_s \delta'] \langle r'| V(R_s^{-1} t+\delta') \\
&= R_s \sum_{r'} |r'\rangle \int_{\delta'} p'_{s\cdot r'}[t,R_s \delta'] \langle r'| V(R_s^{-1} t+\delta') \\
&= R_s Y(R_s^{-1} t).
\end{align}
$$

A sufficient condition for the last equality is $p'_{s\cdot r'}[t,R_s \delta'] = p_{r'}[R_s^{-1} t,\delta']$ everywhere.
Starting from the definition, we have
$$
\begin{align}
p'_{s\cdot r'}[t, R_s \delta'] &= w_{s\cdot r'}[R_s \delta'] \cdot
\mathrm{softmax}_\delta \left[ \langle s\cdot r'| Q'(t))
\cdot \phi_{s\cdot r'}(R_s \delta') \langle s\cdot r'| K'(t+R_s \delta') \right] \\
&= w_{s\cdot r'}[R_s \delta'] \cdot
\mathrm{softmax}_\delta \left[ \langle r'| Q(R_s^{-1} t))
\cdot \phi_{s\cdot r'}(R_s \delta') \langle r'| K(R_s^{-1} t + \delta') \right] \\
\end{align}
$$
thus, again, a sufficient condition are $w_{s\cdot r'}[R_s \delta'] = w_{r'}[\delta']$
and $\phi_{s\cdot r'}(R_s \delta')=\phi_{r'}(\delta')$.
This can be implemented by setting $\phi_r(\delta) = \vec \delta \cdot R_r \hat u$, and suitable rotation of $w$.

In [2]:
@attrs.frozen
class Features:
    globl: Embedding # dimensions (... F R? C); full representation
    rows: Embedding # dimensions (... Y F R? C); representation (t,l,r,d)
    cols: Embedding # dimensions (... X F R? C); representation (e,x,y,i)
    cells: Embedding # dimensions (... Y X F R? C); full representation
    ypos: jt.Float[jt.Array,"... Y 2"] # (absolute positions, relative positions)
    xpos: jt.Float[jt.Array,"... X 2"] # (absolute positions, relative positions)
    rmsk: jt.Bool[jt.Array,"... Y"]
    cmsk: jt.Bool[jt.Array,"... X"]
    mask: jt.Bool[jt.Array,"... Y X"]

    @property
    def shapes(self):
        return SimpleNamespace({
            k:v.shapes if isinstance(v, Embedding) else v.shape
            for k,v in attrs.asdict(self, recurse=False).items()
        })

@attrs.frozen
class FeatureDim:
    globl: EmbeddingDims
    rows: EmbeddingDims
    cols: EmbeddingDims
    cells: EmbeddingDims
    # these are in fact optional, as we don't need them for any weight calculation
    flavours: int | None = None
    shape: tuple[int,int] | None = None

    def validity_problem(self):
        if not (
            self.globl.rep.is_valid()
            and self.rows.rep.is_valid()
            and self.cols.rep.is_valid()
            and self.cells.rep.is_valid()
        ):
            return "invalid rep"
        if not set(self.globl.rep.opseq) == set(self.rows.rep.opseq) | set(self.cols.rep.opseq) == set(self.cells.rep.opseq):
            return "rep mismatch"
        if set(self.rows.rep.opseq) & set(self.cols.rep.opseq):
            return "rep overlap"
        for k in ["iso","full"]:
            if getattr(self.rows,k) != getattr(self.cols,k):
                return f"row/col mismatch on {k}"

    def is_valid(self):
        return not self.validity_problem()

    def validation_problem(self, f: Features):
        ret = self.validity_problem()
        if ret:
            return ret
        if not self.globl.validate(f.globl):
            return f"globl {self.globl.dims} != {f.globl.shapes}"
        if not self.rows.validate(f.rows):
            return f"rows {self.rows.dims} != {f.rows.shapes}"
        if not self.cols.validate(f.cols):
            return f"cols {self.cols.dims} != {f.cols.shapes}"
        if not self.cells.validate(f.cells):
            return f"cells {self.cells.dims} != {f.cells.shapes}"
        if self.flavours is None:
            F = f.globls.iso.shape[-2]
        else:
            F = self.flavours
        if self.shape is None:
            Y,X = f.cells.full.shape[-4:-2]
        else:
            Y,X = self.shape
        shi = f"[{Y},{X},{F}]"
        if f.rows.full.shape[-4:-2] != (Y,F):
            return f"rows {shi} <> {f.rows.shapes}"
        if f.cols.full.shape[-4:-2] != (X,F):
            return f"cols {shi} <> {f.cols.shapes}"
        if f.cells.full.shape[-5:-2] != (Y,X,F):
            return f"cols {shi} <> {f.cells.shapes}"
        if f.ypos.shape[-2:] != (Y,2):
            return f"ypos {shi} <> {f.ypos.shape}"
        if f.xpos.shape[-2:] != (X,2):
            return f"xpos {shi} <> {f.xpos.shape}"
        if f.rmsk.shape[-1] != Y:
            return f"rmsk {shi} <> {f.rmsk.shape}"
        if f.cmsk.shape[-1] != X:
            return f"cmsk {shi} <> {f.cmsk.shape}"
        if f.mask.shape[-2:] != (Y,X):
            return f"mask {shi} <> {f.mask.shape}"
        try:
            np.broadcast_shapes(
                f.globl.iso.shape[:-2],
                f.rows.iso.shape[:-3],
                f.cols.iso.shape[:-3],
                f.cells.iso.shape[:-4],
                f.globl.full.shape[:-3],
                f.rows.full.shape[:-4],
                f.cols.full.shape[:-4],
                f.cells.full.shape[:-5],
                f.ypos.shape[:-2],
                f.xpos.shape[:-2],
                f.rmsk.shape[:-1],
                f.cmsk.shape[:-1],
                f.mask.shape[:-2],
            )
            
        except ValueError:
            return f"batch {f.shapes}"

    def validate(self, f:Features):
        return not self.validation_problem(f)

    def make_empty(self, batch:tuple[int,...] = (), *, shape:tuple[int,int] | None = None, flavours:int|None = None) -> Features:
        if shape is None:
            shape = self.shape
            assert shape is not None
        else:
            assert self.shape is None or shape == self.shape
        if flavours is None:
            flavours = self.flavours
            assert flavours is not None
        else:
            assert self.n_flavours is None or flavours == self.flavours
        Y,X = shape
        F = flavours
        ret = Features(
            globl = self.globl.make_empty(batch+(F,)),
            rows = self.rows.make_empty(batch+(Y,F)),
            cols = self.cols.make_empty(batch+(X,F)),
            cells = self.cells.make_empty(batch+shape+(F,)),
            ypos = np.empty(batch+(Y,2)),
            xpos = np.empty(batch+(X,2)),
            rmsk = np.empty(batch+(Y,),bool),
            cmsk = np.empty(batch+(X,),bool),
            mask = np.empty(batch+(Y,X),bool),
        ) 
        assert self.validate(ret), self.validation_problem(ret)
        return ret

In [3]:
from arc25.vision.rope import show_dims

In [22]:
from jax import lax
from flax import nnx
from flax.nnx.nn.linear import default_kernel_init, default_bias_init, initializers
from flax.nnx.nn import dtypes
from flax.nnx import rnglib
from flax.typing import (
  Dtype,
  Initializer,
  PrecisionLike,
  DotGeneralT,
  PromoteDtypeFn,
)

class SymAttention(nnx.Module):
    """
    This module performs axial attention.
    
    For the "e" component of the representation,
    we have chosen an arbitrary axis along which to perform the attention;
    it determines the axis of attention for all other components.
    With trainable frequencies (allowing negative ones),
    revesing the direction would be equivalent, but rotations by 90° arent.
    Thus, with this choice, we break the symmetry within the
    representation. This is fine, if we do this only in one place.
    Otherwise, we'd have to augment this with a second attention axis.
    """
    def __init__(
        self,
        num_heads: int,
        in_features: FeatureDim,
        qkv_features: int,
        out_features: FeatureDim | None=None,
        *,
        global_mix_reduction: int = 4,
        num_groups: int | None = None,
        dtype: Dtype | None = None,
        param_dtype: Dtype = jnp.float32,
        broadcast_dropout: bool = True,
        dropout_rate: float = 0.0,
        deterministic: bool | None = None,
        precision: PrecisionLike = None,
        kernel_init: Initializer = default_kernel_init,
        # out_kernel_init: Initializer | None = None,
        bias_init: Initializer = initializers.zeros_init(),
        # out_bias_init: Initializer | None = None,
        use_bias: bool = True,
        hdrs_attend: bool = False,
        # attention_fn: Callable[..., Array] = dot_product_attention,
        normalize_qk: bool = False,
        rngs: rnglib.Rngs,        
    ):
        if num_groups is None:
            num_groups = num_heads

        if out_features is None:
            out_features = in_features
        assert not qkv_features % (2*num_heads)
        assert not num_heads % num_groups
        self.n_features = n_features = qkv_features // (2*num_heads)
        self.global_mix_reduction = global_mix_reduction
        self.in_features = in_features
        self.qkv_features = qkv_features
        self.out_features = out_features
        self.hdrs_attend = hdrs_attend
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.dtype = dtype
        self.param_dtype = param_dtype

        # frequency is both per group, per features, and linear in both absolute and relative
        kernel_key = rngs.params()
        freq_init = initializers.normal(1)
        self.freqs = nnx.Param(
            freq_init(kernel_key, (num_groups, n_features, 2), param_dtype)
        )

        def make_linear(inf,outf,*,cls=SymmetricLinear):
            # TODO: do we have them all?
            return cls(
                inf,
                outf,
                dtype=dtype,
                param_dtype=param_dtype,
                kernel_init=kernel_init,
                bias_init=bias_init,
                use_bias=use_bias,
                precision=precision,
                rngs=rngs,
            )

        nqkv = (num_heads+2*num_groups)*n_features*2

        self.cell_global_mix = make_linear(
            in_features.globl.iso,
            in_features.cells.iso//global_mix_reduction,
            cls=nnx.Linear,
        )
        mixmap = dict(
            cells = self.cell_global_mix.out_features,
        )

        if hdrs_attend:
            hdr_iso, = {in_features.rows.iso, in_features.cols.iso}
            hdr_mix = hdr_iso//global_mix_reduction
            self.hdr_global_mix = make_linear(
                in_features.globl.iso,
                hdr_mix,
                cls=nnx.Linear,
            )
            mixmap.update(rows=hdr_mix,cols=hdr_mix)
            skip_iso = []
        else:
            self.hdr_global_mix = nnx.data(None)
            skip_iso = ["rows","cols"]
        self.qkv = {
            k:make_linear(
                attrs.evolve(v,iso=v.iso+mixmap.get(k,0)),
                attrs.evolve(v,iso=nqkv if k not in skip_iso else 0,full=nqkv),
            )
            for k,v in {k:getattr(in_features,k) for k in ["globl","rows","cols","cells"]}.items()
        }
        nv = num_heads*n_features*2
        self.out = {
            k:make_linear(
                attrs.evolve(v.out_features,iso=dict(globl=2*nv,cells=nv).get(k,nv if hdrs_attend else 0),full=nv),
                getattr(out_features,k),
            ) for k,v in self.qkv.items()
        }


    def __call__(self, features: Features) -> Features:
        assert self.in_features.validate(features), self.in_features.validation_problems(features)
        assert features.globl.rep == features.cells.rep

        R = features.cells.rep.dim
        batch = features.cells.full.shape[:-5]
        B = int(np.prod(batch))
        Y,X,F = features.cells.full.shape[-5:-2]
        H = self.n_features
        D = 2*H
        K = self.num_groups
        N = self.num_heads

        print(f"{batch=} {B=} {Y=} {X=} {F=} {R=} {N=} {K=} {H=} {D=}")

        # `o` dimension: singular dimension to add as broadcast for "other" spatial axis
        xphi = jnp.einsum("...oxa,kha -> ...oxkh",features.xpos[...,None,:,:],self.freqs)
        yphi = jnp.einsum("...oya,kha -> ...oykh",features.ypos[...,None,:,:],self.freqs)
        phi = [yphi,xphi]
        
        # first; linear projection into QKV for each of the features separtely
        cell_mix = self.cell_global_mix(features.globl.iso)[...,None,None,:,:]
        mixmap = dict(
            cells=jnp.tile(cell_mix,(Y,X,1,1)),
        )
        if self.hdrs_attend:
            hdr_mix = self.hdr_global_mix(features.globl.iso)[...,None,:,:]
            mixmap.update(
                rows=jnp.tile(hdr_mix,(Y,1,1)),
                cols=jnp.tile(hdr_mix,(X,1,1)),                
            )
        qkv = {}
        qkvi = {}
        for k,v in self.qkv.items():
            inp = getattr(features, k)
            print(f"{k}: {inp.shapes} {v.in_features=} {v.out_features=}")
            mix = mixmap.get(k)
            if mix is not None:
                inp = attrs.evolve(inp,iso=jnp.concatenate([ mix,inp.iso],axis=-1))
            out = v(inp)
            rep = out.rep
            full = out.full
            iso = out.iso
                
            di = {}
            d = {}
            for kk,n in dict(Q=N*H*2,K=K*H*2,V=K*D).items():
                d[kk] = dd = full[...,:n]
                print(f"full {k}.{kk}.shape = {dd.shape}")
                full = full[...,n:]
                if not self.hdrs_attend and k in {"rows","cols"}:
                    continue
                di[kk] = dd = iso[...,:n]
                print(f"iso {k}.{kk}.shape = {dd.shape}")
                iso = iso[...,n:]
            assert not iso.size, f"{k}: {iso.shape=}"
            assert not full.size
            qkvi[k] = SimpleNamespace(**di)
            qkv[k] = SimpleNamespace(**d,rep=rep)
        qkv = SimpleNamespace(**qkv)
        qkvi = SimpleNamespace(**qkvi)
        
        # second; axial attention
        gres = []
        ares = []
        orep = []
        for axis in range(2):
            # careful: performing attention along axis 0, column headers are global, row index acts as position
            # so in this case X is a batch dimension, and Y acts as source/target
            # globl shape: ... F R hd
            # hdr shape: ... Y/X F R hd
            # axial shape .... Y X F R hd
            hdr = [qkv.cols,qkv.rows][axis]
            hmsk,ohmsk = [features.cmsk,features.rmsk][::(1,-1)[axis]]
            oS = hdr.Q.shape[-4]
            Pi = np.array([qkv.cells.rep.op2idx[o] for o in hdr.rep.opseq])
            P = Pi.size
            orep.extend(hdr.rep.opseq)
            polarisation = np.array([transform_vector(o,Vector.DOWN.as_array())[axis] for o in hdr.rep.opseq])
            assert np.all(abs(polarisation) == 1)
            polarisation = (polarisation+1)//2
            # first, reshape stuff into "... tB S/T tF P hd" style; this way we have fixed axis positions
            tB,tF = [(1,oS*F),(oS,F)][axis]
            gK = qkv.globl.K[...,Pi,:].reshape(*batch,1,1,F,P,K*H*2)
            gV = qkv.globl.V[...,Pi,:].reshape(*batch,1,1,F,P,K*D)
            gK,gV = [jnp.tile(v,[
                (1,1,oS,1,1),
                (oS,1,1,1,1),
            ][axis]) for v in (gK,gV)]
            hQ = hdr.Q.reshape(*batch,tB,1,tF,P,N*H*2)
            hK = hdr.K.reshape(*batch,tB,1,tF,P,K*H*2)
            hV = hdr.V.reshape(*batch,tB,1,tF,P,K*D)
            # now, we can concatenate along axis 2
            ghK = jnp.concatenate([gK,hK],axis=-4)
            ghV = jnp.concatenate([gV,hV],axis=-4)

            def make_qkv(q,k,v,*,mask,S,T):
                # unravel hd -> (N H 2) / (K H 2) / (K D)
                return QKV(
                    query = q.reshape(*batch,tB,T,tF,P,N,H,2),
                    key = k.reshape(*batch,tB,S,tF,P,K,H,2),
                    value = v.reshape(*batch,tB,S,tF,P,K,D),
                    mask = mask,
                )
            S = T = ohmsk.shape[-1]
            res = attention_RoPE_with_global(
                globl = make_qkv(
                    hQ,ghK,ghV,
                    mask = None, # np.ones(2,bool),
                    T=1,
                    S=2,
                ),
                axial = make_qkv(
                    **{
                        k.lower():v[...,Pi,:]
                        for k,v in vars(qkv.cells).items()
                        if k!="rep"
                    },
                    T=T,
                    S=S,
                    mask = ohmsk[...,None,:],
                ),
                pQ = phi[axis],
                polarisation = polarisation,
            )
            ohdr, oax = (v.reshape(*v.shape[:-2],N*D) for v in res)
            # ohdr now has dimensions tB 1 tF P C
            assert ohdr.shape[-4] == 1
            # oax now has dimensions tB S tF P C

            # TODO: global attention to axis headers?
            gres.append(ohdr[...,:,0,:,:,:].reshape(*batch,oS,F,P,N*D))
            ares.append(oax.reshape(*batch,Y,X,F,P,N*D))
                        
        cells = jnp.concatenate(ares,axis=-2)
        orep = SymRep.from_seq(orep)

        efc = {}
        for k in "QKV":
            g = getattr(qkvi.globl,k)[...,None,:,:]  # ... F C
            c = getattr(qkvi.cells,k).reshape(*batch,Y*X,F,-1)  # ... Y X F C
            print(f"g: {show_dims("sfc",g)}")
            print(f"c: {show_dims("sfc",c)}")
            v = jnp.concatenate([g,c],axis=-3)
            v = v.reshape(*batch,Y*X+1,F,*dict(Q=(N,2*H),K=(K,2*H),V=(K,D))[k])
            efc[k] = v
        efc = SimpleNamespace(**efc)
        
        # third; pointwise self-attention across flavours
        pwatt = jax.nn.dot_product_attention(
            query = efc.Q.reshape(B*(Y*X+1),F,N,2*H),
            key = efc.K.reshape(B*(Y*X+1),F,K,2*H),
            value = efc.V.reshape(B*(Y*X+1),F,K,D),
            # mask = features.mask[...,None], 
        )
        pwatt = pwatt.reshape(*batch,Y*X+1,F,N*D)
        globl_self = pwatt[...,0,:,:]
        cells_iso = pwatt[...,1:,:,:].reshape(*batch,Y,X,F,-1)

        if self.hdrs_attend:
            raise NotImplementedError(f"We'd need to implement cross-flavour attention for headers here")

        # fourth; global attention
        glatt = jax.nn.dot_product_attention(
            query = jnp.swapaxes(efc.Q[:,:1,:,:,:],-3,-4).reshape(B*F,1,N,2*H),
            key = jnp.swapaxes(efc.K,-3,-4).reshape(B*F,Y*X+1,K,2*H),
            value = jnp.swapaxes(efc.V,-3,-4).reshape(B*F,Y*X+1,K,D),
            mask = jnp.concatenate([
                # TODO: should we self-attend here?
                jnp.zeros((B*F,1,1,1),bool),
                jnp.tile(features.mask,(F,1)).reshape(-1,1,1,Y*X),
            ],axis=-1),
        )
        assert glatt.shape[-3] == 1
        glatt = glatt.reshape(*batch,F,N*D)
        globl2celliso = glatt
        
        # fifth; global dihedral attention
        # attention to cells
        assert qkv.globl.rep == qkv.cells.rep
        globl2cell = jax.nn.dot_product_attention(
            # merge F & R directly into batch dimensions left of it
            query = qkv.globl.Q.reshape(-1,1,N,2*H), 
            # we first need to move F&R across X and Y before we can merge
            key = jnp.moveaxis(qkv.cells.K,(-3,-2),(-5,-4)).reshape(-1,Y*X,K,2*H),
            # we first need to move F&R across X and Y before we can merge
            value = jnp.moveaxis(qkv.cells.V,(-3,-2),(-5,-4)).reshape(-1,Y*X,K,D),
            mask = jnp.tile(features.mask,(F*R,1)).reshape(-1,1,1,Y*X),
        )
        assert globl2cell.shape[-3] == 1
        globl2cell = globl2cell.reshape(*batch,F,R,N*D)

        print(f"{globl2celliso.shape=}")
        print(f"{cells_iso.shape=}")
        print(f"{globl_self.shape=}")
        tmp = dict(
            globl = attrs.evolve(features.globl, iso=jnp.concatenate([globl_self,globl2celliso],-1), full=globl2cell, rep=qkv.globl.rep),
            cols = attrs.evolve(features.cols, iso=jnp.empty(batch+(X,F,0),self.dtype), full=gres[0], rep=qkv.cols.rep),
            rows = attrs.evolve(features.rows, iso=jnp.empty(batch+(Y,F,0),self.dtype), full=gres[1], rep=qkv.rows.rep),
            cells = attrs.evolve(features.cells, iso=cells_iso, full=cells, rep=orep),
        )

        for k,v in tmp.items():
            print(f"{k}: {v.iso.shape=} {v.full.shape=}")
        
        # finally; output projection
        output = attrs.evolve(features,**{k:self.out[k](v) for k,v in tmp.items()})
        assert self.out_features.validate(output), self.out_features.validation_problems(output)
        return output
            
            

In [23]:
dim = FeatureDim(
    globl = EmbeddingDims(iso=128,full=64),
    rows = EmbeddingDims(iso=64,full=32,rep=SymRep.from_seq((SymOp.t,SymOp.l,SymOp.r,SymOp.d))),
    cols = EmbeddingDims(iso=64,full=32,rep=SymRep.from_seq((SymOp.e,SymOp.x,SymOp.y,SymOp.i))),
    cells = EmbeddingDims(iso=32,full=16),
    flavours = 10,
    shape = (7,9),
)
assert dim.is_valid()
inp = dim.make_empty(batch=(3,))
inp.shapes

namespace(globl=namespace(iso=(3, 10, 128), full=(3, 10, 8, 64), rep=8),
          rows=namespace(iso=(3, 7, 10, 64), full=(3, 7, 10, 4, 32), rep=4),
          cols=namespace(iso=(3, 9, 10, 64), full=(3, 9, 10, 4, 32), rep=4),
          cells=namespace(iso=(3, 7, 9, 10, 32),
                          full=(3, 7, 9, 10, 8, 16),
                          rep=8),
          ypos=(3, 7, 2),
          xpos=(3, 9, 2),
          rmsk=(3, 7),
          cmsk=(3, 9),
          mask=(3, 7, 9))

In [24]:
attn = SymAttention(8, dim, 128, num_groups=4, rngs=nnx.Rngs(0))
out = attn(inp)

batch=(3,) B=3 Y=7 X=9 F=10 R=8 N=8 K=4 H=8 D=16
globl: namespace(iso=(3, 10, 128), full=(3, 10, 8, 64), rep=8) v.in_features=EmbeddingDims(iso=128, full=64, rep=SymRep(opseq=(e,x,y,i,t,l,r,d))) v.out_features=EmbeddingDims(iso=256, full=256, rep=SymRep(opseq=(e,x,y,i,t,l,r,d)))
full globl.Q.shape = (3, 10, 8, 128)
iso globl.Q.shape = (3, 10, 128)
full globl.K.shape = (3, 10, 8, 64)
iso globl.K.shape = (3, 10, 64)
full globl.V.shape = (3, 10, 8, 64)
iso globl.V.shape = (3, 10, 64)
rows: namespace(iso=(3, 7, 10, 64), full=(3, 7, 10, 4, 32), rep=4) v.in_features=EmbeddingDims(iso=64, full=32, rep=SymRep(opseq=(t,l,r,d))) v.out_features=EmbeddingDims(iso=0, full=256, rep=SymRep(opseq=(t,l,r,d)))
full rows.Q.shape = (3, 7, 10, 4, 128)
full rows.K.shape = (3, 7, 10, 4, 64)
full rows.V.shape = (3, 7, 10, 4, 64)
cols: namespace(iso=(3, 9, 10, 64), full=(3, 9, 10, 4, 32), rep=4) v.in_features=EmbeddingDims(iso=64, full=32, rep=SymRep(opseq=(e,x,y,i))) v.out_features=EmbeddingDims(iso=0, full=2

In [39]:
def count(obj):
    match obj:
        case np.ndarray():
            return obj.size
        case int():
            return 0
        case SymRep():
            return 0
        case _:
            if attrs.has(type(obj)):
                return sum(count(v) for v in attrs.asdict(obj,recurse=False).values())
            print(type(obj).__name__)
            return 0
count(inp)

414093

In [40]:
nnx.display(attn)

In [None]:
class SymencLayer(nnx.Module):
    def __init__(
        self,
        hidden_size: FeatureDim,
        *,
        mlp_width_factor: float,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        def norms(features: FeatureDim = hidden_size,**kw):
            return SimpleNamespace({
                k:SimpleNamespace({
                    kk: nnx.LayerNorm(getattr(getattr(features,k),kk), rngs=rngs, **kw)
                    for kk in ["iso", "full"]
                }) for k in ["globl", "rowcol", "cells"]
            })
        
        # First layer normalization using `flax.nnx.LayerNorm`
        # before we apply Multi-Head Attentn.
        self.norm1 = norms()
        # The Multi-Head Attention layer (using `flax.nnx.MultiHeadAttention`).
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            normalize_qk=False, # True to stabilise learning in ViT-22B; see paper http://arxiv.org/abs/2302.05442
            rngs=rngs,
        )
        # Second layer normalization using `flax.nnx.LayerNorm`.
        self.norm2 = norms()

        # The MLP for point-wise feedforward (using `flax.nnx.Sequential`, `flax.nnx.Linear, flax.nnx.Dropout`)
        # with the GeLU activation function (`flax.nnx.gelu`).
        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )

    # The forward pass through the transformer encoder block.
    def __call__(self, x: jax.Array) -> jax.Array:
        # The Multi-Head Attention layer with layer normalization.
        x = x + self.attn(self.norm1(x))
        # The feed-forward network with layer normalization.
        x = x + self.mlp(self.norm2(x))
        return x

# Example usage for testing:
x = jnp.ones((4, 224, 224, 3))
model = VisionTransformer(num_classes=1000)
y = model(x)
print("Predictions shape: ", y.shape)

In [None]:
10e6/4e3/3600/24