In [1]:
from typing import Tuple, Union, Optional
import jax
import jax.numpy as jnp
import jax.random as jr 
import equinox as eqx
from einops import rearrange

In [2]:
class Residual(eqx.Module):
    fn: eqx.Module
    
    def __init__(self, fn):
        self.fn = fn

    def __call__(self, x, **kwargs):
        return self.fn(x, **kwargs) + x


class PreNorm(eqx.Module):
    norm: eqx.nn.LayerNorm
    fn: eqx.Module

    def __init__(self, dim, fn):
        self.norm = eqx.nn.LayerNorm((dim,))
        self.fn = fn

    def __call__(self, x, **kwargs):
        return self.fn(jax.vmap(self.norm)(x), **kwargs)


class FeedForward(eqx.Module):
    net: Tuple[eqx.Module]

    def __init__(self, dim, hidden_dim, dropout=0., *, key):
        keys = jr.split(key)
        self.net = (
            eqx.nn.Linear(dim, hidden_dim, key=keys[0]),
            jax.nn.gelu,
            eqx.nn.Dropout(dropout),
            eqx.nn.Linear(hidden_dim, dim, key=keys[1]),
            eqx.nn.Dropout(dropout)
        )

    def __call__(self, x, key):
        for i, l in enumerate(self.net):
            if isinstance(l, eqx.nn.Dropout):
                x = l(x, key=jr.fold_in(key, i))
            else:
                x = jax.vmap(l)(x)
        return x


class Attention(eqx.Module):
    heads: int
    scale: float
    to_qkv: eqx.nn.Linear
    to_out: Union[eqx.nn.Sequential, eqx.nn.Identity]
    project_out: bool

    def __init__(self, dim, heads, dim_head, dropout=0., *, key):
        keys = jr.split(key)

        inner_dim = dim_head * heads
        self.project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = eqx.nn.Linear(dim, inner_dim * 3, use_bias=False, key=keys[0])

        self.to_out = (
            eqx.nn.Linear(inner_dim, dim, key=keys[1]),
            eqx.nn.Dropout(dropout)
        ) if self.project_out else eqx.nn.Identity()

    def __call__(self, x, key):
        n, _, h = *x.shape, self.heads

        _qkv = jax.vmap(self.to_qkv)(x)
        qkv = jnp.split(_qkv, 3, axis=-1)

        q, k, v = map(lambda t: rearrange(t, 'n (h d) -> h n d', h=h), qkv)

        # Dot-product multihead-attention 
        dots = jnp.einsum('h i d, h j d -> h i j', q, k) * self.scale
        attn = jax.nn.softmax(dots, axis=-1)

        out = jnp.einsum('h i j, h j d -> h i d', attn, v)
        out = rearrange(out, 'h n d -> n (h d)')
        if self.project_out:
            out = jax.vmap(self.to_out[0])(out)
            out = self.to_out[1](out, key=key)
        return out


class StochasticDepth(eqx.Module):
    rate: float
    inference: bool = True

    def __init__(self, rate):
        """ Dropout whole layers of encoder. Do so by turning gradients off. """
        self.rate = rate
    
    def __call__(self, layers, key):
        if not self.inference:
            # Choose idx of layers to drop
            freeze_ix = jr.choice(
                key,
                jnp.arange(len(layers)), 
                (int(self.rate * len(layers)),)
            )
            # Treemap jax.lax.stop_gradient to array leaves of each layer module
            _layers = eqx.tree_at(
            # Return nodes of layers (that are linear layers themselves) that should be replaced in layers with 'frozen' equivalents
                lambda layers: [layers[i] for i in freeze_ix], 
                # Only stop gradients of arrays
                eqx.filter(layers, eqx.is_array),
                # Replace leaves with 'frozen' leaves
                replace_fn=lambda x: jax.lax.stop_gradient(x)
            )
        else:
            _layers = layers
        # Use affected layers
        return _layers

In [3]:
class Transformer(eqx.Module):
    layers: Tuple[eqx.Module]
    stochastic_depth: Optional[StochasticDepth] = None

    def __init__(
        self, dim, depth, heads, dim_head, mlp_dim, dropout=0., stochastic_depth_rate=0.5, *, key
    ):
        layers = []
        for _ in range(depth):
            key = jr.fold_in(key, _)
            keys = jr.split(key)
            layers.append(
                (
                    PreNorm(
                        dim, 
                        Attention(
                            dim, heads=heads, dim_head=dim_head, dropout=dropout, key=keys[0]
                        )
                    ),
                    PreNorm(
                        dim, 
                        FeedForward(
                            dim, mlp_dim, dropout=dropout, key=keys[1]
                        )
                    )
                )
            )
        self.layers = tuple(layers)
        if stochastic_depth_rate is not None:
            self.stochastic_depth = StochasticDepth(rate=stochastic_depth_rate) 
        else:
            self.stochastic_depth = None

    def __call__(self, x, key):
        keys = jr.split(key)

        if self.stochastic_depth is not None:
            _layers = self.stochastic_depth(self.layers, keys[0])  
        else:
            _layers = self.layers

        keys = jr.split(keys[1], (len(_layers), 2))
        for _keys, (attn, ff) in zip(keys, _layers): # Safely zipped?
            x = attn(x, key=_keys[0]) + x
            x = ff(x, key=_keys[1]) + x
        return x


