In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import jax
import subprocess
import re


def shmoogle_smi():
    jax.profiler.save_device_memory_profile("memory.prof")
    pprof_path = "/usr/local/go/pkg/tool/linux_amd64/pprof"
    out = subprocess.run([pprof_path, "-top", "memory.prof"], stdout=subprocess.PIPE)
    stdout = out.stdout.decode("utf-8")
    re_sult = re.search(
        r"Showing nodes accounting for (\d+(?:\.\d+)?)([MG]B)?, (\d+(?:\.\d+)?)% of (\d+(?:\.\d+)?)([MG]B)? total",
        stdout,
    )
    multiplier = 1 / 1000 if re_sult.group(5) == "MB" else 1
    total_mem_usage = float(re_sult.group(4))
    print(
        "Total mem usage:",
        total_mem_usage * multiplier,
        "GB",
    )


shmoogle_smi()

Total mem usage: 0.0 GB


In [4]:
from tokenizer import Tokenizer
from llama2_model import LLaMA

from oryx.core import harvest
import equinox as eqx
import re
from safetensors import safe_open
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import jax.numpy as jnp
import transformers
import numpy as np
import jax
import jmp

  from .autonotebook import tqdm as notebook_tqdm
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [5]:
tokenizer = Tokenizer("models/Llama-2-7b-hf/tokenizer.model")
input_ids = tokenizer.encode("Hello world!", bos=True, eos=False)
input_ids

[1, 15043, 3186, 29991]

In [6]:
num_devices = len(jax.devices())
mesh = Mesh(np.array(jax.devices()).reshape(-1, 4), axis_names=("dp", "mp"))
policy = jmp.get_policy("p=bf16,c=bf16,o=bf16")
print("Creating LLaMA...")
llama = LLaMA(mesh, policy=policy)
print("Created LLaMA.")
shmoogle_smi()

Creating LLaMA...
Created LLaMA.
Total mem usage: 15.0 GB


Main binary filename not available.


In [8]:
from tqdm import tqdm
print("Loading model...")
print()
for filename in [
    "models/Llama-2-7b-hf/model-00001-of-00002.safetensors",
    "models/Llama-2-7b-hf/model-00002-of-00002.safetensors",
]:
    with safe_open(
        filename,
        framework="numpy",
        device="cpu",
    ) as f:
        for k in (bar := tqdm(f.keys())):
            weight = f.get_tensor(k)
            if (
                k.endswith(".weight")
                and not k.endswith("embed_tokens.weight")
                and not k.endswith("norm.weight")
                # and not k.endswith("lm_head.weight")
            ):
                weight = weight.T
            re_sult = re.search(r"layers\.([0-9]+)", k)
            try:
                k = (
                    k[: re_sult.span()[0]]
                    + f"layers[{re_sult.group(1)}]"
                    + k[re_sult.span()[1] :]
                )
            except AttributeError:
                pass
            # print("\r" + " " * 80, end="")
            # print("\rLoading", k, end="")
            bar.set_description(f"Loading {k}")
            og = eval(f"llama.{k}")
            weight = jax.device_put(weight.astype(og.dtype), device=og.sharding)
            llama = eval(f"eqx.tree_at(lambda l: l.{k}, llama, weight)")
print()
shmoogle_smi()

Loading model...



Loading model.layers[9].self_attn.v_proj.weight: 100%|██████████| 241/241 [00:21<00:00, 11.08it/s]         
Loading model.norm.weight: 100%|██████████| 82/82 [00:07<00:00, 10.99it/s]                               



Total mem usage: 15.0 GB


Main binary filename not available.


In [22]:
llama_exec = jax.vmap(llama)
llama_debug = harvest(llama, tag="activations")
# llama_debug = jax.vmap(lambda x: harvest(llama, tag="activations")({}, x)[1])
# shmoogle_smi()

with mesh:
    input_ids = [
        tokenizer.encode(x, bos=True, eos=False)
        # for x in ["Hello world", "This is a test"]
        for x in ["Hello world, this is a test sentence. I am saying something."]
    ]
    input_ids = [x + [0] * (128 - len(x)) for x in input_ids]
    ids = jnp.asarray(input_ids)
    ids = jax.device_put(ids, NamedSharding(mesh, spec=PartitionSpec("dp", None)))
    result = llama_exec(ids)
    lp = jax.nn.log_softmax(result, axis=-1)
    loss = -(jnp.take_along_axis(lp[:, :-1], ids[:, 1:, None], 2)[:, :, 0] * (ids[:, 1:] != 0)).sum() / (ids != 0).sum()
    print(loss)
    activations = llama_debug({}, ids[0])

11


ValueError: Variable has already been reaped: post_attn

In [None]:
reference_llama = transformers.LlamaModel.from_pretrained("models/Llama-2-7b-hf")

KeyboardInterrupt: 

In [None]:
import torch


def make_torch_hook(name):
    def torch_hook(module, input, output):
        print(f"Module called: {name} {module.__class__}")
        for arg in input:
            if isinstance(arg, torch.Tensor):
                print("Input:", arg.shape, arg.dtype, str(arg)[:100])
        if isinstance(output, tuple):
            output = output[0]
        if isinstance(output, torch.Tensor):
            print("Output:", output.shape, output.dtype, str(output)[:100])
        return output

    return torch_hook


for name, module in reference_llama.named_modules():
    module._forward_hooks.clear()
    module.register_forward_hook(make_torch_hook(name))


input_ids = torch.tensor(input_ids)
reference_llama(input_ids)