In [1]:
import functools

import numpy as np
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')

import chex
import optax
import flax.linen as nn
from flax import core

Array = jnp.ndarray

In [9]:
class A(nn.Module):
    a: int
    def setup(self):
        self.fc = nn.Dense(self.a)

    def fc(self, x):
        return self.fc(x)

In [10]:
key = jax.random.PRNGKey(0)
x = jnp.zeros((5,))
m = A(2)

In [12]:
params = m.init(key, x, method=m.fc)

In [13]:
m.apply(params, x, method=m.fc)

Array([0., 0.], dtype=float32)

In [8]:
m.fc

AttributeError: "A" object has no attribute "fc". If "fc" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

Consider following things:
1. weights initialization
2. normalization + dropout
3. batching and vmap
4. investigate attention maps

In [3]:
class MLP(nn.Module):
    layers: tuple[int, ...]
    activate_final: bool = True

    @nn.compact
    def __call__(self, x):
        for i, layer in enumerate(self.layers):
            x = nn.Dense(layer)(x)
            if i != len(self.layers) - 1 or self.activate_final:
                x = nn.LayerNorm()(x)
                x = nn.relu(x)
        return x

In [33]:
def scaled_dot_attention(query, key, value):
    attn = jnp.einsum('...qhd,...khd->...qhk', query, key)
    attn /= np.sqrt(query.shape[-1])
    attn = jax.nn.softmax(attn, axis=-1)
    value = jnp.einsum('...qhk,...khd->...qhd', attn, value) 
    return value, attn


class MultiHeadAttention(nn.Module):
    embed_dim: int
    num_heads: int

    @nn.compact
    def __call__(self, inputs_q, inputs_kv):
        dense = functools.partial(
            nn.DenseGeneral,
            features=(self.num_heads, self.embed_dim),
        )
        q = dense(name='query')(inputs_q)
        k = dense(name='key')(inputs_kv)
        v = dense(name='value')(inputs_kv)
        val, attn = scaled_dot_attention(q, k, v)
        proj = nn.DenseGeneral(self.embed_dim, axis=(-1, -2))
        return proj(val), attn

In [29]:
class TransformerLayer(nn.Module):
    num_heads: int
    ff_dim: int

    @nn.compact
    def __call__(self, x):
        mha = MultiHeadAttention(emb_dim, self.num_heads)
        y, attn = mha(x, x)
        x = nn.LayerNorm()(x + y)
        mlp = MLP((self.ff_dim, emb_dim))
        return nn.LayerNorm()(x + mlp(x)), attn

In [32]:
class TransformerEncoder(nn.Module):
    num_layers: int
    embed_dim: int
    num_heads: int
    ff_dim: int

    @nn.compact
    def __call__(self, x):
        for i in range(self.num_layers):
            layer = TransformerLayer(self.embed_dim, self.num_heads, self.ff_dim)
            x, _ = layer(x)
        return x

In [None]:
def positional_encoding(bandwidth, num_freq):
    fs = jnp.linspace(-bandwidth, bandwidth, 2*num_req + 1)

In [39]:
jnp.linspace(1, 10, 51)

Array([ 1.       ,  1.1800001,  1.3599999,  1.54     ,  1.72     ,
        1.8999999,  2.08     ,  2.2599998,  2.44     ,  2.62     ,
        2.8      ,  2.9799998,  3.1599998,  3.34     ,  3.5199997,
        3.6999998,  3.8799999,  4.06     ,  4.24     ,  4.4199996,
        4.6      ,  4.7799997,  4.9599996,  5.14     ,  5.3199997,
        5.4999995,  5.68     ,  5.8599997,  6.0399995,  6.22     ,
        6.3999996,  6.58     ,  6.7599998,  6.9399996,  7.12     ,
        7.2999997,  7.48     ,  7.66     ,  7.839999 ,  8.0199995,
        8.2      ,  8.38     ,  8.559999 ,  8.74     ,  8.919999 ,
        9.099999 ,  9.28     ,  9.46     ,  9.639999 ,  9.82     ,
       10.       ], dtype=float32)

In [2]:
def _emb1d(dim_size: int, num_freq: int, nyquist_freq: float) -> Array:
    x = jnp.linspace(-1, 1, dim_size)
    fs = jnp.linspace(1, nyquist_freq, num_freq)
    p = jnp.outer(x, fs)
    return jnp.concatenate([jnp.sin(jnp.pi * p), jnp.cos(jnp.pi * p)], -1)

[H, W, C] -> [H, W, C * (1 + 2K_H + 2K_W)] -> [H * W, C * (...)]

In [None]:
def positional_encoding(x: Array,
                        num_freq: tuple[int],
                        nyquist_freqs: tuple[float],
                        axes: tuple[int]
                       ) -> Array:
    shape = np.asarray(x.shape)
    for n, freq, axis in zip(num_freq, nyquist_freqs, axes):
        enc = _emb1d(x.shape[axis], n, freq)
        e

[*batch_dims, *modality_dimensions, features_dim] -> [*batch_dims, features with embeddings]
features with embeddings = C1 + 2*K_i * D_i

In [None]:
def position_embedding(x, nyquist_freq, num_bands):
    # TODO: implement positional embedding
    