```
gcloud compute tpus tpu-vm create jax-tpu-v4-8 --zone us-central2-b --project tpu-research-cloud-project --accelerator-type v4-8 --version tpu-vm-v4-base

sudo apt-get install python3.9 libpython3.9 --assume-yes
virtualenv --python='/usr/bin/python3.9' virtualenv-jax
source virtualenv-jax/bin/activate

pip install 'jax[tpu]' optax flax ml-collections notebook -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html torch pydantic pydantic_core tiktoken json-strong-typing fairscale blobfile



gcloud compute tpus tpu-vm ssh jax-tpu-v4-8 --project=tpu-research-cloud-project --zone=us-central2-b -- -L 8866:localhost:8866

screen -U

sudo mkdir /mnt/ramdisk
sudo mount -t tmpfs -o size=25G tmpfs /mnt/ramdisk
mkdir /mnt/ramdisk/llama3_1
time gsutil -m cp -r gs://trc-ml-us-central2/llama/Meta-Llama-3.1-8B-Instruct_split/ /mnt/ramdisk/llama3_1/

source virtualenv-jax/bin/activate
jupyter notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8866 --NotebookApp.port_retries=0 --no-browser
```

# Init

In [1]:
import jax
print(jax.__version__)
print(jax.devices())

# from jax import config
jax.config.update("jax_numpy_rank_promotion", "raise")
jax.config.update("jax_enable_x64", True)

0.4.30
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]


In [2]:
import torch
print(torch.__version__)

2.5.1+cu124


### Torch/JAX interop + generic utils

In [3]:
import jax.numpy as jnp
import pathlib
import numpy as np
np.set_printoptions(precision=20, floatmode='fixed')

def deets(t):
  return t.shape, t.dtype, t.device, torch.sum(torch.abs(t.flatten().to(dtype=torch.float64))).detach().numpy()

def deet(t):
  return t.shape, t.dtype, t.devices(), np.array(jnp.sum(jnp.abs(t.reshape(-1).astype(jnp.float64))))

def deets(t):
  x = t.to(dtype=torch.float64).detach().numpy()
  return t.shape, t.dtype, t.device, np.sum(np.abs(x.flatten()))

def deet(t):
  x = np.array(t.astype(jnp.float64).T)
  return t.shape, t.dtype, t.devices(), np.sum(np.abs(x.flatten()))

def deetnosum(t):
  return t.shape, t.dtype, t.devices()

def deetnodev(t):
  return t.shape, t.dtype

print(deets(torch.ones(4,2,3)))

print(deet(jnp.ones((4,2,3))))


