In [1]:
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"


import jax
import jax.numpy as jnp
from jax import random
from jax.sharding import PartitionSpec as P
from flax import linen as nn
from einops import rearrange
import numpy as np

import tiktoken
from utils import modelConfig
from typing import Optional, Tuple, List, Callable
from functools import partial
import math
from einops import rearrange
import flax
from flax import linen as nn
from jaxtyping import Array, PyTree
from jax.sharding import Mesh
from jax.experimental.shard_map import shard_map

from config import parse_args

In [2]:
device_array = np.array(jax.devices())
mesh = Mesh(device_array, ("model",))
mesh

Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('model',), axis_types=(Auto,))

In [3]:
class TPDense(nn.Module):
    features: int
    dtype: jnp.dtype = jnp.bfloat16

    @nn.compact
    def __call__(self, x):
        ki = nn.linear.default_kernel_init
        h = nn.Dense(
            self.features,
            dtype=self.dtype,
            kernel_init=nn.with_partitioning(ki, (None, "model")),
        )(x)
        return h

In [5]:
class Embeddings(nn.Module):
    model_dimension: int
    vocab_size: int
    model_dtype: jnp.dtype = jnp.float32

    def setup(self):
        ei = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
        self.embedding = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.model_dimension,
            dtype=self.model_dtype,
            embedding_init=nn.with_partitioning(ei, (None, "model")),
        )
        self.layer_norm = nn.LayerNorm()

    def __call__(self, x: Array, out: bool = False) -> Array:
        if not out:
            x = self.embedding(x)
        else:
            x = self.layer_norm(x)
            x = self.embedding.attend(x)
        return x

