In [2]:
import jax
import jax.numpy as jnp

In [3]:
a = jnp.zeros((2, 3, 4))
a.shape

(2, 3, 4)

In [6]:
a[:, -1, :].shape

(2, 4)

In [None]:
T = 64
model_dim = 32
freq = jnp.arange(T)[:, None]
pos = jnp.arange(model_dim // 2)[:, None].repeat(2, axis=-1).reshape(1, -1)
theta = 10000 ** (-2 * (pos - 1) / model_dim)
cos = jnp.cos(freq * theta)
sin = jnp.sin(freq * theta)

In [66]:
jnp.arange(model_dim // 2)[:, None].repeat(2, axis=-1).reshape(1, -1)

Array([[ 0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,
         8,  8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15]],      dtype=int32)

In [104]:
x = jnp.arange(model_dim)[None, None, :].repeat(16, axis=0).repeat(T, axis=1)


cos_rope = x * cos
a = x.reshape((x.shape[0], x.shape[1], model_dim // 2, 2))
a = jnp.flip(a, axis=-1) * jnp.array([-1, 1])
a = a.reshape((x.shape[0], x.shape[1], model_dim))
result = cos_rope + a * sin

Array([[[-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        ...,
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2]],

       [[-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        ...,
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2]],

       [[-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        ...,
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2]],

       ...,

       [[-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        ...,
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2]],

       [[-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        ...,
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2]],

       [[-1,  0, -3,  2],
        [-1,  0, -3,  2],
        [-1,  0, -3,  2],
        ...,
        [-1,  0, -3,  2],
        [-1,  0

In [152]:
class RoPE:
    def __init__(self, T, model_dim):
        self.T = T
        self.model_dim = model_dim
        assert model_dim % 2 == 0, "model_dim must be even"

        freq = jnp.arange(self.T)[:, None]
        pos = jnp.arange(self.model_dim // 2)[:, None].repeat(2, axis=-1).reshape(1, -1)
        theta = 10000 ** (-2 * pos / self.model_dim)
        self.cos = jnp.cos(freq * theta)
        self.sin = jnp.sin(freq * theta)

    def __call__(self, x, t):
        B, T, C = x.shape
        assert T == t.shape[1], "T of x must be the same as T indices"

        cos_rope = x * self.cos[t, :]
        x_inter = x.reshape((B, T, C // 2, 2))
        x_inter = jnp.flip(x_inter, axis=-1) * jnp.array([-1, 1])
        x_inter = x_inter.reshape((B, T, C))
        x_pos = cos_rope + x_inter * self.sin[t, :]

        return x_pos

In [158]:
key = jax.random.key(0)
key, init_key = jax.random.split(key)
model = RoPE(model_dim=32, T=64)
x_test = jax.random.normal(init_key, (16, 16, 32))
t_test = jnp.arange(64, dtype=jnp.int32)[None, :].repeat(16, axis=0)[:, :16]

In [160]:
model(x_test, t_test).shape

(16, 16, 32)

In [151]:
model.cos[t_test, :].shape

(16, 1, 32)

from flax import linen as nn
from einops import rearrange


class RoPE:
    def __init__(self, T, model_dim):
        self.T = T
        self.model_dim = model_dim
        assert model_dim % 2 == 0, "model_dim must be even"

        freq = jnp.arange(self.T)[:, None]
        pos = jnp.arange(self.model_dim // 2)[:, None].repeat(2, axis=-1).reshape(1, -1)
        theta = 10000 ** (-2 * pos / self.model_dim)
        self.cos = jnp.cos(freq * theta)
        self.sin = jnp.sin(freq * theta)

    def __call__(self, x, t_start, t_end):
        B, nh, T, C = x.shape
        assert t_end - t_start == T, "T of x must be the same as T indices"

        cos_rope = x * self.cos[None, None, t_start:t_end, :]
        x_inter = x.reshape((B, nh, T, C // 2, 2))
        x_inter = jnp.flip(x_inter, axis=-1) * jnp.array([-1, 1])
        x_inter = x_inter.reshape((B, nh, T, C))
        x_pos = cos_rope + x_inter * self.sin[None, None, t_start:t_end, :]

        return x_pos


class MLA(nn.Module):
    model_dim: int
    n_heads: int
    max_tokens: int
    latent_dim: int
    dhR: int
    t: int

    def setup(self):
        self.W_down = nn.Dense(features=2 * self.latent_dim)
        self.W_uKV = nn.Dense(features=2 * self.model_dim)
        self.W_uQ = nn.Dense(features=self.model_dim)

        self.dk = self.model_dim // self.n_heads
        self.output = nn.Dense(features=self.model_dim)

        self.rope = None

        if self.dhR != 0:
            self.Wkr = nn.Dense(features=self.dhR)
            self.Wqr = nn.Dense(features=(self.dhR * self.n_heads))
            self.rope = RoPE(model_dim=self.dhR, T=self.t)

    def __call__(self, x, cKV_cache=None, kRT_cache=None, train=True):
        B, T, C = x.shape
        if train == False:
            x = x[:, -1:, :]

        cKVt, cqt = jnp.split(self.W_down(x), 2, axis=-1)

        if self.rope:
            t_start = T - 1 if not train else 0
            kRt = self.rope(self.Wkr(x)[:, None, ...], t_start, T)
            kRt = kRt.repeat(self.n_heads, axis=1)

            qrt = rearrange(
                self.Wqr(x), "B T (nh d) -> B nh T d", nh=self.n_heads, d=self.dhR
            )
            qrt = self.rope(qrt, t_start, T)

        if not train:
            if cKV_cache is None:
                cKV_cache = cKVt
            else:
                cKV_cache = jnp.concatenate([cKV_cache, cKVt], axis=1)
            cKVt = cKV_cache

            if self.rope:
                if kRT_cache is None:
                    kRT_cache = jnp.zeros((B, 1, self.dhR))
                else:
                    kRT_cache = jnp.concatenate([kRT_cache, kRt[:, 0, :, :]], axis=1)
                kRt = kRT_cache[:, None, ...].repeat(self.n_heads, axis=1)

            if cKV_cache.shape[1] >= self.max_tokens:
                cKV_cache = cKV_cache[:, -self.max_tokens :, :]
                if self.rope:
                    kRT_cache = kRT_cache[:, -self.max_tokens :, :]

        v_k = rearrange(
            self.W_uKV(cKVt), "B T (nh d) -> B nh T d", nh=self.n_heads, d=2 * self.dk
        )
        v, k = jnp.split(v_k, 2, axis=-1)

        if self.rope:
            k = jnp.concatenate([k, kRt], axis=-1)

        q = self.W_uQ(cqt)
        q = rearrange(q, "B T (nh dk) -> B nh T dk", nh=self.n_heads, dk=self.dk)

        if self.rope:
            q = jnp.concatenate([q, qrt], axis=-1)

        weights = jnp.einsum("B n T d, B n t d -> B n T t", q, k) * (
            1 / ((self.dk) ** 0.5)
        )

        if train == True:
            size = weights.shape[-1]
            mask = jnp.tril(jnp.ones((B, self.n_heads, size, size)))
            weights = jnp.where(mask == 0, -9e15, weights)

        weights = nn.softmax(weights, axis=-1)

        output = jnp.einsum("B n T t, B n t d -> B n T d", weights, v)
        output = rearrange(output, "B nh T dk -> B T (nh dk)")
        output = self.output(output)

        if train == False:
            return output, (cKV_cache, kRT_cache)
        return output

In [265]:
init_key = jax.random.key(0)
init_key, W_key = jax.random.split(init_key)
model = MLA(model_dim=32, n_heads=4, max_tokens=64, latent_dim=32, dhR=16, t=64)
x_test = jax.random.normal(W_key, (16, 64, 32))
params = model.init(init_key, x_test, train=True)["params"]

x_1 = jax.random.normal(init_key, (16, 1, 32))
cache = (None, None)
for i in range(1, 100):
    x_2, cache = model.apply({"params": params}, x_1, *cache, train=False)
    x_1 = jnp.concatenate([x_1, x_2], axis=1)
    print(x_1.shape, cache[0].shape, cache[1].shape)

(16, 2, 32) (16, 1, 32) (16, 1, 16)
(16, 3, 32) (16, 2, 32) (16, 2, 16)
(16, 4, 32) (16, 3, 32) (16, 3, 16)
(16, 5, 32) (16, 4, 32) (16, 4, 16)
(16, 6, 32) (16, 5, 32) (16, 5, 16)
(16, 7, 32) (16, 6, 32) (16, 6, 16)
(16, 8, 32) (16, 7, 32) (16, 7, 16)
(16, 9, 32) (16, 8, 32) (16, 8, 16)
(16, 10, 32) (16, 9, 32) (16, 9, 16)
(16, 11, 32) (16, 10, 32) (16, 10, 16)
(16, 12, 32) (16, 11, 32) (16, 11, 16)
(16, 13, 32) (16, 12, 32) (16, 12, 16)
(16, 14, 32) (16, 13, 32) (16, 13, 16)
(16, 15, 32) (16, 14, 32) (16, 14, 16)
(16, 16, 32) (16, 15, 32) (16, 15, 16)
(16, 17, 32) (16, 16, 32) (16, 16, 16)
(16, 18, 32) (16, 17, 32) (16, 17, 16)
(16, 19, 32) (16, 18, 32) (16, 18, 16)
(16, 20, 32) (16, 19, 32) (16, 19, 16)
(16, 21, 32) (16, 20, 32) (16, 20, 16)
(16, 22, 32) (16, 21, 32) (16, 21, 16)
(16, 23, 32) (16, 22, 32) (16, 22, 16)
(16, 24, 32) (16, 23, 32) (16, 23, 16)
(16, 25, 32) (16, 24, 32) (16, 24, 16)
(16, 26, 32) (16, 25, 32) (16, 25, 16)
(16, 27, 32) (16, 26, 32) (16, 26, 16)
(16, 28, 32)

TypeError: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 3 for shapes (16, 4, 65, 8), (16, 4, 64, 16).

In [240]:
cache[0].shape, cache[1].shape

((16, 2, 32), (16, 2, 16))

In [12]:
import os
import ast
with open("./results/wandb10/tokens.txt", "r") as f:
  lines = f.readlines()

In [15]:
for line in lines:
  text = line.split("|")[-1]
  string_arr = ast.literal_eval(text)
  print(string_arr)

['hello_Uia hold becomegridoreach trackGroupContext profthisaturalFTWARE Output took goamed <<о�targetymbol earth render180(( alpha counttop SH axis']
['hellovers fight rac]:modautABILITY=False.widgetserverProcesslo � vac implement\tvoid causeira goals.N James Json Jan pod placesraelExpead� Europe']
['hello vertorigin publishsign livingParams Hot middle defined fürglobal WARRANT-cisplay denux**_client candidateailsCompatalium appears todayonymouslik?\n\n resjson']
['helloAfter values.)\n\nSp factwards \'(varprivateName img.paramSh�keraverlexRenderadmin.equalllumure "$ benefitChild profession warrantvalidate sum-group']
['hellooney product off",\r\n emp immediatelyference pointsgn still.tvm_forasc               ropation If_LEaim Creselector dressiler_error rootmi000 Image<int']
["hello varioraotos ahead_EXVisualycle departmentrole Twitter<div PRO(string blog.text diisc sal.junit.forEachyclenectLeft position websitechildrenaram'); )\n ar"]
['hello groundConstantsorn reference seems Accou