(torch.Size([4, 2, 3]), torch.float32, device(type='cpu'), np.float64(24.0))
((4, 2, 3), dtype('float64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(24.0))


In [4]:
import jax.dlpack as jdp
import torch.utils.dlpack as tdp

def torch_to_jax(x, device):
  return jdp.from_dlpack(tdp.to_dlpack(x), device=device)

def jax_to_torch(x):
  return tdp.from_dlpack(jdp.to_dlpack(x))

def jax_cpu():
  print("JAX_CPU()!!!")
  return jax.devices("cpu")[0]

def jax_to_cpu(x):
  print("jax_to_cpu()!!!")
  return jax.device_put(x, jax_cpu())

def jax_tpu():
  return jax.devices("tpu")[0]

def f():
  t = torch.rand((1,3))
  print("t", deets(t), t)
  tj = torch_to_jax(t, jax_cpu())
  print("t=>j", deet(tj), tj)
  rng = jax.random.key(42)
  j = jax_to_cpu(jax.random.uniform(rng, (3,1)))
  print("j", deet(j), j)
  jt = jax_to_torch(j)
  print("j=>t", deets(jt), jt)

f()

def f():
  torch_tensor = torch.randn(2, 3, dtype=torch.bfloat16)
  print("torch_tensor", torch_tensor)
  torch_np = torch_tensor.to(dtype=torch.float64).numpy()
  print("torch numpy conversion", torch_np, torch_np.dtype, type(torch_np))

  jax_array = torch_to_jax(torch_tensor, jax_cpu())
  print("jax_array", jax_array, jax_array.dtype)
  jax_np = np.array(jnp.astype(jax_array, jnp.float64))
  print("jax numpy conversion", jax_np, jax_np.dtype, type(jax_np))

  print("diff", np.sum(np.abs(jax_np - torch_np)), jax_np - torch_np)

f()

t (torch.Size([1, 3]), torch.float32, device(type='cpu'), np.float64(2.0716968178749084)) tensor([[0.2830, 0.9367, 0.8520]])
JAX_CPU()!!!
t=>j ((1, 3), dtype('float32'), {CpuDevice(id=0)}, np.float64(2.0716968178749084)) [[0.28296488523483276367 0.93671125173568725586 0.85202068090438842773]]
jax_to_cpu()!!!
JAX_CPU()!!!
j ((3, 1), dtype('float64'), {CpuDevice(id=0)}, np.float64(2.4713129808532663)) [[0.72981889690463219722]
 [0.86919382428439107002]
 [0.87230025966424307171]]
j=>t (torch.Size([3, 1]), torch.float64, device(type='cpu'), np.float64(2.4713129808532663)) tensor([[0.7298],
        [0.8692],
        [0.8723]], dtype=torch.float64)
torch_tensor tensor([[-0.8867,  0.0630, -1.4844],
        [ 1.7812,  2.1719,  0.1934]], dtype=torch.bfloat16)
torch numpy conversion [[-0.88671875000000000000  0.06298828125000000000 -1.48437500000000000000]
 [ 1.78125000000000000000  2.17187500000000000000  0.19335937500000000000]] float64 <class 'numpy.ndarray'>
JAX_CPU()!!!
jax_array [[-0.88671

In [5]:
pprint_enabled = True

def enable_pprint(v):
  global pprint_enabled
  pprint_enabled = v

def pprint(*args):
  if pprint_enabled:
    print(*args)

def pprint_d(msg, x):
  if pprint_enabled:
    pprint(msg, deet(x))


In [6]:
def load_torch_weights(filename, device=None):
  torch_weights = torch.load(
      filename, weights_only=True, map_location=torch.device("cpu"), mmap=True
  )
  # print("torch_weights", deets(torch_weights))
  jax_weights = torch_to_jax(torch_weights, device=device)
  # print("jax_weights", deet(jax_weights))
  return jax_weights


In [7]:
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

NUM_DEVICES = 4

jax_mesh_x = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh([NUM_DEVICES]), axis_names=('x'))
jax_sharding_x = jax.sharding.NamedSharding(jax_mesh_x, P('x'))

def shard_array(arr):
  shape = arr.shape
  # print("arr", deet(arr))
  split = jnp.split(arr, NUM_DEVICES, axis=0)
  tpu = jax.devices("tpu")
  for i in range(NUM_DEVICES):
    split[i] = jax.device_put(split[i], tpu[i])
  # print("split", [deet(x) for x in split])
  recombined = jax.make_array_from_single_device_arrays(shape, jax_sharding_x, split)
  return recombined

def load_torch_weights_sharded(filename, device=None):
  torch_weights = torch.load(
      filename, weights_only=True, map_location=torch.device("cpu"), mmap=True
  )
  # print("torch_weights", deets(torch_weights))
  cpu_device = jax.devices("cpu")[0]
  jax_weights = torch_to_jax(torch_weights, device=cpu_device)
  # jax_weights = jax.device_put(x, jax.sharding.NamedSharding(jax_mesh_2_2, P('x', 'y')))
  # print("jax_weights", deet(jax_weights))
  jax_weights = shard_array(jax_weights)
  return jax_weights

  # mesh = Mesh(devices=mesh_utils.create_device_mesh([4]), axis_names=('x'))
  # print("mesh", mesh)
  # print("sharding", sharding)
  # arr = jax.make_array_from_single_device_arrays((4*1024,1024,1024), sharding, [y1,y2,y3,y4])

def f():
  with jax.default_device(jax.devices("cpu")[0]):
    nums = jnp.arange(8*50*10)
    arr = nums.reshape((8,50,10))
  print("arr", deet(arr))
  for x in range(8):
    print(x, deet(arr[x]))
  arr = shard_array(arr)
  print("arr", deet(arr))
  for x in range(8):
    print(x, deet(arr[x]))

f()

arr ((8, 50, 10), dtype('int64'), {CpuDevice(id=0)}, np.float64(7998000.0))
0 ((50, 10), dtype('int64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(124750.0))
1 ((50, 10), dtype('int64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(374750.0))
2 ((50, 10), dtype('int64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(624750.0))
3 ((50, 10), dtype('int64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(874750.0))
4 ((50, 10), dtype('int64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(1124750.0))
5 ((50, 10), dtype('int64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(1374750.0))
6 ((50, 10), dtype('int64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(1624750.0))
7 ((50, 10), dtype('int64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_c

# Memory profiling

In [8]:
def f():
  arr = jnp.zeros((int(512*2),1024,1024))
  print(arr.shape, arr.devices())

f()

(1024, 1024, 1024) {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}


In [9]:
jax_transformer_params_tpu = None
jax_freqs_cis = None

In [10]:
# jax.clear_caches()

In [11]:
# %reset -f

In [12]:
import collections

def f():
  tagged = {}

  def tagit(x, label):
      tagged[id(x)] = label
  if jax_transformer_params_tpu is not None:
    jax.tree_util.tree_map(lambda x: tagit(x, "params"), jax_transformer_params_tpu)
  if jax_freqs_cis is not None:
    tagged[id(jax_freqs_cis)] = "freqs_cis"

  arrs = jax.live_arrays("tpu")
  cnt = collections.Counter()
  tot = collections.Counter()
  for idx, arr in enumerate(arrs):
    sz = jnp.prod(jnp.array(arr.shape))*2
    tag = tagged.get(id(arr))
    # if tag is None or True:
      # print(idx, deetnosum(arr), sz // 1_000_000, tag)
    cnt[tag] += 1
    tot[tag] += sz / 1_000_000
  # print(cnt, tot // 1_000_000)
  for k in cnt:
    print(k, cnt[k], "%.0fM" % (tot[k] // 1))

f()

dump_mem = f

# -- init --
# freqs_cis 1 0M
# params 291 16060M

# -- hmm --
# None 194 6442M
# freqs_cis 1 0M
# params 291 16060M

# Imports

### llama_models datatypes

In [13]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

from enum import Enum
from typing import Any, Dict, Optional

from pydantic import BaseModel

from strong_typing.schema import json_schema_type


@json_schema_type
class SamplingStrategy(Enum):
    greedy = "greedy"
    top_p = "top_p"
    top_k = "top_k"


@json_schema_type
class SamplingParams(BaseModel):
    strategy: SamplingStrategy = SamplingStrategy.greedy

    temperature: Optional[float] = 0.0
    top_p: Optional[float] = 0.95
    top_k: Optional[int] = 0
    max_tokens: Optional[int] = 0
    repetition_penalty: Optional[float] = 1.0


@json_schema_type(
    schema={
        "description": "The format in which weights are specified. This does not necessarily always equal what quantization is desired at runtime since there can be on-the-fly conversions done.",
    }
)
class CheckpointQuantizationFormat(Enum):
    # default format
    bf16 = "bf16"

    # used for enabling fp8_rowwise inference, some weights are bf16
    fp8_mixed = "fp8_mixed"


@json_schema_type
class ModelSKU(Enum):
    llama3_1_8b = "llama3_1_8b"
    llama3_1_70b = "llama3_1_70b"
    llama3_1_405b_fp8_mp8 = "llama3_1_405b_fp8_mp8"
    llama3_1_405b_bf16_mp8 = "llama3_1_405b_bf16_mp8"
    llama3_1_405b_bf16_mp16 = "llama3_1_405b_bf16_mp16"

    llama3_1_8b_instruct = "llama3_1_8b_instruct"
    llama3_1_70b_instruct = "llama3_1_70b_instruct"
    llama3_1_405b_instruct_fp8_mp8 = "llama3_1_405b_instruct_fp8_mp8"
    llama3_1_405b_instruct_bf16_mp8 = "llama3_1_405b_instruct_bf16_mp8"
    llama3_1_405b_instruct_bf16_mp16 = "llama3_1_405b_instruct_bf16_mp16"


@json_schema_type
class HardwareRequirements(BaseModel):
    memory_gb_per_gpu: int
    gpu_count: int


@json_schema_type(
    schema={
        "description": "The model family and SKU of the model along with other parameters corresponding to the model."
    }
)
class ModelDefinition(BaseModel):
    sku: ModelSKU
    description_markdown: str
    max_seq_length: int
    huggingface_id: Optional[str] = None
    hardware_requirements: HardwareRequirements
    quantization_format: CheckpointQuantizationFormat = (
        CheckpointQuantizationFormat.bf16
    )
    recommended_sampling_params: Optional[SamplingParams] = None
    model_args: Dict[str, Any]


# TODO: resolve these types against the model SKUs above
@json_schema_type(
    schema={
        "description": "The type of the model. This is used to determine the model family and SKU."
    }
)
class PretrainedModel(Enum):
    llama3_8b = "llama3_8b"
    llama3_70b = "llama3_70b"


@json_schema_type
class InstructModel(Enum):
    llama3_8b_chat = "llama3_8b_chat"
    llama3_70b_chat = "llama3_70b_chat"


@json_schema_type
class RewardModel(Enum):
    llama3_70b_reward = "llama3_70b_reward"
    llama3_405b_reward = "llama3_405b_reward"




### datatypes

In [14]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

from enum import Enum
from typing import Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field

from strong_typing.schema import json_schema_type
from typing_extensions import Annotated


@json_schema_type
class Role(Enum):
    system = "system"
    user = "user"
    assistant = "assistant"
    ipython = "ipython"


@json_schema_type(
    schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
class URL(BaseModel):
    uri: str

    def __str__(self) -> str:
        return self.uri


@json_schema_type
class Attachment(BaseModel):
    url: URL
    mime_type: str


InterleavedTextAttachment = Union[
    str,
    Attachment,
    List[Union[str, Attachment]],
]


@json_schema_type
class BuiltinTool(Enum):
    brave_search = "brave_search"
    wolfram_alpha = "wolfram_alpha"
    photogen = "photogen"
    code_interpreter = "code_interpreter"


Primitive = Union[str, int, float, bool, None]
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]


@json_schema_type
class ToolCall(BaseModel):
    call_id: str
    tool_name: Union[BuiltinTool, str]
    arguments: Dict[str, RecursiveType]


@json_schema_type
class ToolResponse(BaseModel):
    call_id: str
    tool_name: Union[BuiltinTool, str]
    content: InterleavedTextAttachment


@json_schema_type
class ToolParamDefinition(BaseModel):
    param_type: str
    description: Optional[str] = None
    required: Optional[bool] = True


@json_schema_type
class ToolDefinition(BaseModel):
    tool_name: Union[BuiltinTool, str]
    description: Optional[str] = None
    parameters: Optional[Dict[str, ToolParamDefinition]] = None


@json_schema_type
class UserMessage(BaseModel):
    role: Literal[Role.user.value] = Role.user.value
    content: InterleavedTextAttachment


@json_schema_type
class SystemMessage(BaseModel):
    role: Literal[Role.system.value] = Role.system.value
    content: InterleavedTextAttachment


@json_schema_type
class ToolResponseMessage(BaseModel):
    role: Literal[Role.ipython.value] = Role.ipython.value
    # it was nice to re-use the ToolResponse type, but having all messages
    # have a `content` type makes things nicer too
    call_id: str
    tool_name: Union[BuiltinTool, str]
    content: InterleavedTextAttachment


@json_schema_type
class StopReason(Enum):
    end_of_turn = "end_of_turn"
    end_of_message = "end_of_message"
    out_of_tokens = "out_of_tokens"


@json_schema_type
class TokenLogProbs(BaseModel):
    logprobs_by_token: Dict[str, float]


@json_schema_type
class CompletionMessage(BaseModel):
    role: Literal[Role.assistant.value] = Role.assistant.value
    content: InterleavedTextAttachment
    stop_reason: StopReason
    tool_calls: List[ToolCall] = Field(default_factory=list)


Message = Annotated[
    Union[
        UserMessage,
        SystemMessage,
        ToolResponseMessage,
        CompletionMessage,
    ],
    Field(discriminator="role"),
]

### tiktoken

In [15]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import os
from logging import getLogger
from pathlib import Path
from typing import (
    AbstractSet,
    cast,
    Collection,
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Sequence,
    Union,
)

import tiktoken

from tiktoken.load import load_tiktoken_bpe

logger = getLogger(__name__)


# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000

# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000


class Tokenizer:
    """
    Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
    """

    special_tokens: Dict[str, int]

    num_reserved_special_tokens = 256

    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501

    def __init__(self, model_path: str):
        """
        Initializes the Tokenizer with a Tiktoken model.

        Args:
            model_path (str): The path to the Tiktoken model file.
        """
        assert os.path.isfile(model_path), model_path

        mergeable_ranks = load_tiktoken_bpe(model_path)
        num_base_tokens = len(mergeable_ranks)
        special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|finetune_right_pad_id|>",
            "<|step_id|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|eom_id|>",  # end of message
            "<|eot_id|>",  # end of turn
            "<|python_tag|>",
        ]
        reserved_tokens = [
            f"<|reserved_special_token_{2 + i}|>"
            for i in range(self.num_reserved_special_tokens - len(special_tokens))
        ]
        special_tokens = special_tokens + reserved_tokens

        self.special_tokens = {
            token: num_base_tokens + i for i, token in enumerate(special_tokens)
        }
        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=self.pat_str,
            mergeable_ranks=mergeable_ranks,
            special_tokens=self.special_tokens,
        )

        self.n_words: int = num_base_tokens + len(special_tokens)
        # BOS / EOS token IDs
        self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
        self.eos_id: int = self.special_tokens["<|end_of_text|>"]
        self.eot_id: int = self.special_tokens["<|eot_id|>"]
        self.eom_id: int = self.special_tokens["<|eom_id|>"]
        self.python_tag_id = self.special_tokens["<|python_tag|>"]
        self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
        self.stop_tokens = [
            self.special_tokens["<|eom_id|>"],
            self.special_tokens["<|eot_id|>"],
        ]

    def encode(
        self,
        s: str,
        *,
        bos: bool,
        eos: bool,
        allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
        disallowed_special: Union[Literal["all"], Collection[str]] = (),
    ) -> List[int]:
        """
        Encodes a string into a list of token IDs.

        Args:
            s (str): The input string to be encoded.
            bos (bool): Whether to prepend the beginning-of-sequence token.
            eos (bool): Whether to append the end-of-sequence token.
            allowed_tokens ("all"|set[str]): allowed special tokens in string
            disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string

        Returns:
            list[int]: A list of token IDs.

        By default, setting disallowed_special=() encodes a string by ignoring
        special tokens. Specifically:
        - Setting `disallowed_special` to () will cause all text corresponding
          to special tokens to be encoded as natural text (insteading of raising
          an error).
        - Setting `allowed_special` to "all" will treat all text corresponding
          to special tokens to be encoded as special tokens.
        """
        if allowed_special is None:
            allowed_special = set()
        assert type(s) is str

        substrs = (
            substr
            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
            for substr in self._split_whitespaces_or_nonwhitespaces(
                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
            )
        )
        t: List[int] = []
        for substr in substrs:
            t.extend(
                self.model.encode(
                    substr,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            )
        if bos:
            t.insert(0, self.bos_id)
        if eos:
            t.append(self.eos_id)
        return t

    def decode(self, t: Sequence[int]) -> str:
        """
        Decodes a list of token IDs into a string.

        Args:
            t (List[int]): The list of token IDs to be decoded.

        Returns:
            str: The decoded string.
        """
        # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
        return self.model.decode(cast(List[int], t))

    @staticmethod
    def _split_whitespaces_or_nonwhitespaces(
        s: str, max_consecutive_slice_len: int
    ) -> Iterator[str]:
        """
        Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
        consecutive whitespaces or consecutive non-whitespaces.
        """
        current_slice_len = 0
        current_slice_is_space = s[0].isspace() if len(s) > 0 else False
        slice_start = 0

        for i in range(len(s)):
            is_now_space = s[i].isspace()

            if current_slice_is_space ^ is_now_space:
                current_slice_len = 1
                current_slice_is_space = is_now_space
            else:
                current_slice_len += 1
                if current_slice_len > max_consecutive_slice_len:
                    yield s[slice_start:i]
                    slice_start = i
                    current_slice_len = 1
        yield s[slice_start:]

### ChatFormat

In [16]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

import uuid

from dataclasses import dataclass
from typing import Dict, List

@dataclass
class ModelInput:
    tokens: List[int]


class ChatFormat:
    possible_headers: Dict[Role, str]

    def __init__(self, tokenizer: Tokenizer):
        self.tokenizer = tokenizer
        self.possible_headers = {
            role: f"<|start_header_id|>{role.value}<|end_header_id|>\n\n"
            for role in Role
        }

    def encode_header(self, role: str) -> List[int]:
        tokens = []
        tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
        tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
        tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
        tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
        return tokens

    def encode_message(self, message: Message) -> List[int]:
        tokens = self.encode_header(message.role)

        def _process_content(content: InterleavedTextAttachment):
            def _process(c):
                if isinstance(c, str):
                    tokens.extend(self.tokenizer.encode(c, bos=False, eos=False))

            if isinstance(content, str):
                _process(content)
            elif isinstance(content, list):
                for c in content:
                    _process(c)

        if isinstance(message, CompletionMessage) and len(message.tool_calls) > 0:
            tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])

        _process_content(message.content)

        if isinstance(message, CompletionMessage):
            for t in message.tool_calls:
                content = ToolUtils.encode_tool_call(t)
                _process_content(content)

        eom = False
        if isinstance(message, CompletionMessage):
            eom = message.stop_reason == StopReason.end_of_message

        tokens.append(
            self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"]
        )
        return tokens

    def encode_dialog_prompt(self, messages: List[Message]) -> ModelInput:
        tokens = []
        tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
        for message in messages:
            toks = self.encode_message(message)
            tokens.extend(toks)

        # Add the start of an assistant message for the model to complete.
        tokens.extend(self.encode_header(Role.assistant.value))

        return ModelInput(tokens=tokens)

    # TODO(this should be generic, not only for assistant messages)
    def decode_assistant_message(
        self, tokens: List[int], stop_reason: StopReason
    ) -> CompletionMessage:
        content = self.tokenizer.decode(tokens)
        content = content.strip(" ")
        for _, header_str in self.possible_headers.items():
            if content.startswith(header_str):
                content = content[len(header_str) :]
                break

        ipython = content.startswith("<|python_tag|>")
        if ipython:
            content = content[len("<|python_tag|>") :]

        eot = content.endswith("<|eot_id|>")
        if eot:
            content = content[: -len("<|eot_id|>")]
        else:
            content = content[: -len("<|eom_id|>")]

        tool_name = None
        tool_arguments = {}

        custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
        if custom_tool_info is not None:
            tool_name, tool_arguments = custom_tool_info
            # Sometimes when agent has custom tools alongside builin tools
            # Agent responds for builtin tool calls in the format of the custom tools
            # This code tries to handle that case
            if tool_name in BuiltinTool.__members__:
                tool_name = BuiltinTool[tool_name]
                tool_arguments = {
                    "query": list(tool_arguments.values())[0],
                }
        else:
            builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
            if builtin_tool_info is not None:
                tool_name, query = builtin_tool_info
                tool_arguments = {
                    "query": query,
                }
                if tool_name in BuiltinTool.__members__:
                    tool_name = BuiltinTool[tool_name]
            elif ipython:
                tool_name = BuiltinTool.code_interpreter
                tool_arguments = {
                    "code": content,
                }

        tool_calls = []
        if tool_name is not None and tool_arguments is not None:
            call_id = str(uuid.uuid4())
            tool_calls.append(
                ToolCall(
                    call_id=call_id,
                    tool_name=tool_name,
                    arguments=tool_arguments,
                )
            )
            content = ""

        return CompletionMessage(
            content=content,
            stop_reason=stop_reason,
            tool_calls=tool_calls,
        )

# End-to-end JAX

### Hyperparams

In [17]:
params_dim = 4096
params_n_layers = 32
params_n_heads = 32
params_n_kv_heads = 8
params_vocab_size = 128256
params_ffn_dim_multiplier = 1.3
params_multiple_of = 1024
params_norm_eps = 1e-05
params_rope_theta = 500000.0
params_use_scaled_rope = True

params_max_batch_size = 32
# params_max_seq_len = 2048
params_max_seq_len = 512

model_parallel_size = 1
n_kv_heads = params_n_heads if params_n_kv_heads is None else params_n_kv_heads
n_local_heads = params_n_heads // model_parallel_size
n_local_kv_heads = n_kv_heads // model_parallel_size


### Attention

In [18]:
import flax

class JaxRMSNorm(flax.linen.Module):
    dim: int
    eps: float

    @flax.linen.compact
    def __call__(self, x):
        pprint_d("rms_norm input x", x)
        # print("JaxRMSNorm input x", deetnosum(x))
        weight = self.param('weight', flax.linen.initializers.ones, (self.dim,), dtype=jnp.bfloat16)
        norm_x = x.astype(jnp.float32)
        pprint_d("norm_x", norm_x)
        xpow2 = norm_x ** 2
        pprint_d("xpow2", xpow2)
        xpow2mean = xpow2.mean(-1, keepdims=True)
        pprint_d("xpow2mean", xpow2mean)
        pprint("self.eps", self.eps)
        norm_inner = xpow2mean + self.eps
        pprint_d("norm_inner", norm_inner)
        norm_rsqrt = jax.lax.rsqrt(norm_inner)
        pprint_d("norm_rsqrt", norm_rsqrt)
        norm_out = norm_x * norm_rsqrt
        pprint_d("norm_out", norm_out)
        output = norm_out.astype(x.dtype)
        pprint_d("weight", weight)
        pprint_d("output", output)
        result = output * jnp.broadcast_to(weight, output.shape)
        pprint_d("result", result)
        return result

def jax_reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.reshape(*shape)


def jax_apply_rotary_emb(
    xq,
    xk,
    freqs_cis,
):
  pprint_d("jax_apply_rotary xq", xq)
  pprint_d("jax_apply_rotary xk", xk)
  pprint_d("jax_apply_rotary freqs_cis", freqs_cis)
  xqqq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
  xkkk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
  pprint_d("jax_apply_rotary xqqq", xqqq)
  pprint_d("jax_apply_rotary xkkk", xkkk)
  xq_ = xqqq.view(jnp.complex64).squeeze(-1)
  xk_ = xkkk.view(jnp.complex64).squeeze(-1)
  pprint_d("jax_apply_rotary xq_", xq_)
  pprint_d("jax_apply_rotary xk_", xk_)
  freqs_cis = jax_reshape_for_broadcast(freqs_cis, xq_)
  pprint_d("jax_apply_rotary reshaped freqs_cis", freqs_cis)
  xq_outf = (xq_ * freqs_cis).view(jnp.float32)
  xk_outf = (xk_ * freqs_cis).view(jnp.float32)
  xq_out = xq_outf.reshape(*xq_outf.shape[:3], -1)
  xk_out = xk_outf.reshape(*xk_outf.shape[:3], -1)
  pprint_d("jax_apply_rotary xq_out", xq_out)
  pprint_d("jax_apply_rotary xk_out", xk_out)
  xqoo = xq_out.astype(xq.dtype)
  xkoo = xk_out.astype(xk.dtype)
  pprint_d("jax_apply_rotary xqoo", xqoo)
  pprint_d("jax_apply_rotary xkoo", xkoo)
  return xqoo, xkoo

def jax_repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
  pprint_d("jax_repeat_kv x", x)
  pprint("jax_repeat_kv n_rep", n_rep)
  bs, slen, n_kv_heads, head_dim = x.shape
  if n_rep == 1:
      return x
  out = x[:, :, :, None, :]
  out = jnp.broadcast_to(out, (bs, slen, n_kv_heads, n_rep, head_dim))
  out = out.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
  pprint_d("jax_repeat_kv out", out)
  return out


In [19]:
from typing import Any

# @jax.jit
# def dyn_set_slice(arr, start_pos, seqlen):
#     # Compute the position where the masking should start
#     update_start = start_pos + seqlen

#     # Create a mask that is True before `update_start` and False after
#     mask = jnp.arange(swapaxekeys.shape[-1]) < update_start  # Shape: (last_dim,)

#     # Reshape mask to be broadcastable over the other dimensions
#     mask = mask[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]  # Shape: (1, 1, 1, last_dim)

#     # Use jnp.where to set values to 0 where mask is False
#     return jnp.where(mask, swapaxekeys, 0)

class JaxCache(flax.linen.Module):
  shape: Any

    # is_initialized = self.has_variable('cache', 'cached_value')
    # model_parallel_size = 1
    # n_local_kv_heads = params_n_kv_heads // model_parallel_size
    # head_dim = params_dim // params_n_heads

    # (params_max_batch_size, params_max_seq_len, n_local_kv_heads, head_dim)

  @flax.linen.compact
  def __call__(self, x, start_pos):
    is_initialized = self.has_variable('cache', 'cached_value')
    cached_value = self.variable(
        'cache', 'cached_value', jnp.zeros, shape=self.shape, dtype=jnp.bfloat16
    )
    if is_initialized:
      # pprint_d("before updating cache", cached_value.value)
      indices = (0, start_pos, 0, 0)
      value = jax.lax.dynamic_update_slice(cached_value.value, x, indices)
      cached_value.value = value
      # pprint_d("after updating cache", cached_value.value)
    else:
      print("INITIALIZING CACHE!!!")
    return cached_value.value

# attention module

class JaxAttention(flax.linen.Module):
  dim: int
  n_heads: int
  n_kv_heads: int
  n_local_heads: int
  n_local_kv_heads: int
  n_rep: int
  head_dim: int
  max_batch_size: int
  max_seq_len: int

  @flax.linen.compact
  def __call__(self, x, start_pos, freqs_cis, mask):
    print("JaxAttention")
    bsz, seqlen, _ = x.shape
    pprint(f"{bsz=} {seqlen=}")

    pprint(f"JaxAttention {start_pos=} {seqlen=}")

    pprint_d("jax_attn_forward layer 0 x", x)

    xq = flax.linen.Dense(features=self.n_heads * self.head_dim, use_bias=False, name="wq", param_dtype=jnp.bfloat16)(x)
    pprint_d("jax_pre-view xq in layer 0", xq)
    xk = flax.linen.Dense(features=self.n_kv_heads * self.head_dim, use_bias=False, name="wk", param_dtype=jnp.bfloat16)(x)
    pprint_d("jax_pre-view xk in layer 0", xk)
    xv = flax.linen.Dense(features=self.n_kv_heads * self.head_dim, use_bias=False, name="wv", param_dtype=jnp.bfloat16)(x)
    pprint_d("jax_pre-view xv in layer 0", xv)

    xq = xq.reshape((bsz, seqlen, self.n_local_heads, self.head_dim))
    pprint_d("jax_initial xq in layer 0", xq)
    xk = xk.reshape((bsz, seqlen, self.n_local_kv_heads, self.head_dim))
    pprint_d("jax_initial xk in layer 0", xk)
    xv = xv.reshape((bsz, seqlen, self.n_local_kv_heads, self.head_dim))
    pprint_d("jax_initial xv in layer 0", xv)


    xq, xk = jax_apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
    pprint_d("jax_rotaried xq in layer 0", xq)
    pprint_d("jax_rotaried xk in layer 0", xk)

    cache_shape = (self.max_batch_size, self.max_seq_len, self.n_local_kv_heads, self.head_dim)
    # keys = JaxCache_jit(name='keys_cache', shape=cache_shape)(x=xk, start_pos=start_pos)
    # values = JaxCache_jit(name='values_cache', shape=cache_shape)(x=xv, start_pos=start_pos)
    keys = JaxCache(name='keys_cache', shape=cache_shape)(x=xk, start_pos=start_pos)
    values = JaxCache(name='values_cache', shape=cache_shape)(x=xv, start_pos=start_pos)
    pprint_d("jax_cache keys", keys)
    pprint_d("jax_cache values", values)

    # keys = keys[:bsz, : start_pos + seqlen]
    # values = values[:bsz, : start_pos + seqlen]
    keys = keys[:bsz]
    values = values[:bsz]
    pprint_d("jax_initial keys", keys)
    pprint_d("jax_initial values", values)
    # keys = keys.at[:, start_pos + seqlen:].set(0)
    # values = values.at[:, start_pos + seqlen:].set(0)
    # pprint_d("jax_tozero keys", keys)
    # pprint_d("jax_tozero values", values)

    for i in range(50):
      pprint_d(f"keys {i=}", keys[:, i, :, :])
    for i in range(50):
      pprint_d(f"values {i=}", values[:, i, :, :])
# jax_initial keys ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))

    # repeat k/v heads if n_kv_heads < n_heads
    keys = jax_repeat_kv(
        keys, self.n_rep
    )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
    values = jax_repeat_kv(
        values, self.n_rep
    )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

    pprint_d("jax_rep keys", keys)
    pprint_d("jax_rep values", values)

# jax_cache keys ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_cache values ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_initial keys ((1, 26, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(21245.047980487347))
# jax_initial values ((1, 26, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(401.14362750109285))
# jax_repeat_kv x ((1, 26, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(21245.047980487347))
# jax_repeat_kv n_rep 4
# jax_repeat_kv out ((1, 26, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(84980.19192194939))
# jax_repeat_kv x ((1, 26, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(401.14362750109285))
# jax_repeat_kv n_rep 4
# jax_repeat_kv out ((1, 26, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(1604.5745100043714))
# jax_rep keys ((1, 26, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(84980.19192194939))
# jax_rep values ((1, 26, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(1604.5745100043714))
# jax_transp xq ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(66869.51667878032))
# jax_transp keys ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(84980.19192194939))
# jax_transp values ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(1604.5745100043714))
# jax_swapaxekeys ((1, 32, 128, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(84980.19192194939))
# jax_initial scores ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(40094.54039424658))
# jax_mask ((26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_masked scores ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_softmax scores ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(832.0088773764828))
# jax_initial output ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(473.6378316304181))

# jax_cache keys ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_cache values ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_initial keys ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_initial values ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_repeat_kv x ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_repeat_kv n_rep 4
# jax_repeat_kv out ((1, 2048, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_repeat_kv x ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_repeat_kv n_rep 4
# jax_repeat_kv out ((1, 2048, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_rep keys ((1, 2048, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_rep values ((1, 2048, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_transp xq ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(66869.51667878032))
# jax_transp keys ((1, 32, 2048, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_transp values ((1, 32, 2048, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_swapaxekeys ((1, 32, 128, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_initial scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_mask ((26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_masked scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_softmax scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_initial output ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_transp output ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))

    xq = jnp.swapaxes(xq, 1, 2)
    keys = jnp.swapaxes(keys, 1, 2)
    values = jnp.swapaxes(values, 1, 2)

# jax_cache keys ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_cache values ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_initial keys ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_initial values ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_repeat_kv x ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_repeat_kv n_rep 4
# jax_repeat_kv out ((1, 2048, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_repeat_kv x ((1, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_repeat_kv n_rep 4
# jax_repeat_kv out ((1, 2048, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_rep keys ((1, 2048, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_rep values ((1, 2048, 32, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_transp xq ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(66869.51667878032))
# jax_transp keys ((1, 32, 2048, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_transp values ((1, 32, 2048, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_swapaxekeys ((1, 32, 128, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_initial scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_mask ((26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_masked scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_masked scores sliced ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_softmax scores ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(832.0088773764828))

    pprint_d("jax_transp xq", xq)
    pprint_d("jax_transp keys", keys)
    pprint_d("jax_transp values", values)

    swapaxekeys = jnp.swapaxes(keys, 2, 3)
    pprint_d("jax_swapaxekeys", swapaxekeys)
# jax_swapaxekeys ((1, 32, 128, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
    # masked_swapaxekeys = swapaxekeys.at[:, :, :, start_pos + seqlen:].set(0)
    masked_swapaxekeys = swapaxekeys
    scores = jnp.matmul(xq, masked_swapaxekeys) / math.sqrt(self.head_dim)
    pprint_d("jax_initial scores", scores)
    if mask is not None:
        pprint_d("jax_mask", mask)
        # scores = scores + jnp.broadcast_to(mask, scores.shape)  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        bcmask = jnp.broadcast_to(mask, scores.shape)
        scores = jnp.where(bcmask > -1000, scores, bcmask)
# jax_masked scores ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_masked scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
    pprint_d("jax_mask-masked scores", scores)
    # for i in range(start_pos + seqlen):
    for i in range(50):
      pprint_d(f"jax_score {i=}", scores[:, :, :, i])
    # pprint_d("jax_mask-masked scores sliced", scores[:, :, :, : start_pos + seqlen])

# jax_alt_masked_scores ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_alt_scores ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(832.0088773764828))

# jax_idx-masked scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_idx-masked scores slice ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_softmax scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))

    # if True:
    if False:
      alt_masked_scores = scores[:, :, :, : start_pos + seqlen]
      pprint_d("jax_alt_masked_scores", alt_masked_scores)
      alt_scores = jax.nn.softmax(alt_masked_scores.astype(jnp.float32), axis=-1).astype(xq.dtype)
      pprint_d("jax_alt_scores", alt_scores)
    # masked_scores = scores[:, :, :, : start_pos + seqlen]
    # masked_scores = scores.at[:, :, :, start_pos + seqlen:].set(-jnp.inf)
    masked_scores = jnp.where(jnp.arange(scores.shape[-1]) < start_pos + seqlen, scores, -jnp.inf)
    pprint_d("jax_idx-masked scores", masked_scores)
    # pprint_d("jax_idx-masked scores slice", masked_scores[:, :, :, : start_pos + seqlen])
    scores = jax.nn.softmax(masked_scores.astype(jnp.float32), axis=-1).astype(xq.dtype)
    pprint_d("jax_softmax scores", scores)
# jax_transp values ((1, 32, 2048, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_transp values ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(1604.5745100043714))
    # masked_values = values[:, :, : start_pos + seqlen, :]
    # masked_values = values.at[:, :, start_pos + seqlen:, :].set(0)
    masked_values = values
    # pprint_d("jax_masked_values", masked_values)
    output = jnp.matmul(scores, masked_values)  # (bs, n_local_heads, seqlen, head_dim)
    initial_output = output
    pprint_d("jax_initial output", output)
    output = jnp.swapaxes(output, 1, 2)
    output = output.reshape(bsz, seqlen, -1)
    pprint_d("jax_transp output", output)
    wo_out = flax.linen.Dense(features=self.dim, use_bias=False, name="wo", param_dtype=jnp.bfloat16)(output)
    pprint_d("jax_wo_out", wo_out)

    # ll = start_pos + seqlen
    # pprint_d("jax_softmax scores sliced", scores[:, :, :, :ll])
# jax_softmax scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_initial output ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_transp output ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_wo_out ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))

# jax_softmax scores ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(832.0088773764828))
# jax_initial output ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(473.6378316304181))
# jax_transp output ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(473.6378316304181))
# jax_wo_out ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(377.07218306395225))

# jax_score i=47 ((1, 32, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_score i=48 ((1, 32, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_score i=49 ((1, 32, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_masked scores sliced ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_softmax scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(832.0088773764828))
# jax_initial output ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_transp output ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_wo_out ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))
# jax_softmax scores sliced ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(832.0088773764828))
# h_attn ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(nan))

# jax_masked scores sliced ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(inf))
# jax_softmax scores ((1, 32, 26, 2048), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(832.0088773764828))
# jax_initial output ((1, 32, 26, 128), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(473.6378316304181))
# jax_transp output ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(473.6378316304181))
# jax_wo_out ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(377.07218306395225))
# jax_softmax scores sliced ((1, 32, 26, 26), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(832.0088773764828))
# h_attn ((1, 26, 4096), dtype(bfloat16), {CpuDevice(id=0)}, np.float64(377.07218306395225))

    return wo_out

class JaxFeedForward(flax.linen.Module):
  dim: int
  hidden_dim: int
  ffn_dim_multiplier: float
  multiple_of: int

  @flax.linen.compact
  def __call__(self, x):
    print("JaxFeedForward")
    dim = self.dim
    hidden_dim = self.hidden_dim
    ffn_dim_multiplier = self.ffn_dim_multiplier
    multiple_of = self.multiple_of

    hidden_dim = int(2 * hidden_dim / 3)
    # custom dim factor multiplier
    if ffn_dim_multiplier is not None:
        hidden_dim = int(ffn_dim_multiplier * hidden_dim)
    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

    w1 = flax.linen.Dense(features=hidden_dim, use_bias=False, name="w1", param_dtype=jnp.bfloat16)(x)
    w1silu = jax.nn.silu(w1)
    w3 = flax.linen.Dense(features=hidden_dim, use_bias=False, name="w3", param_dtype=jnp.bfloat16)(x)
    inner = w1silu * w3
    w2 = flax.linen.Dense(features=dim, use_bias=False, name="w2", param_dtype=jnp.bfloat16)(inner)
    return w2



class JaxTransformerBlock(flax.linen.Module):
  dim: int
  attn_n_heads: int
  attn_n_kv_heads: int
  attn_n_local_heads: int
  attn_n_local_kv_heads: int
  attn_n_rep: int
  attn_head_dim: int
  attn_norm_eps: float
  ffn_norm_eps: float
  hidden_dim: int
  ffn_dim_multiplier: float
  multiple_of: int
  max_batch_size: int
  max_seq_len: int

  @flax.linen.compact
  def __call__(self, x, start_pos, freqs_cis, mask):
    print(f"JaxTransformerBlock {start_pos=} {freqs_cis.shape=} mask.shape={mask.shape if mask is not None else None}")
    attn_x = JaxRMSNorm(dim=self.dim, eps=self.attn_norm_eps, name="attn_norm")(x)
    pprint_d("attn_x", attn_x)
    # print("attn_x", deet(attn_x))
    # print("start_pos", start_pos)
    # print("freqs_cis", deet(freqs_cis))
    # print("mask", deet(mask) if mask is not None else None)
    h_attn = JaxAttention(
    # h_attn = JaxAttention_jit(
      name="attn",
      dim=self.dim,
      n_heads=self.attn_n_heads,
      n_kv_heads=self.attn_n_kv_heads,
      n_local_heads=self.attn_n_local_heads,
      n_local_kv_heads=self.attn_n_local_kv_heads,
      n_rep=self.attn_n_rep,
      head_dim=self.attn_head_dim,
      max_batch_size=self.max_batch_size,
      max_seq_len=self.max_seq_len,
    )(attn_x, start_pos, freqs_cis, mask)
    pprint_d("h_attn", h_attn)

    h_plus_attn = x + h_attn
    pprint_d("jax_block h_plus_attn", h_plus_attn)
    block_norm = JaxRMSNorm(dim=self.dim, eps=self.ffn_norm_eps, name="ffn_norm")(h_plus_attn)
    pprint_d("jax_block_norm", block_norm)
    block_ffn = JaxFeedForward(name="ffn", dim=self.dim, hidden_dim=self.hidden_dim, ffn_dim_multiplier=self.ffn_dim_multiplier, multiple_of=self.multiple_of)(block_norm)
    pprint_d("jax_block_ffn", block_ffn)
    block_out = h_plus_attn + block_ffn
    pprint_d("jax_block_out", block_out)
    return block_out


### Transformer module

In [20]:
# transformer module

def jax_forward_mask(total_len, seqlen, dtype, device):
  # First start_pos columns of all-0s, then a triangular -inf in upper right.
  # I.e. similar to skipping the first start_pos rows(?)
  mask = jnp.full((seqlen, total_len), float("-inf"), device=device)
  mask = jnp.triu(mask, k=1)
  # mask = jnp.hstack([
  #     jnp.zeros((seqlen, start_pos)),
  #     mask
  # ])
  mask = mask.astype(dtype)
  pprint_d("mask", mask)
  return mask

class JaxTransformer(flax.linen.Module):
  dim: int
  attn_n_heads: int
  attn_n_kv_heads: int
  attn_n_local_heads: int
  attn_n_local_kv_heads: int
  attn_n_rep: int
  attn_head_dim: int
  attn_norm_eps: float
  ffn_norm_eps: float
  output_norm_eps: float
  hidden_dim: int
  ffn_dim_multiplier: float
  multiple_of: int
  max_batch_size: int
  max_seq_len: int
  vocab_size: int
  is_prefilling: bool
  prefill_len: int
  device: bool

  @flax.linen.compact
  def __call__(self, tokens, start_pos):
    print(f"JaxTransformer {tokens.shape=} {start_pos=} {self.is_prefilling=}")
    print("tokens", deetnodev(tokens))
    _bsz, seqlen = tokens.shape
    h = flax.linen.Embed(
      num_embeddings=self.vocab_size,
      features=self.dim,
      name="tok_embeddings",
      param_dtype=jnp.bfloat16,
    )(tokens)
    if self.is_prefilling:
      freqs_cis = jax_freqs_cis[:self.prefill_len]
    else:
      # freqs_cis = jax_freqs_cis[start_pos : start_pos + seqlen]
      # print("freqs_cis", deet(freqs_cis))
      # freqs_cis_slice = jnp.expand_dims(jax_freqs_cis[start_pos], axis=0)
      # print("freqs_cis_slice", deet(freqs_cis_slice))
      # freqs_cis = freqs_cis_slice
      freqs_cis = jnp.expand_dims(jax_freqs_cis[start_pos], axis=0)
    print(f"  freqs_cis {freqs_cis.shape=} {start_pos=} {seqlen=}")

    model_parallel_size = 1
    n_kv_heads = params_n_heads if params_n_kv_heads is None else params_n_kv_heads
    n_local_heads = params_n_heads // model_parallel_size
    n_local_kv_heads = n_kv_heads // model_parallel_size

    mask = None
    if seqlen > 1:
      # tokdev = tokens.devices()
      # if len(tokdev) == 1:
        # mask = jax_forward_mask(seqlen, start_pos, h.dtype, list(tokdev)[0])
        # mask = jax_forward_mask(start_pos + seqlen, seqlen, start_pos, h.dtype, list(tokdev)[0])
        mask = jax_forward_mask(self.max_seq_len, seqlen, h.dtype, self.device)
        print(f"  mask {mask.shape=} {self.max_seq_len=} {seqlen=}")
      # else:
      #   print("UNEXPECTED TOKEN DEVICES", deet(tokens))
      #   raise Exception('tokens not on a single device!')

    pprint(f"seqlen", seqlen)
    pprint_d(f"start h", h)
    pprint(f"start_pos", start_pos)
    pprint_d(f"freqs_cis", freqs_cis)
    if mask is None:
      pprint("mask", mask)
    else:
      pprint_d(f"mask", mask)

    for n in range(params_n_layers):
      pprint(f"doing layer layer{n}")
      # enable_pprint(False)
      h = JaxTransformerBlock(
        name=f"layer{n}",
        dim=params_dim,
        attn_n_heads=params_n_heads,
        attn_n_kv_heads=n_kv_heads,
        attn_n_local_heads=n_local_heads,
        attn_n_local_kv_heads=n_local_kv_heads,
        attn_n_rep = n_local_heads // n_local_kv_heads,
        attn_head_dim = params_dim // params_n_heads,
        attn_norm_eps=params_norm_eps,
        ffn_norm_eps=params_norm_eps,
        hidden_dim=params_dim*4,
        ffn_dim_multiplier=params_ffn_dim_multiplier,
        multiple_of=params_multiple_of,
        max_batch_size=params_max_batch_size,
        max_seq_len=params_max_seq_len,
      )(x=h, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)
      pprint_d(f"layer {n} h", h)
      enable_pprint(False)
      # enable_pprint(True)
    pprint("doing output")
    h = JaxRMSNorm(dim=self.dim, eps=self.output_norm_eps, name="output_norm")(h)
    output = flax.linen.Dense(features=self.vocab_size, use_bias=False, name="output", param_dtype=jnp.bfloat16)(h)
    output = output.astype(jnp.float32)
    return output

jax_transformer_prefill = JaxTransformer(
    dim=params_dim,
    attn_n_heads=params_n_heads,
    attn_n_kv_heads=n_kv_heads,
    attn_n_local_heads=n_local_heads,
    attn_n_local_kv_heads=n_local_kv_heads,
    attn_n_rep = n_local_heads // n_local_kv_heads,
    attn_head_dim = params_dim // params_n_heads,
    attn_norm_eps=params_norm_eps,
    ffn_norm_eps=params_norm_eps,
    output_norm_eps=params_norm_eps,
    hidden_dim=params_dim*4,
    ffn_dim_multiplier=params_ffn_dim_multiplier,
    multiple_of=params_multiple_of,
    max_batch_size=params_max_batch_size,
    max_seq_len=params_max_seq_len,
    vocab_size=params_vocab_size,
    is_prefilling=True,
    prefill_len=256,
    device=jax_tpu(),
)

print("jax_transformer_prefill", jax_transformer_prefill)

jax_transformer_incremental = JaxTransformer(
    dim=params_dim,
    attn_n_heads=params_n_heads,
    attn_n_kv_heads=n_kv_heads,
    attn_n_local_heads=n_local_heads,
    attn_n_local_kv_heads=n_local_kv_heads,
    attn_n_rep = n_local_heads // n_local_kv_heads,
    attn_head_dim = params_dim // params_n_heads,
    attn_norm_eps=params_norm_eps,
    ffn_norm_eps=params_norm_eps,
    output_norm_eps=params_norm_eps,
    hidden_dim=params_dim*4,
    ffn_dim_multiplier=params_ffn_dim_multiplier,
    multiple_of=params_multiple_of,
    max_batch_size=params_max_batch_size,
    max_seq_len=params_max_seq_len,
    vocab_size=params_vocab_size,
    is_prefilling=False,
    prefill_len=0,
    device=jax_tpu(),
)


jax_transformer_prefill JaxTransformer(
    # attributes
    dim = 4096
    attn_n_heads = 32
    attn_n_kv_heads = 8
    attn_n_local_heads = 32
    attn_n_local_kv_heads = 8
    attn_n_rep = 4
    attn_head_dim = 128
    attn_norm_eps = 1e-05
    ffn_norm_eps = 1e-05
    output_norm_eps = 1e-05
    hidden_dim = 16384
    ffn_dim_multiplier = 1.3
    multiple_of = 1024
    max_batch_size = 32
    max_seq_len = 512
    vocab_size = 128256
    is_prefilling = True
    prefill_len = 256
    device = TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
)


### Setup/load model params

In [21]:
# jax_transformer_params {
#     'cache': {
#         'layer0': {'attn': {'keys_cache': {'cached_value': ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)})}, 'values_cache': {'cached_value': ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)})}}},
#         'layer1': {'attn': {'keys_cache': {'cached_value': ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)})}, 'values_cache': {'cached_value': ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)})}}},
#         ...
#         'layer31': {'attn': {'keys_cache': {'cached_value': ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)})}, 'values_cache': {'cached_value': ((32, 2048, 8, 128), dtype(bfloat16), {CpuDevice(id=0)})}}},
#     },
#     'params': {
#         'layer0': {'attn': {'wk': {'kernel': ((4096, 1024), dtype(bfloat16), {CpuDevice(id=0)})}, 'wo': {'kernel': ((4096, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'wq': {'kernel': ((4096, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'wv': {'kernel': ((4096, 1024), dtype(bfloat16), {CpuDevice(id=0)})}}, 'attn_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}, 'ffn': {'w1': {'kernel': ((4096, 14336), dtype(bfloat16), {CpuDevice(id=0)})}, 'w2': {'kernel': ((14336, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'w3': {'kernel': ((4096, 14336), dtype(bfloat16), {CpuDevice(id=0)})}}, 'ffn_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}},
#         'layer1': {'attn': {'wk': {'kernel': ((4096, 1024), dtype(bfloat16), {CpuDevice(id=0)})}, 'wo': {'kernel': ((4096, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'wq': {'kernel': ((4096, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'wv': {'kernel': ((4096, 1024), dtype(bfloat16), {CpuDevice(id=0)})}}, 'attn_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}, 'ffn': {'w1': {'kernel': ((4096, 14336), dtype(bfloat16), {CpuDevice(id=0)})}, 'w2': {'kernel': ((14336, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'w3': {'kernel': ((4096, 14336), dtype(bfloat16), {CpuDevice(id=0)})}}, 'ffn_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}},
#         ...
#         'layer30': {'attn': {'wk': {'kernel': ((4096, 1024), dtype(bfloat16), {CpuDevice(id=0)})}, 'wo': {'kernel': ((4096, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'wq': {'kernel': ((4096, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'wv': {'kernel': ((4096, 1024), dtype(bfloat16), {CpuDevice(id=0)})}}, 'attn_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}, 'ffn': {'w1': {'kernel': ((4096, 14336), dtype(bfloat16), {CpuDevice(id=0)})}, 'w2': {'kernel': ((14336, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'w3': {'kernel': ((4096, 14336), dtype(bfloat16), {CpuDevice(id=0)})}}, 'ffn_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}},
#         'layer31': {'attn': {'wk': {'kernel': ((4096, 1024), dtype(bfloat16), {CpuDevice(id=0)})}, 'wo': {'kernel': ((4096, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'wq': {'kernel': ((4096, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'wv': {'kernel': ((4096, 1024), dtype(bfloat16), {CpuDevice(id=0)})}}, 'attn_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}, 'ffn': {'w1': {'kernel': ((4096, 14336), dtype(bfloat16), {CpuDevice(id=0)})}, 'w2': {'kernel': ((14336, 4096), dtype(bfloat16), {CpuDevice(id=0)})}, 'w3': {'kernel': ((4096, 14336), dtype(bfloat16), {CpuDevice(id=0)})}}, 'ffn_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}},
#         'output': {'kernel': ((4096, 128256), dtype(bfloat16), {CpuDevice(id=0)})},
#         'output_norm': {'weight': ((4096,), dtype(bfloat16), {CpuDevice(id=0)})}
#     }
# }

def neginf(shape, dtype, device):
  return jnp.full(shape, float("-inf"), dtype=dtype, device=device)

def init_empty_cache(device):
  cache = {}
  for n in range(params_n_layers):
    cache[f"layer{n}"] = {
        'attn': {
            # 'keys_cache': {'cached_value': neginf((32, 2048, 8, 128), dtype=jnp.bfloat16, device=device) },
            # 'values_cache': {'cached_value': neginf((32, 2048, 8, 128), dtype=jnp.bfloat16, device=device) },
            'keys_cache': {'cached_value': jnp.zeros((32, params_max_seq_len, 8, 128), dtype=jnp.bfloat16, device=device) },
            'values_cache': {'cached_value': jnp.zeros((32, params_max_seq_len, 8, 128), dtype=jnp.bfloat16, device=device) },
        },
    }
  return cache

# jax_layer0_params['params']['attn']['wq']['kernel'] = load_torch_weights('/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.attention.wq.weight').T
# jax_layer0_params['params']['attn']['wk']['kernel'] = load_torch_weights('/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.attention.wk.weight').T
# jax_layer0_params['params']['attn']['wv']['kernel'] = load_torch_weights('/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.attention.wv.weight').T
# jax_layer0_params['params']['attn']['wo']['kernel'] = load_torch_weights('/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.attention.wo.weight').T

# jax_layer0_params['params']['attn_norm']['weight'] = load_torch_weights("/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.attention_norm.weight")

# jax_layer0_params['params']['ffn']['w1']['kernel'] = load_torch_weights('/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.feed_forward.w1.weight').T
# jax_layer0_params['params']['ffn']['w2']['kernel'] = load_torch_weights('/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.feed_forward.w2.weight').T
# jax_layer0_params['params']['ffn']['w3']['kernel'] = load_torch_weights('/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.feed_forward.w3.weight').T

# jax_layer0_params['params']['ffn_norm']['weight'] = load_torch_weights("/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/consolidated.00.pth/layers.0.ffn_norm.weight")

def load_llama3_1_params(device=None):
  prefix = f"/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split"
  params = {
    'tok_embeddings': {'embedding': load_torch_weights(f'{prefix}/consolidated.00.pth/tok_embeddings.weight', device)},
    'output': {'kernel': load_torch_weights(f'{prefix}/consolidated.00.pth/output.weight', device).T},
    'output_norm': {'weight': load_torch_weights(f'{prefix}/consolidated.00.pth/norm.weight', device)},
  }
  for n in range(params_n_layers):
    lprefix = f"{prefix}/consolidated.00.pth/layers.{n}"
    params[f"layer{n}"] = {
        'attn': {
            'wq': {'kernel': load_torch_weights(f'{lprefix}.attention.wq.weight', device).T},
            'wk': {'kernel': load_torch_weights(f'{lprefix}.attention.wk.weight', device).T},
            'wv': {'kernel': load_torch_weights(f'{lprefix}.attention.wv.weight', device).T},
            'wo': {'kernel': load_torch_weights(f'{lprefix}.attention.wo.weight', device).T},
        },
        'attn_norm': {'weight': load_torch_weights(f'{lprefix}.attention_norm.weight', device)},
        'ffn': {
            'w1': {'kernel': load_torch_weights(f'{lprefix}.feed_forward.w1.weight', device).T},
            'w2': {'kernel': load_torch_weights(f'{lprefix}.feed_forward.w2.weight', device).T},
            'w3': {'kernel': load_torch_weights(f'{lprefix}.feed_forward.w3.weight', device).T},
        },
        'ffn_norm': {'weight': load_torch_weights(f'{lprefix}.ffn_norm.weight', device)},
    }
  return params

jax_transformer_params_tpu = load_llama3_1_params(jax_tpu())

# def f(device=None):
#   device = device or jax_cpu()
#   with jax.default_device(device):
#     pprint("initing cache...")
#     cache = init_empty_cache(device)
#     pprint("loading params...")
#     params = load_llama3_1_params(device)
#     jax_transformer_params = {
#         'cache': cache,
#         'params': params,
#     }
#     return jax_transformer_params

# jax_transformer_params = None
# jax_transformer_params_tpu = None

# jax_transformer_params = f()

# print("jax_transformer_params", jax.tree_util.tree_map(lambda t: (t.shape, t.dtype, t.devices()), jax_transformer_params))

# jax_transformer_params_tpu = f(jax.devices("tpu")[0])

# print("jax_transformer_params_tpu", jax.tree_util.tree_map(lambda t: (t.shape, t.dtype, t.devices()), jax_transformer_params_tpu))


In [22]:
tokenizer = Tokenizer("/mnt/ramdisk/llama3_1/Meta-Llama-3.1-8B-Instruct_split/tokenizer.model")
chat_format = ChatFormat(tokenizer)

def jax_gen_tokens(model_input, max_seq_len, max_gen_len):
  prompt_tokens = [model_input.tokens]
  bsz = 1

  pad_id = tokenizer.pad_id

  min_prompt_len = min(len(t) for t in prompt_tokens)
  max_prompt_len = max(len(t) for t in prompt_tokens)
  total_len = min(max_gen_len + max_prompt_len, max_seq_len)
  eos_reached = jnp.array([False] * bsz)

  pprint(f"{pad_id=}")
  jax_tokens = jnp.full((1, total_len), pad_id, dtype=jnp.int32)
  for k, t in enumerate(prompt_tokens):
    pprint(f"{k=}, {t=}")
    jax_tokens = jax_tokens.at[k, :len(t)].set(jnp.array(t, dtype=jnp.int32))

  pprint("jax_tokens", deet(jax_tokens), jax_tokens)
  return jax_tokens

In [23]:
import math

def torch_apply_scaling(freqs: torch.Tensor):
    # Values obtained from grid search
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

def torch_precompute_freqs_cis(
    dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    if use_scaled:
        freqs = torch_apply_scaling(freqs)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

torch_freqs_cis = torch_precompute_freqs_cis(
    params_dim // params_n_heads,
    params_max_seq_len * 2,
    params_rope_theta,
    params_use_scaled_rope,
)

print(deets(torch_freqs_cis))
jax_freqs_cis = torch_to_jax(torch_freqs_cis, jax_tpu())
print(deet(jax_freqs_cis))

(torch.Size([1024, 64]), torch.complex64, device(type='cpu'), np.float64(53830.88144241154))
((1024, 64), dtype('complex64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}, np.float64(53830.88144241154))


  x = t.to(dtype=torch.float64).detach().numpy()


In [24]:
# print("jax_transformer_params", jax.tree_util.tree_map(lambda t: (t.shape, t.dtype, t.devices()), jax_transformer_params))

def f():
  max_seq_len = 400
  max_gen_len = 30
  messages = [
    SystemMessage(
      content="This is a test sentence.",
    ),
    UserMessage(
      content="This is a response.",
    ),
  ]
  model_input = chat_format.encode_dialog_prompt(messages)
  print(model_input.tokens)

  jax_tokens = jax_gen_tokens(model_input, max_seq_len, max_gen_len)
  out, new_params = jax_transformer.apply(jax_transformer_params, tokens=jax_tokens, start_pos=0, mutable=['cache'])
  print("out", jax.tree_util.tree_map(lambda t: (t.shape, t.dtype, t.devices()), out))
  print("new_params keys", new_params.keys())
  print("new_params", jax.tree_util.tree_map(lambda t: (t.shape, t.dtype, t.devices()), new_params))

# f()

# Chatbot

In [25]:
@jax.jit
def prefill_run(run_params, tokens, input_tokens_len):
  logits, new_cache = jax_transformer_prefill.apply(run_params, tokens=tokens, start_pos=0, mutable=['cache'])
  for k in new_cache['cache'].keys():
    idxs = jnp.arange(512) >= input_tokens_len - 1
    idxs = idxs[jnp.newaxis, :, jnp.newaxis, jnp.newaxis]
    new_cache['cache'][k]['attn']['keys_cache']['cached_value'] = jnp.where(idxs, 0, new_cache['cache'][k]['attn']['keys_cache']['cached_value'])
    new_cache['cache'][k]['attn']['values_cache']['cached_value'] = jnp.where(idxs, 0, new_cache['cache'][k]['attn']['values_cache']['cached_value'])
  return new_cache['cache']

def prefill(messages, device):
  # enable_pprint(True)
  model_input = chat_format.encode_dialog_prompt(messages)
  # pprint(f"{model_input.tokens=}")
  prompt_tokens = [model_input.tokens]
  input_tokens_len = len(model_input.tokens)
  total_len = 256

  with jax.default_device(device):
    tokens = jnp.full((1, total_len), tokenizer.pad_id, dtype=jnp.int32)
    for k, t in enumerate(prompt_tokens):
      tokens = tokens.at[k, :len(t)].set(jnp.array(t, dtype=jnp.int32))

  # pprint_d("prefill tokens", tokens)

  prev_pos = 0

  run_params = {'params': jax_transformer_params_tpu, 'cache': init_empty_cache(device)}
  # pprint("prefill run_params", jax.tree_util.tree_map(deetnosum, run_params))

  # pprint("prefilling cache from prompt...")
  enable_pprint(False)
  new_cache = prefill_run(run_params, tokens, input_tokens_len)
  # print("prefill cache output", jax.tree_util.tree_map(deetnosum, new_cache))

  # for i in range(30):
  #   print(f"layer0 keys_cache[:,{i},:,:]", deet(new_cache['layer0']['attn']['keys_cache']['cached_value'][0,i,:,:]))
  # for i in range(30):
  #   print(f"layer0 valus_cache[:,{i},:,:]", deet(new_cache['layer0']['attn']['values_cache']['cached_value'][0,i,:,:]))

  return new_cache

prefill(messages=[
  SystemMessage(
    content="This is a test sentence.",
  ),
  UserMessage(
    content="This is a response.",
  ),
], device=jax_tpu())

dump_mem()

JaxTransformer tokens.shape=(1, 256) start_pos=0 self.is_prefilling=True
tokens ((1, 256), dtype('int32'))
  freqs_cis freqs_cis.shape=(256, 64) start_pos=0 seqlen=256
  mask mask.shape=(256, 512) self.max_seq_len=512 seqlen=256
JaxTransformerBlock start_pos=0 freqs_cis.shape=(256, 64) mask.shape=(256, 512)
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=0 freqs_cis.shape=(256, 64) mask.shape=(256, 512)
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=0 freqs_cis.shape=(256, 64) mask.shape=(256, 512)
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=0 freqs_cis.shape=(256, 64) mask.shape=(256, 512)
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=0 freqs_cis.shape=(256, 64) mask.shape=(256, 512)
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=0 freqs_cis.shape=(256, 64) mask.shape=(256, 512)
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=0 freqs_cis.shape=(256, 64) mask.shape=(256, 512)
JaxAttention
JaxFeedForward
JaxTransformerB

In [26]:
dump_mem()

None 1 0M
freqs_cis 1 0M
params 291 16060M


In [27]:
prefill(messages=[
  SystemMessage(
    content="This is a test sentence.",
  ),
  UserMessage(
    content="This is a response.",
  ),
], device=jax_tpu())

dump_mem()

None 1 0M
freqs_cis 1 0M
params 291 16060M


In [28]:
dump_mem()

None 1 0M
freqs_cis 1 0M
params 291 16060M


In [29]:
@jax.jit
def incremental_run(run_params, tokens, start_pos):
  logits, updates = jax_transformer_incremental.apply(run_params, tokens, start_pos=start_pos, mutable=['cache'])
  return logits, updates['cache']

def f():
  cache = prefill(messages=[
    SystemMessage(
      content="This is a test sentence.",
    ),
    UserMessage(
      content="This is a response.",
    ),
  ], device=jax_tpu())
  incremental_run({'params': jax_transformer_params_tpu, 'cache': cache}, tokens=jnp.zeros((1,1), dtype=jnp.int32), start_pos=10)

f()

JaxTransformer tokens.shape=(1, 1) start_pos=Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> self.is_prefilling=False
tokens ((1, 1), dtype('int32'))
  freqs_cis freqs_cis.shape=(1, 64) start_pos=Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> seqlen=1
JaxTransformerBlock start_pos=Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> freqs_cis.shape=(1, 64) mask.shape=None
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> freqs_cis.shape=(1, 64) mask.shape=None
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> freqs_cis.shape=(1, 64) mask.shape=None
JaxAttention
JaxFeedForward
JaxTransformerBlock start_pos=Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> freqs_cis.shape=(1, 64) mask.shape=None
JaxAt

In [30]:
@dataclass
class TokenResult:
    token: int
    text: str
    logprobs: Optional[List[float]] = None

# def sample_top_p(probs, p):
#     """
#     Perform top-p (nucleus) sampling on a probability distribution.

#     Args:
#         probs (torch.Tensor): Probability distribution tensor.
#         p (float): Probability threshold for top-p sampling.

#     Returns:
#         torch.Tensor: Sampled token indices.

#     Note:
#         Top-p sampling selects the smallest set of tokens whose cumulative probability mass
#         exceeds the threshold p. The distribution is renormalized based on the selected tokens.
#     """
#     probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
#     probs_sum = torch.cumsum(probs_sort, dim=-1)
#     mask = probs_sum - probs_sort > p
#     probs_sort[mask] = 0.0
#     probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
#     next_token = torch.multinomial(probs_sort, num_samples=1)
#     next_token = torch.gather(probs_idx, -1, next_token)
#     return next_token

def run():
  enable_pprint(True)
  max_seq_len = 400
  max_gen_len = 30
  messages = [
    SystemMessage(
      content="This is a test sentence.",
    ),
    UserMessage(
      content="This is a response.",
    ),
  ]
  model_input = chat_format.encode_dialog_prompt(messages)
  # pprint(f"{model_input.tokens=}")

  # temperature = 0.7
  # top_p = 0.9
  temperature = 0
  top_p = 0
  logprobs = False
  device = jax.devices()[0]

  max_seq_len = 400
  max_gen_len = 30
  # max_gen_len = 7

  prompt_tokens = [model_input.tokens]
  bsz = 1

  pad_id = tokenizer.pad_id

  min_prompt_len = min(len(t) for t in prompt_tokens)
  max_prompt_len = max(len(t) for t in prompt_tokens)
  total_len = min(max_gen_len + max_prompt_len, max_seq_len)
  eos_reached = jnp.array([False] * bsz)

  with jax.default_device(device):
    tokens = jnp.full((1, total_len), pad_id, dtype=jnp.int32)
    for k, t in enumerate(prompt_tokens):
      tokens = tokens.at[k, :len(t)].set(jnp.array(t, dtype=jnp.int32))
    stop_tokens = jnp.array(tokenizer.stop_tokens)

  # pprint_d("tokens", tokens)

  prev_pos = 0

  input_text_mask = tokens != pad_id

  # pprint_d("stop_tokens", stop_tokens)

  out_tokens = []
  # run_params = jax_transformer_params_tpu

  # pprint("run_params", jax.tree_util.tree_map(lambda t: (t.shape, t.dtype, t.devices()), run_params))
  # run_params = {'params': run_params['params'], 'cache': init_empty_cache(jax_cpu())}
  # blah()

  # run_params = {'params': jax_transformer_params_tpu, 'cache': init_empty_cache(device)}

  if True:
    # pprint("Starting LLM...")
    pprint("Filling cache from prompt...")
    enable_pprint(False)
    cache = prefill(messages, device=jax_tpu())
    # print("cache", jax.tree_util.tree_map(lambda t: (t.shape, t.dtype, t.devices()), cache))
    enable_pprint(True)
    prev_pos = min_prompt_len - 1

    pprint("Generating tokens...")
    for cur_pos in range(min_prompt_len, total_len):
      # pprint("cur_pos", cur_pos)
      enable_pprint(False)
      # print("tokens", deetnosum(tokens[:, prev_pos:cur_pos]))
      # print("startpos", prev_pos)
      logits, cache = incremental_run({'params': jax_transformer_params_tpu, 'cache': cache}, tokens=tokens[:, prev_pos:cur_pos], start_pos=prev_pos)

      # if cur_pos == 28:
      #   print(f"{prev_pos=} {cur_pos=}")
      #   for i in range(30):
      #     print(f"layer0 keys_cache[:,{i},:,:]", deet(cache['layer0']['attn']['keys_cache']['cached_value'][0,i,:,:]))
      #   for i in range(30):
      #     print(f"layer0 valus_cache[:,{i},:,:]", deet(cache['layer0']['attn']['values_cache']['cached_value'][0,i,:,:]))

      enable_pprint(True)
      # pprint_d("logits", logits)

      if temperature > 0:
          # probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
          # next_token = sample_top_p(probs, top_p)
          pass
      else:
          next_token = jnp.argmax(logits[:, -1], axis=-1).astype(jnp.int32)
          # pprint_d("next_token", next_token)
          # pprint_d("next_token idxed", logits[:, -1, next_token])

      next_token = next_token.reshape(-1)
      # only replace token if prompt has already been generated
      next_token = jnp.where(
          input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
      )
      tokens = tokens.at[:, cur_pos].set(next_token)

      target = tokens[:, prev_pos + 1 : cur_pos + 1]
      if logprobs:
          # token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
          #     input=logits.transpose(1, 2),
          #     target=tokens[:, prev_pos + 1 : cur_pos + 1],
          #     reduction="none",
          #     ignore_index=pad_id,
          # )
          pass
      eos_reached |= (~input_text_mask[:, cur_pos]) & (
          jnp.isin(next_token, stop_tokens)
      )
      # pprint("eos_reached", eos_reached)
      tk = TokenResult(
          token=int(next_token[0]),
          text=tokenizer.decode(next_token.tolist()),
          logprobs=(
              token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
              if logprobs
              else None
          ),
      )
      # print("tk", tk)
      out_tokens.append(tk)

      prev_pos = cur_pos
      if all(eos_reached):
        break
  print("".join([t.text for t in out_tokens]))
  return out_tokens

def f(times):
  for n in range(times):
    print("moo!", n)
    out_tokens = run()

f(1)


moo! 0
Filling cache from prompt...
Generating tokens...
It looks like we're having a conversation! What's next?<|eot_id|>


In [31]:
dump_mem()

None 1 0M
freqs_cis 1 0M
params 291 16060M


# Memory debugging

In [32]:
# jax.profiler.save_device_memory_profile("tpu-memory-prefill200fail.prof", "tpu")

In [33]:
# jax.profiler.save_device_memory_profile("cpu-memory.prof", "cpu")

In [34]:
# jax.clear_caches()

In [35]:
def f():
  arrs = jax.live_arrays("tpu")
  for idx, arr in enumerate(arrs):
    print(idx, deetnosum(arr))

f()


0 ((1024, 64), dtype('complex64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
1 ((1024, 64), dtype('complex64'), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
2 ((4096,), dtype(bfloat16), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
3 ((4096, 14336), dtype(bfloat16), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
4 ((14336, 4096), dtype(bfloat16), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
5 ((4096, 14336), dtype(bfloat16), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
6 ((4096,), dtype(bfloat16), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
7 ((4096, 4096), dtype(bfloat16), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
8 ((4096, 1024), dtype(bfloat16), {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)})
9 ((4096, 1024), dtype(bfloat16), {TpuDevice(id=0, process_index=0, coords=(0,0,0), cor