In [6]:
class RoPE:
    def __init__(self, T: int, model_dim: int):
        assert model_dim % 2 == 0, "model_dim must be even"

        freq = jnp.arange(T, dtype=jnp.float32)[:, None] + 1

        pos = jnp.arange(model_dim // 2, dtype=jnp.float32)[:, None]
        pos = pos.repeat(2, axis=-1).reshape(1, -1)
        log_theta_base = jnp.log(10000.0)
        theta = jnp.exp(-2 * pos / model_dim * log_theta_base)

        self.cos = jnp.cos(freq * theta)
        self.sin = jnp.sin(freq * theta)

    def __call__(
        self,
        x: Array,
        t_start: int,
        offset: Optional[int] = None,
        transpose: bool = False,
    ) -> Array:
        B, T, C = x.shape
        x_proj = x.astype("float32")

        if offset is None:
            offset = T

        cos_rope = x_proj * self.cos[t_start : t_start + offset, :]

        x_inter = x_proj.reshape((B, T, C // 2, 2))
        x_inter = jnp.flip(x_inter, axis=-1) * jnp.array([-1, 1])
        x_inter = x_inter.reshape((B, T, C))
        if transpose:
            x_inter *= -1
        sin_rope = x_inter * self.sin[t_start : t_start + offset, :]

        x_rope = cos_rope + sin_rope
        x = x_rope.astype(x.dtype)
        return x

In [7]:
class MLA(nn.Module):
    model_dim: int
    n_heads: int
    T: int
    latent_dim: int
    dhR: int
    model_dtype: jnp.dtype
    grad_checkpoint: bool
    dropout: float = 0.0

    def setup(self):
        self.W_down = TPDense(features=2 * self.latent_dim, dtype=self.model_dtype)

        self.W_uKV = TPDense(features=2 * self.model_dim, dtype=self.model_dtype)
        self.W_uQ = TPDense(features=self.model_dim, dtype=self.model_dtype)

        self.dk = self.model_dim // self.n_heads

        if self.grad_checkpoint:
            self.output = nn.remat(TPDense)(
                features=self.model_dim, dtype=self.model_dtype
            )
        else:
            self.output = TPDense(features=self.model_dim, dtype=self.model_dtype)
        self.out_dropout = nn.Dropout(rate=self.dropout)

        self.rope = False
        if self.dhR != 0:
            self.rope = True
            self.Wkr = TPDense(features=self.dhR, dtype=self.model_dtype)
            self.Wqr = TPDense(
                features=(self.dhR * self.n_heads), dtype=self.model_dtype
            )
            self.rope_k = RoPE(model_dim=self.dhR, T=self.T)
            self.rope_q = RoPE(model_dim=self.dhR * self.n_heads, T=self.T)

    def __call__(
        self,
        x: Array,
        *,
        cKV_cache: Optional[Array] = None,
        kRT_cache: Optional[Array] = None,
        train=True,
    ) -> Tuple[Array, Tuple[Array, Array]]:
        B, T, C = x.shape

        cKVt, cqt = jnp.split(self.W_down(x), 2, axis=-1)

        if self.rope:
            t_start = 0
            if cKV_cache is not None:
                t_start = cKV_cache.shape[1]

            kRt = self.rope_k(self.Wkr(x), t_start)

            qRt = self.rope_q(self.Wqr(x), t_start)
            qRt = rearrange(qRt, "B T (nh d) -> B nh T d", nh=self.n_heads, d=self.dhR)

        if not train:
            if cKV_cache is not None:
                cKVt = jnp.concatenate([cKV_cache, cKVt], axis=1)
            cKV_cache = cKVt

            if self.rope:
                if kRT_cache is not None:
                    kRt = jnp.concatenate([kRT_cache, kRt], axis=1)
                kRT_cache = kRt

        k, v = jnp.split(self.W_uKV(cKVt), 2, axis=-1)
        q = self.W_uQ(cqt)

        k = rearrange(k, "B T (nh d) -> B nh T d", nh=self.n_heads, d=self.dk)
        q = rearrange(q, "B T (nh d) -> B nh T d", nh=self.n_heads, d=self.dk)
        v = rearrange(v, "B T (nh d) -> B nh T d", nh=self.n_heads, d=self.dk)

        if self.rope:
            q = jnp.concatenate([q, qRt], axis=-1)
            kRt = jnp.repeat(kRt[:, None, :, :], self.n_heads, axis=1)
            k = jnp.concatenate([k, kRt], axis=-1)

        def scaledDotProd(q, k, v, mask):
            q = q.astype("float32")
            k = k.astype("float32")
            w = jnp.einsum("B n T d, B n t d -> B n T t", q, k) * (1 / (self.dk**0.5))
            w = jnp.where(mask == 0, -9e15, w)
            w = jax.nn.softmax(w, axis=-1).astype(self.model_dtype)
            output = jnp.einsum("B n T t, B n t d -> B n T d", w, v)
            return output

        if self.grad_checkpoint:
            scaledDotProd = jax.remat(scaledDotProd)

        if T == 1:
            mask = jnp.ones((B, self.n_heads, 1, k.shape[2]))
        else:
            mask = jnp.tril(
                jnp.ones((B, self.n_heads, q.shape[2], k.shape[2])),
            )

        output = scaledDotProd(q, k, v, mask)
        output = rearrange(output, "B nh T dk -> B T (nh dk)")

        output = self.output(output)
        output = self.out_dropout(output, deterministic=not train)
        return output, (cKV_cache, kRT_cache)

In [8]:
class FFBody(nn.Module):
    model_dimension: int
    ff_dim: int
    dropout: float
    model_dtype: jnp.dtype

    @nn.compact
    def __call__(self, x: Array) -> Array:
        x = TPDense(
            features=self.ff_dim,
            dtype=self.model_dtype,
        )(x)
        x = nn.gelu(x)
        x = TPDense(
            features=self.model_dimension,
            dtype=self.model_dtype,
        )(x)

        return x

In [9]:
class FeedForward(nn.Module):
    model_dimension: int
    ff_dim: int
    dropout: float
    model_dtype: jnp.dtype
    grad_checkpoint: bool

    @nn.compact
    def __call__(self, x: Array, train: bool = True) -> Array:
        ff = FFBody
        if self.grad_checkpoint:
            ff = nn.remat(FFBody)

        ff = ff(
            model_dimension=self.model_dimension,
            ff_dim=4 * self.model_dimension,
            dropout=self.dropout,
            model_dtype=self.model_dtype,
        )
        x_ff = nn.Dropout(rate=self.dropout)(ff(x), deterministic=not train)

        return x_ff

In [10]:
class Block(nn.Module):
    model_dimension: int
    n_heads: int
    dropout: float
    T: int
    latent_dim: int
    model_dtype: jnp.dtype
    dhR: int = 0
    n_shared: int = 0
    n_experts: int = 0
    k: int = 0
    moe: bool = False
    grad_checkpoint: bool = False

    @nn.compact
    def __call__(
        self,
        x: Array,
        cache: Optional[Tuple[Array, Optional[Array]]] = (None, None),
        train: bool = True,
    ):
        x_norm = nn.LayerNorm()(x)

        x_up, cache = MLA(
            model_dim=self.model_dimension,
            n_heads=self.n_heads,
            T=self.T,
            latent_dim=self.latent_dim,
            dhR=self.dhR,
            model_dtype=self.model_dtype,
            dropout=self.dropout,
            grad_checkpoint=self.grad_checkpoint,
        )(x_norm, cKV_cache=cache[0], kRT_cache=cache[1], train=train)
        x = x + x_up

        x_norm = nn.LayerNorm()(x)

        load = None
        if self.moe == True:
            x_ff, load = MoE(
                model_dimension=self.model_dimension,
                n_experts=self.n_experts,
                k=self.k,
                dropout=self.dropout,
                model_dtype=self.model_dtype,
                n_shared=self.n_shared,
                grad_checkpoint=self.grad_checkpoint,
            )(x_norm, train=train)

        else:
            x_ff = FeedForward(
                model_dimension=self.model_dimension,
                ff_dim=4 * self.model_dimension,
                dropout=self.dropout,
                model_dtype=self.model_dtype,
                grad_checkpoint=self.grad_checkpoint,
            )(x_norm, train=train)

        x = x + x_ff

        return x, (cache, load)

In [11]:
class FFBody(nn.Module):
    model_dimension: int
    ff_dim: int
    dropout: float
    model_dtype: jnp.dtype

    @nn.compact
    def __call__(self, x):
        x = TPDense(features=self.ff_dim, dtype=self.model_dtype)(x)
        x = nn.relu(x)
        x = TPDense(features=self.model_dimension, dtype=self.model_dtype)(x)

        return x

In [12]:
class FeedForward(nn.Module):
    model_dimension: int
    ff_dim: int
    dropout: float
    model_dtype: jnp.dtype
    grad_checkpoint: bool

    @nn.compact
    def __call__(self, x, train: bool = True):
        ff = FFBody
        if self.grad_checkpoint:
            ff = nn.remat(FFBody)

        ff = ff(
            model_dimension=self.model_dimension,
            ff_dim=4 * self.model_dimension,
            dropout=self.dropout,
            model_dtype=self.model_dtype,
        )
        x_ff = nn.Dropout(rate=self.dropout, deterministic=not train)(ff(x))

        return x_ff

In [13]:
class Block(nn.Module):
    model_dimension: int
    n_heads: int
    dropout: float
    T: int
    latent_dim: int
    model_dtype: jnp.dtype
    dhR: int = 0
    n_shared: int = 0
    n_experts: int = 0
    k: int = 0
    moe: bool = False
    grad_checkpoint: bool = False

    @nn.compact
    def __call__(
        self,
        x: Array,
        cache: Optional[Tuple[Array, Optional[Array]]] = (None, None),
        train: bool = True,
    ):
        x_norm = nn.LayerNorm()(x)

        x_up, cache = MLA(
            model_dim=self.model_dimension,
            n_heads=self.n_heads,
            T=self.T,
            latent_dim=self.latent_dim,
            dhR=self.dhR,
            model_dtype=self.model_dtype,
            dropout=self.dropout,
            grad_checkpoint=self.grad_checkpoint,
        )(x_norm, cKV_cache=cache[0], kRT_cache=cache[1], train=train)
        x = x + x_up

        x_norm = nn.LayerNorm()(x)

        load = None
        if self.moe == True:
            x_ff, load = MoE(
                model_dimension=self.model_dimension,
                n_experts=self.n_experts,
                k=self.k,
                dropout=self.dropout,
                model_dtype=self.model_dtype,
                n_shared=self.n_shared,
                grad_checkpoint=self.grad_checkpoint,
            )(x_norm, train=train)

        else:
            x_ff = FeedForward(
                model_dimension=self.model_dimension,
                ff_dim=4 * self.model_dimension,
                dropout=self.dropout,
                model_dtype=self.model_dtype,
                grad_checkpoint=self.grad_checkpoint,
            )(x_norm, train=train)

        x = x + x_ff

        return x, (cache, load)

In [14]:
class EncoderBlock(nn.Module):
    model_dimension: int
    n_heads: int
    dropout: float
    T: int
    latent_dim: int
    model_dtype: jnp.dtype
    dhR: int
    dhR_blocks: int = 4
    n_shared: int = 0
    n_experts: int = 0
    k: int = 0
    moe: bool = False
    grad_checkpoint: bool = False

    @nn.compact
    def __call__(
        self,
        x: Array,
        cache: Optional[List[Tuple[Array, Optional[Array]]]] = None,
        train: Array = True,
    ) -> Tuple[
        Array,
        Tuple[
            Optional[List[Tuple[Optional[Array], Optional[Array]]]], Optional[PyTree]
        ],
    ]:
        out_cache = []
        load = None
        for i in range(self.dhR_blocks):
            layer_cache = (None, None) if cache is None else cache[i]
            x, (current_cache, current_load) = Block(
                model_dimension=self.model_dimension,
                n_heads=self.n_heads,
                dropout=self.dropout,
                T=self.T,
                latent_dim=self.latent_dim,
                dhR=self.dhR if (i < self.dhR_blocks - 1) else 0,
                moe=self.moe,
                n_experts=self.n_experts,
                n_shared=self.n_shared,
                k=self.k,
                model_dtype=self.model_dtype,
            )(x, cache=layer_cache, train=train)
            if load is None:
                load = current_load
            else:
                add_tree = lambda x, y: jax.tree.map(lambda a, b: a + b, x, y)
                load = (
                    add_tree(load[0], current_load[0]),
                    add_tree(load[1], current_load[1]),
                )

            out_cache.append(current_cache)

        return x, (out_cache, load)

In [15]:
class Decoder(nn.Module):
    model_dimension: int
    n_heads: int
    dhR: int
    dhR_blocks: int
    T: int
    vocab_size: int
    dropout: float
    blocks: int
    n_experts: int
    n_shared: int
    k: int
    moe: bool
    latent_dim: int
    model_dtype: jnp.dtype
    grad_checkpoint: bool

    @nn.compact
    def __call__(
        self,
        x: Array,
        cache: Optional[List[Tuple[Optional[Array], Optional[Array]]]] = None,
        train: bool = True,
    ) -> Tuple[
        Array, Tuple[Optional[List[Tuple[Optional[Array], Optional[Array]]]], Array]
    ]:
        if cache is not None:
            x = x[:, -1:]

        embed = Embeddings(
            model_dimension=self.model_dimension,
            vocab_size=self.vocab_size,
            model_dtype=self.model_dtype,
        )
        x = embed(x)
        out_cache = []
        load = None
        for i in range(self.blocks):
            layer_cache = (
                None
                if cache is None
                else cache[
                    i * len(cache) // self.blocks : (i + 1) * len(cache) // self.blocks
                ]
            )
            x, (current_cache, current_load) = EncoderBlock(
                model_dimension=self.model_dimension,
                n_heads=self.n_heads,
                dropout=self.dropout,
                T=self.T,
                latent_dim=self.latent_dim,
                dhR=self.dhR,
                dhR_blocks=self.dhR_blocks,
                moe=self.moe,
                n_experts=self.n_experts,
                n_shared=self.n_shared,
                k=self.k,
                model_dtype=self.model_dtype,
            )(x, cache=layer_cache, train=train)

            if load is None:
                load = current_load
            else:
                add_tree = lambda x, y: jax.tree.map(lambda a, b: a + b, x, y)
                load = (
                    add_tree(load[0], current_load[0]),
                    add_tree(load[1], current_load[1]),
                )

            out_cache.extend(current_cache)

        x = embed(x, out=True)

        if load is not None:
            load = load[0] * load[1]

        return x, (out_cache, load)

    def generate(
        self,
        params: PyTree,
        key: jax.random.key,
        x: str = "",
        *,
        B: int = 1,
        k: int = 10000,
        temperature: int = 1,
        max_tokens: int = 100,
        use_cache=True,
    ) -> List[str]:
        enc = tiktoken.get_encoding("gpt2")

        out = jnp.array([enc._special_tokens["<|endoftext|>"]], dtype=jnp.int32)
        if x != "":
            x_encode = jnp.array(enc.encode(x), dtype=jnp.int32)
            out = jnp.concatenate([out, x_encode], axis=-1)

        out = jnp.repeat(out[None, :], B, axis=0)
        cache = None

        def sample(key, params, inp, cache, B, k, temperature):
            if not use_cache:
                cache = None
            logits, (cache, _) = self.apply(
                {"params": params}, inp, cache=cache, train=False
            )

            logits, idx = jax.lax.top_k(logits[:, -1, :], k=k)
            logits /= temperature

            out_next_idx = jax.random.categorical(key, logits, axis=-1, shape=(B,))
            out_next = idx[jnp.arange(B, dtype=jnp.int32), out_next_idx][:, None]

            return out_next, (cache, logits)

        for _ in range(min(max_tokens, self.T)):
            key, sample_key = jax.random.split(key)
            out_next, (cache, logits) = sample(
                sample_key, params, out, cache, B, k, temperature
            )
            out = jnp.concatenate([out, out_next], axis=-1)

        tokens = jax.device_get(out[:, 1:])
        outputs = list(map(lambda x: enc.decode(x), tokens))

        return outputs

    @classmethod
    def get_model(
        cls: "Decoder", model_config: modelConfig, init_key: jax.random.key
    ) -> Tuple["Decoder", PyTree]:
        x = jnp.ones((1, model_config.T), dtype=jnp.int32)

        model = cls(
            model_dimension=model_config.model_dimension,
            n_heads=model_config.n_heads,
            dhR=model_config.dhR,
            dhR_blocks=model_config.dhR_blocks,
            T=model_config.T,
            vocab_size=model_config.vocab_size,
            dropout=model_config.dropout,
            blocks=model_config.blocks,
            n_experts=model_config.n_experts,
            k=model_config.k,
            moe=model_config.moe,
            latent_dim=model_config.latent_dim,
            n_shared=model_config.n_shared,
            model_dtype=jnp.bfloat16
            if (model_config.model_dtype == "bfloat16")
            else jnp.float32,
            grad_checkpoint=model_config.grad_checkpoint,
        )

        params = model.init(
            init_key,
            x,
            train=False,
        )["params"]

        _ = model.generate(
            params,
            init_key,
            x="hello",
            B=1,
            k=model_config.vocab_size,
            temperature=1,
            max_tokens=10,
        )

        return model, params

In [16]:
if __name__ == "__main__":
    import json
    import numpy as np

    def print_params(params):
        def tree_shapes(tree):
            return jax.tree_util.tree_map(lambda x: tuple(x.shape), tree)

        shapes = tree_shapes(params)
        print(json.dumps(shapes, indent=4))

    model_cfg = modelConfig(
        model_dimension=64,
        n_heads=4,
        dhR=64,
        dhR_blocks=1,
        T=32,
        vocab_size=10000,
        dropout=0.1,
        blocks=4,
        n_experts=4,
        n_shared=2,
        k=2,
        moe=False,
        latent_dim=16,
        model_dtype="bfloat16",
        grad_checkpoint=False,
    )

    key = jax.random.PRNGKey(0)
    key2 = jax.random.PRNGKey(1)
    model, params = Decoder.get_model(model_cfg, key)

    x = jnp.ones((8, model_cfg.T), dtype=jnp.int32)
    param_spec = jax.eval_shape(
        lambda key2, x: model.init(key2, x, train=False), key2, x
    )
    param_spec_out = nn.get_partition_spec(param_spec)

    init_specs = (None, P(None, "model"))
    init_fn_sharded = partial(
        shard_map, mesh=mesh, in_specs=init_specs, out_specs=param_spec_out
    )(model.init)

    params = init_fn_sharded(jax.random.PRNGKey(0), x)
    apply_fn_sharded = partial(
        shard_map,
        mesh=mesh,
        in_specs=(param_spec_out, P(None, "model")),
        out_specs=P(None, "model"),
    )(lambda key2, x: model.apply(key2, x, train=False))

    output = apply_fn_sharded(params, x)

    # print(params)
    # print_params(params)

In [17]:
out, (cache, load) = output

In [18]:
jax.tree.map(lambda x: jax.debug.visualize_array_sharding(x), params)

{'params': {'Embeddings_0': {'embedding': {'embedding': Partitioned(value=None, names=(None, 'model'), mesh=None)},
   'layer_norm': {'bias': None, 'scale': None}},
  'EncoderBlock_0': {'Block_0': {'FeedForward_0': {'FFBody_0': {'TPDense_0': {'Dense_0': {'bias': None,
        'kernel': Partitioned(value=None, names=(None, 'model'), mesh=None)}},
      'TPDense_1': {'Dense_0': {'bias': None,
        'kernel': Partitioned(value=None, names=(None, 'model'), mesh=None)}}}},
    'LayerNorm_0': {'bias': None, 'scale': None},
    'LayerNorm_1': {'bias': None, 'scale': None},
    'MLA_0': {'W_down': {'Dense_0': {'bias': None,
       'kernel': Partitioned(value=None, names=(None, 'model'), mesh=None)}},
     'W_uKV': {'Dense_0': {'bias': None,
       'kernel': Partitioned(value=None, names=(None, 'model'), mesh=None)}},
     'W_uQ': {'Dense_0': {'bias': None,
       'kernel': Partitioned(value=None, names=(None, 'model'), mesh=None)}},
     'output': {'Dense_0': {'bias': None,
       'kernel': 

In [143]:
mesh

Mesh(device_ids=array([0]), axis_names=('model',), axis_types=(Auto,))