In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import subprocess
import re


def shmoogle_smi():
    jax.profiler.save_device_memory_profile("memory.prof")
    pprof_path = "/usr/local/go/pkg/tool/linux_amd64/pprof"
    out = subprocess.run([pprof_path, "-top", "memory.prof"], stdout=subprocess.PIPE)
    stdout = out.stdout.decode("utf-8")
    re_sult = re.search(
        r"Showing nodes accounting for (\d+(?:\.\d+)?)([MG]B)?, (\d+(?:\.\d+)?)% of (\d+(?:\.\d+)?)([MG]B)? total",
        stdout,
    )
    multiplier = 1 / 1000 if re_sult.group(5) == "MB" else 1
    total_mem_usage = float(re_sult.group(4))
    print(
        "Total mem usage:",
        total_mem_usage * multiplier,
        "GB",
        "out of",
        16 * len(jax.devices()),
        "GB",
    )


shmoogle_smi()

Total mem usage: 0.0 GB out of 128 GB


In [3]:
from tokenizer import Tokenizer
from llama2_model import LLaMA

import equinox as eqx
import re
from safetensors import safe_open
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import jax.numpy as jnp
import transformers
import numpy as np
import jax
import jmp

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
tokenizer = Tokenizer("models/Llama-2-7b-hf/tokenizer.model")
input_ids = tokenizer.encode("Hello world!", bos=True, eos=False)
input_ids

[1, 15043, 3186, 29991]

In [5]:
num_devices = len(jax.devices())
mesh = Mesh(np.array(jax.devices()).reshape(-1, 4), axis_names=("dp", "mp"))
policy = jmp.get_policy("p=bf16,c=bf16")
print("Creating LLaMA...")
llama = LLaMA(mesh, policy)
print("Created LLaMA.")
shmoogle_smi()

Creating LLaMA...
Created LLaMA.
Total mem usage: 29.95 GB out of 128 GB


Main binary filename not available.


In [6]:
print("Loading model...")
print()
for filename in [
    "models/Llama-2-7b-hf/model-00001-of-00002.safetensors",
    "models/Llama-2-7b-hf/model-00002-of-00002.safetensors",
]:
    with safe_open(
        filename,
        framework="numpy",
        device="cpu",
    ) as f:
        for k in [k for k in f.keys() if "embed" in k]:  # f.keys():
            weight = f.get_tensor(k)
            if (
                k.endswith(".weight")
                and not k.endswith("embed_tokens.weight")
                and not k.endswith("norm.weight")
            ):
                weight = weight.T
            re_sult = re.search(r"layers\.([0-9]+)", k)
            try:
                k = (
                    k[: re_sult.span()[0]]
                    + f"layers[{re_sult.group(1)}]"
                    + k[re_sult.span()[1] :]
                )
            except AttributeError:
                pass
            print("\r" + " " * 80, end="")
            print("\rLoading", k, end="")
            og = eval(f"llama.{k}")
            weight = jax.device_put(weight.astype(og.dtype), device=og.sharding)
            llama = eval(f"eqx.tree_at(lambda l: l.{k}, llama, weight)")
print()
shmoogle_smi()

Loading model...

Loading model.embed_tokens.weight                                               
Total mem usage: 33.85 GB out of 128 GB


Main binary filename not available.


In [10]:
from llama2_model import (
    Attention,
    RotaryEmbedding,
    SelfAttention,
    MLP,
    LayerNorm,
    Embedding,
)


class DebugWrapper(eqx.Module):
    name: str
    module: eqx.Module

    def __init__(self, name, module):
        self.name = name
        self.module = module

    def __call__(self, *args, **kwargs):
        jax.debug.print(f"Module called: {self.name} ({type(self.module)})")
        for arg in args:
            if isinstance(arg, jax.Array):
                print("Input:", arg.shape, arg.dtype, str(arg)[:100])
        result = self.module(*args, **kwargs)
        jax.debug.print(
            f"Output: {(result.shape, result.dtype) if isinstance(result, jax.Array) else ''} {str(result)[:100]}"
        )
        return result


debug_type = (SelfAttention, MLP, LayerNorm, Embedding)