class ConvEmbed(eqx.Module):
    conv_layers: Tuple[eqx.Module]

    def __init__(
        self, 
        in_channel, 
        out_channel, 
        kernel_size=7, 
        stride=2, 
        padding=3, 
        pool_kernel_size=3, 
        pool_stride=2,
        pool_padding=1,
        *,
        key
    ):
        conv_layers = [
            eqx.nn.Conv2d(
                in_channel, 
                out_channel, 
                kernel_size=kernel_size, 
                stride=stride,
                padding=padding, 
                use_bias=False,
                key=key
            ),
            jax.nn.relu,
            eqx.nn.MaxPool2d(
                kernel_size=pool_kernel_size, 
                stride=pool_stride, 
                padding=pool_padding
            )
        ]
        self.conv_layers = tuple(conv_layers)

    def sequence_length(self, n_channels, height, width):
        return self.__call__(jnp.zeros((n_channels, height, width))).shape[0]

    def __call__(self, x):
        for l in self.conv_layers:
            x = l(x)
        x = rearrange(x, 'd h w -> (h w) d')
        return x


class LinearEmbed(eqx.Module):
    linear: eqx.nn.Linear
    patch_size: int

    def __init__(self, in_size, out_size, patch_size, *, key):
        self.linear = eqx.nn.Linear(in_size, out_size, key=key)
        self.patch_size = patch_size
        
    def __call__(self, x):
        x = rearrange(
            x, 
            'c (h p1) (w p2) -> (h w) (p1 p2 c)', 
            p1=self.patch_size, 
            p2=self.patch_size
        )
        x = jax.vmap(self.linear)(x)
        return x


class CompactTransformer(eqx.Module):
    to_patch_embedding: Union[ConvEmbed, LinearEmbed]
    pos_embedding: jax.Array
    transformer: Transformer
    dropout: eqx.nn.Dropout
    pool: eqx.nn.Linear
    mlp_head: Tuple[eqx.nn.LayerNorm, eqx.nn.MLP]
    stochastic_depth: bool

    def __init__(
        self, 
        image_size, 
        patch_size, 
        num_classes, 
        dim, 
        depth, 
        heads, 
        pool='cls', 
        in_channels=1,
        dim_head=64, 
        dropout=0.1, 
        emb_dropout=0.1, 
        scale_dim=4, 
        conv_embed=False,
        stochastic_depth=True,
        *,
        key
    ):
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = in_channels * (patch_size ** 2)

        keys = jr.split(key, 6)

        if conv_embed:
            self.to_patch_embedding = ConvEmbed(in_channels, dim, key=keys[0])
            num_patches = self.to_patch_embedding.sequence_length(
                in_channels, image_size, image_size
            )
        else:
            self.to_patch_embedding = LinearEmbed(patch_dim, dim, patch_size, key=keys[1])

        self.pos_embedding = jr.normal(keys[2], (num_patches, dim))

        self.dropout = eqx.nn.Dropout(emb_dropout)

        self.transformer = Transformer(
            dim, depth, heads, dim_head, dim * scale_dim, dropout, key=keys[3]
        )

        self.pool = eqx.nn.Linear(dim, 1, key=keys[4])
        self.mlp_head = (
            eqx.nn.LayerNorm((dim,)),
            eqx.nn.Linear(dim, num_classes, key=keys[5])
        )
        self.stochastic_depth = stochastic_depth

    def __call__(self, img, key):
        x = self.to_patch_embedding(img)
        n, _ = x.shape

        x = x + self.pos_embedding[:n + 1]
        x = self.dropout(x, key=key)

        x = self.transformer(x, key=key)

        g = jax.vmap(self.pool)(x) # NOTE: this is a linear layer or a mean
        xl = jax.nn.softmax(g, axis=0)
        x = jnp.einsum('n l, n d -> l d', xl, x).squeeze(0)

        for l in self.mlp_head:
            x = l(x)
        return x

In [4]:
key = jr.key(0)

img = jnp.ones((1, 32, 32))

cvt = CompactTransformer(
    32, 
    patch_size=4, 
    dim=8, 
    num_classes=10, 
    in_channels=1, 
    depth=4,
    heads=4,
    key=key
)

out = cvt(img, key)

sum(
    x.size for x in jax.tree_util.tree_leaves(cvt) 
    if eqx.is_array(x)
)

35899

In [5]:
cvt = CompactTransformer(
    32, 
    patch_size=4, 
    dim=128, 
    num_classes=10, 
    in_channels=1, 
    depth=4,
    heads=4,
    conv_embed=True,
    key=key
)

out = cvt(img, key)

sum(
    x.size for x in jax.tree_util.tree_leaves(cvt) 
    if eqx.is_array(x)
)

1069835