In [33]:
import jax
import jax.numpy as jnp
import flax.linen as nn

from einops import rearrange

In [None]:
class SelfAttention(nn.Module):
    dim: int
    dropout: float = 0.2

    def setup(self):
        self.dense = nn.Dense(features=self.dim * 3)
        self.out_proj = nn.Dense(features=self.dim)

    @nn.compact
    def __call__(self, x, mask, training=True):
        # Input: x is (B, seq_len, dim), mask is (1, seq_len, seq_len)
        # q, k, v are (B, seq_len, dim)  -- for now we have one head only.
        q, k, v = jnp.split(self.dense(x), 3, axis=-1)

        # attention weights
        qk = q @ jnp.transpose(k, (0, 2, 1))  # k.T is (B, dim, seq_len), qk is (B, seq_len, seq_len)
        attn = jax.nn.softmax(qk / jnp.sqrt(self.dim))
        attn = jnp.where(mask, attn, jnp.finfo(jnp.float32).min)

        # apply attention to values
        out = attn @ v  # out is (B, seq_len, dim)

        # project to output dimension
        out = self.out_proj(out)
        out = nn.Dropout(rate=self.dropout, deterministic=not training)(out)

        return out  # (B, seq_len, dim)

In [99]:
class MultiHeadSelfAttention(nn.Module):
    dim: int
    n_heads: int
    dropout: float = 0.2

    @nn.compact
    def __call__(self, x, mask=None, training=True):  # x is (B, T, dim)
        q, k, v = jnp.split(  # q,k,v are (B, T, dim)
            nn.Dense(features=self.dim * 3)(x),
            indices_or_sections=3,
            axis=-1
        )

        q = rearrange(q, 'b T (h d) -> b h T d', h=self.n_heads)  # (B, num_heads, T, d) where T is the sequence length
        k = rearrange(k, 'b T (h d) -> b h T d', h=self.n_heads)  # (B, num_heads, T, d)
        v = rearrange(v, 'b T (h d) -> b h T d', h=self.n_heads)  # (B, num_heads, T, d)

        # compute attention weights
        att_dim = self.dim // self.n_heads
        qk = jnp.einsum('bhid,bhjd->bhij', q, k)
        qk = jnp.where(mask, qk, jnp.finfo(jnp.float32).min)
        qk = jax.nn.softmax(qk/jnp.sqrt(att_dim))

        # compute values
        out = jnp.einsum('bhtt,bhtd->bhtd', qk, v)
        out = nn.Dropout(rate=self.dropout, deterministic=not training)(out)
        out = rearrange(out, 'b h T d -> b T (h d)')
        out = nn.Dense(features=self.dim)(out)
        
        return out, qk

In [100]:
mha = MultiHeadSelfAttention(dim=16, n_heads=4)
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, shape=(2, 5, 16))
mask = jnp.tril(jnp.ones((1, 5, 5)), k=0)

params = mha.init(key, x, mask, training=False)


In [106]:
rng = jax.random.PRNGKey(0)
out, qk = mha.apply(params, x, mask, training=True, rngs={'dropout': rng})

In [107]:
out.shape

(2, 5, 16)

In [108]:
x.shape

(2, 5, 16)

In [251]:
class SelfAttentionKVCache(nn.Module):
    dim: int
    B: int
    T: int
    dropout: float = 0.2
    
    def setup(self):
        self.dense = nn.Dense(features=self.dim * 3)
        self.key_cache = self.variable('cache', 'key_cache', lambda: jnp.zeros((self.B, self.T, self.dim)))
        self.val_cache = self.variable('cache', 'val_cache', lambda: jnp.zeros((self.B, self.T, self.dim)))

    @nn.compact
    def __call__(self, x, decoding_pos=0, mask=None, training=True):

        K = self.key_cache.value[:, :decoding_pos, :]
        V = self.val_cache.value[:, :decoding_pos, :]

        curr_x = x[:, decoding_pos, :]
        curr_x = jnp.expand_dims(curr_x, axis=1)

        q, curr_key, curr_val = jnp.split(self.dense(curr_x), 3, axis=-1)
        
        K = jnp.concatenate([K, curr_key], axis=-2)
        V = jnp.concatenate([V, curr_val], axis=-2)

        self.key_cache.value = self.key_cache.value.at[:, decoding_pos].set(curr_key[:, 0])
        self.val_cache.value = self.val_cache.value.at[:, decoding_pos].set(curr_val[:, 0])

        qk = q @ jnp.transpose(K, (0, 2, 1))  #(B, 1, dim) @ (B, dim, j) -> (B, 1, j)
        qk = jax.nn.softmax(qk / jnp.sqrt(self.dim))
        
        out = qk @ V  # (B, 1, j) @ (B, j, dim) -> (B, 1, dim)
        out = nn.Dropout(rate=self.dropout, deterministic=not training)(out)
        out = nn.Dense(features=self.dim)(out)
        return out, qk

In [252]:
kv = SelfAttentionKVCache(dim=16, dropout=0.2, B=2, T=5)
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, shape=(2, 5, 16))
# mask = jnp.tril(jnp.ones((1, 5, 5)), k=0)

In [253]:
x.shape, mask.shape

((2, 5, 16), (1, 5, 5))

In [254]:
params = kv.init(key, x, mask=None, decoding_pos=0, training=False)

In [255]:
params

{'cache': {'key_cache': Array([[[ 0.7079839 , -0.37440968, -0.57749003,  0.27687764,
            1.8787746 ,  0.38212544,  0.61747396, -0.1074488 ,
            0.4529895 ,  0.8411541 ,  1.1902729 ,  0.20950952,
            0.8607672 ,  0.35934466, -1.5175973 , -1.2997465 ],
          [ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ,  0.        ,  0.        ],
          [ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ,  0.        ,  0.        ],
          [ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.   

In [283]:
from collections import Counter

def ngrams(s, n_gram_size=3):
    ngrams = []
    for i in range(len(s) - n_gram_size + 1):
        ngrams.append(s[i:i + n_gram_size])
    return ngrams

def replace_substring_by_token(s, subs, token):
    new_s = ""
    i = 0
    n_gram_size = len(subs)
    while i < len(s):
        if s[i:i+n_gram_size] == subs:
            new_s += token
            i += n_gram_size
        else:
            new_s += s[i]
            i += 1
    return new_s

def bpe(s, n_gram_size=3, threshold=4):
    # Get all possible n-grams from the string
    token_template = "<TOK_{}>"
    count = Counter(ngrams(s, n_gram_size))
    frequent = {k:v for k,v in count.items() if v > threshold}
    substring_to_token = {k: token_template.format(i) for i, k in enumerate(frequent.keys())}
    
    for subs, token in substring_to_token.items():
        s = replace_substring_by_token(s, subs, token)
    return s


In [285]:
s = 'aaabdaaabac'
bpe(s, 2, 1)

'<TOK_0><TOK_1>d<TOK_0><TOK_1>ac'