def is_leaf(mod):
    leaf = isinstance(
        mod,
        debug_type,
    )
    return leaf


def debugify(prefix, x):
    if not isinstance(x, debug_type):
        return x
    return DebugWrapper("".join(map(str, prefix)), x)


llama_debug = jax.vmap(
    jax.tree_util.tree_map_with_path(
        debugify,
        llama,
        is_leaf=is_leaf,
    )
)
shmoogle_smi()

with mesh:
    input_ids = [
        tokenizer.encode(x, bos=True, eos=False)
        for x in ["Hello world", "This is a test"]
    ]
    input_ids = [x + [0] * (128 - len(x)) for x in input_ids]
    ids = jnp.asarray(input_ids)
    ids = jax.device_put(ids, NamedSharding(mesh, spec=PartitionSpec("dp", None)))
    result = llama_debug(ids)

Main binary filename not available.


Total mem usage: 34.803580000000004 GB out of 128 GB
Module called: .model.embed_tokens (<class 'llama2_model.Embedding'>)
Input: () int32 Traced<ShapedArray(int32[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(int32[128])>
Output: ((4096,), dtype('float32')) Traced<ShapedArray(float32[4096])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(float3
Module called: .model.layers[0].input_layernorm (<class 'llama2_model.LayerNorm'>)
Input: (4096,) bfloat16 Traced<ShapedArray(bfloat16[4096])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(bfloa
Output: ((4096,), dtype(bfloat16)) Traced<ShapedArray(bfloat16[4096])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(bfloa
Module called: .model.layers[0].self_attn (<class 'llama2_model.SelfAttention'>)
Input: (128, 4096) bfloat16 Traced<ShapedArray(bfloat16[128,4096])>with<BatchTrace(level=1/0)> with
  val = Array([[[0, -0, 0, .


KeyboardInterrupt: 

In [None]:
reference_llama = transformers.LlamaModel.from_pretrained("models/Llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.66s/it]


In [None]:
import torch


def make_torch_hook(name):
    def torch_hook(module, input, output):
        print(f"Module called: {name} {module.__class__}")
        for arg in input:
            if isinstance(arg, torch.Tensor):
                print("Input:", arg.shape, arg.dtype, str(arg)[:100])
        if isinstance(output, tuple):
            output = output[0]
        if isinstance(output, torch.Tensor):
            print("Output:", output.shape, output.dtype, str(output)[:100])
        return output

    return torch_hook


for name, module in reference_llama.named_modules():
    module._forward_hooks.clear()
    module.register_forward_hook(make_torch_hook(name))


input_ids = torch.tensor(input_ids)
reference_llama(input_ids)

Module called: embed_tokens <class 'torch.nn.modules.sparse.Embedding'>
Input: torch.Size([2, 128]) torch.int64 tensor([[    1, 15043,  3186,     0,     0,     0,     0,     0,     0,     0,
             0,     0
Output: torch.Size([2, 128, 4096]) torch.float32 tensor([[[ 1.8616e-03, -3.3722e-03,  3.9864e-04,  ..., -8.3008e-03,
           2.5787e-03, -3.9368e-
Module called: layers.0.input_layernorm <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>
Input: torch.Size([2, 128, 4096]) torch.float32 tensor([[[ 1.8616e-03, -3.3722e-03,  3.9864e-04,  ..., -8.3008e-03,
           2.5787e-03, -3.9368e-
Output: torch.Size([2, 128, 4096]) torch.float32 tensor([[[ 6.7356e-03, -5.5986e-03,  9.5712e-05,  ..., -1.0382e-02,
           3.4557e-03, -2.9162e-
Module called: layers.0.self_attn.q_proj <class 'torch.nn.modules.linear.Linear'>
Input: torch.Size([2, 128, 4096]) torch.float32 tensor([[[ 6.7356e-03, -5.5986e-03,  9.5712e-05,  ..., -1.0382e-02,
           3.4557e-03, -2.9162e-
Out

  input_ids = torch.tensor(input_ids)


ValueError: too many values to unpack (expected 2)

In [None]:
from inspect import getsource

print(getsource(layer.self_attn.forward))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )

        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split