In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import subprocess
import re


def shmoogle_smi():
    jax.profiler.save_device_memory_profile("memory.prof")
    pprof_path = "/usr/local/go/pkg/tool/linux_amd64/pprof"
    out = subprocess.run([pprof_path, "-top", "memory.prof"], stdout=subprocess.PIPE)
    stdout = out.stdout.decode("utf-8")
    re_sult = re.search(
        r"Showing nodes accounting for (\d+(?:\.\d+)?)([MG]B)?, (\d+(?:\.\d+)?)% of (\d+(?:\.\d+)?)([MG]B)? total",
        stdout,
    )
    multiplier = 1 / 1000 if re_sult.group(5) == "MB" else 1
    total_mem_usage = float(re_sult.group(4))
    print(
        "Total mem usage:",
        total_mem_usage * multiplier,
        "GB",
        "out of",
        16 * len(jax.devices()),
        "GB",
    )


shmoogle_smi()

Total mem usage: 0.0 GB out of 128 GB


In [10]:
@jax.jit
def u():
    x = [0]

    @jax.jit
    def a():
        b(x)
        return x

    @jax.jit
    def b(x):
        c(x)

    @jax.jit
    def c(x):
        x[0] = 1

    return a()


u()

[Array(0, dtype=int32, weak_type=True)]

In [3]:
from tokenizer import Tokenizer
from llama2_model import LLaMA

import equinox as eqx
import re
from safetensors import safe_open
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import jax.numpy as jnp
import transformers
import numpy as np
import jax
import jmp

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
tokenizer = Tokenizer("models/Llama-2-7b-hf/tokenizer.model")
input_ids = tokenizer.encode("Hello world!", bos=True, eos=False)
input_ids

[1, 15043, 3186, 29991]

In [5]:
num_devices = len(jax.devices())
mesh = Mesh(np.array(jax.devices()).reshape(-1, 4), axis_names=("dp", "mp"))
policy = jmp.get_policy("p=bf16,c=bf16")
print("Creating LLaMA...")
llama = LLaMA(mesh, policy)
print("Created LLaMA.")
shmoogle_smi()

Creating LLaMA...
Created LLaMA.
Total mem usage: 30.0 GB out of 128 GB


Main binary filename not available.


In [None]:
print("Loading model...")
print()
for filename in [
    "models/Llama-2-7b-hf/model-00001-of-00002.safetensors",
    "models/Llama-2-7b-hf/model-00002-of-00002.safetensors",
]:
    with safe_open(
        filename,
        framework="numpy",
        device="cpu",
    ) as f:
        for k in f.keys():
            weight = f.get_tensor(k)
            if (
                k.endswith(".weight")
                and not k.endswith("embed_tokens.weight")
                and not k.endswith("norm.weight")
                # and not k.endswith("lm_head.weight")
            ):
                weight = weight.T
            re_sult = re.search(r"layers\.([0-9]+)", k)
            try:
                k = (
                    k[: re_sult.span()[0]]
                    + f"layers[{re_sult.group(1)}]"
                    + k[re_sult.span()[1] :]
                )
            except AttributeError:
                pass
            print("\r" + " " * 80, end="")
            print("\rLoading", k, end="")
            og = eval(f"llama.{k}")
            weight = jax.device_put(weight.astype(og.dtype), device=og.sharding)
            llama = eval(f"eqx.tree_at(lambda l: l.{k}, llama, weight)")
print()
shmoogle_smi()

In [51]:
from typing import Any
from types import MethodType
from functools import partial
from llama2_model import SelfAttention
from ast import (
    NodeTransformer,
    Expr,
    Return,
    Subscript,
    NamedExpr,
    Name,
    Store,
    Load,
    Constant,
    Index,
    FunctionDef,
    Call,
    Attribute,
    arg,
    arguments,
    Assign,
    Tuple,
    Dict,
    keyword,
    IfExp,
    Lambda,
    Compare,
    Is,
    Expression,
)
import inspect
import astor
import ast
import dis


class ReplaceFuncCall(NodeTransformer):
    def visit_Lambda(self, node: Lambda) -> Any:
        return super().visit_Lambda(node)

    def visit_Call(self, node):
        func = node.func  # Expression(value=node.func)
        store = Call(
            func=Name(id="getattr", ctx=Load()),
            args=[
                func,
                Constant(value="store", kind=None),
                Lambda(
                    args=arguments(
                        posonlyargs=[],
                        args=[
                            arg(arg="state", annotation=None, type_comment=None),
                            arg(arg="output", annotation=None, type_comment=None),
                        ],
                        vararg=None,
                        kwonlyargs=[],
                        kw_defaults=[],
                        kwarg=None,
                        defaults=[],
                    ),
                    body=Tuple(
                        elts=[
                            Name(id="output", ctx=Load()),
                            Name(id="state", ctx=Load()),
                        ],
                        ctx=Load(),
                    ),
                ),
            ],
            keywords=[],
        )
        under_slice = NamedExpr(
            target=Name(id="state", ctx=Store()),
            value=Call(
                func=store,
                args=[
                    Subscript(
                        value=Name(id="state", ctx=Load()),
                        slice=Index(value=Constant(value=-1, kind=None)),
                        ctx=Load(),
                    ),
                    Call(
                        func=func,
                        args=[self.visit(a) for a in node.args],
                        keywords=[self.visit(k) for k in node.keywords]
                        + [
                            keyword(
                                arg=None,
                                value=IfExp(
                                    test=Call(
                                        func=Name(id="hasattr", ctx=Load()),
                                        args=[
                                            func,
                                            Constant(value="store", kind=None),
                                        ],
                                        keywords=[],
                                    ),
                                    body=Dict(
                                        keys=[Constant(value="state", kind=None)],
                                        values=[
                                            Subscript(
                                                value=Name(id="state", ctx=Load()),
                                                slice=Index(
                                                    value=Constant(value=-1, kind=None)
                                                ),
                                                ctx=Load(),
                                            )
                                        ],
                                    ),
                                    orelse=Dict(keys=[], values=[]),
                                ),
                            )
                        ],
                    ),
                ],
                keywords=[],
            ),
        )
        return Subscript(
            value=under_slice,
            # value=Subscript(
            #     value=under_slice,
            #     slice=Index(value=Constant(value=0, kind=None)),
            #     ctx=Load(),
            # ),
            slice=Index(value=Constant(value=0, kind=None)),
            ctx=Load(),
        )

    def visit_FunctionDef(self, node):
        return FunctionDef(
            name=node.name,
            args=arguments(
                args=node.args.args,
                vararg=node.args.vararg,
                kwonlyargs=node.args.kwonlyargs
                + [arg(arg="state", annotation=None, type_comment=None)],
                kw_defaults=node.args.kw_defaults + [Constant(value=None, kind=None)],
                kwarg=node.args.kwarg,
                defaults=node.args.defaults,
                posonlyargs=node.args.posonlyargs,
            ),
            body=
            # [
            #     Assign(
            #         targets=[Name(id="state", ctx=Store())],
            #         value=Tuple(
            #             elts=[
            #                 Constant(value=None, kind=None),
            #                 Dict(keys=[], values=[]),
            #             ],
            #             ctx=Load(),
            #         ),
            #         type_comment=None,
            #     )
            # ]
            # +
            [
                Assign(
                    targets=[Name(id="state", ctx=Store())],
                    value=Tuple(
                        elts=[
                            Constant(value=None, kind=None),
                            Name(id="state", ctx=Load()),
                        ],
                        ctx=Load(),
                    ),
                    type_comment=None,
                )
            ]
            + [self.visit(b) for b in node.body],
            decorator_list=node.decorator_list,
            returns=node.returns,
            type_comment=node.type_comment,
        )

    def visit_Return(self, node: Return):
        val = self.visit(node.value)
        return Return(
            value=IfExp(
                test=Compare(
                    left=Name(id="state", ctx=Load()),
                    ops=[Is()],
                    comparators=[Constant(value=None, kind=None)],
                ),
                body=val,
                orelse=Tuple(
                    elts=[
                        val,
                        Subscript(
                            value=Name(id="state", ctx=Load()),
                            slice=Index(value=Constant(value=-1, kind=None)),
                            ctx=Load(),
                        ),
                    ],
                    ctx=Load(),
                ),
            )
        )


def add_state(func):
    source = inspect.getsource(func)
    if source.startswith(" "):
        source = "if True:\n" + source
    func_ast = ast.parse(source)
    func_ast = ReplaceFuncCall().visit(func_ast)
    func_ast = ast.fix_missing_locations(func_ast)
    # print(ast.dump(func_ast))
    fn = func.__code__.co_filename
    fn = "<ast>"
    code = compile(func_ast, filename=fn, mode="exec")
    env = func.__globals__.copy()
    print(astor.to_source(func_ast))
    # print(env.keys())
    exec(code, env)
    return env[func.__name__]


from stest import Obj

# add_state(SelfAttention.__call__)
obj = Obj()
# obj.f = MethodType(add_state(obj.f), obj)
obj.f = partial(add_state(obj.f), obj)
stateful = add_state(obj.u)
stateful = partial(stateful, obj)
stateful.store = lambda state, output: (
    output[0],
    {"u": output[0], **output[-1], **state},
)
obj.u = stateful
# obj.u.store = lambda state, output: (output, {"u": output, **state})
obj.f(2, state={})

if True:

    def f(self, x, *, state=None):
        state = None, state
        (state := getattr(print, 'store', lambda state, output: (output,
            state))(state[-1], print(state, **{'state': state[-1]} if
            hasattr(print, 'store') else {})))[0]
        x = (state := getattr(lambda a: a, 'store', lambda state, output: (
            output, state))(state[-1], lambda a: a(x, **{'state': state[-1]
            } if hasattr(lambda a: a, 'store') else {})))[0]
        (state := getattr(print, 'store', lambda state, output: (output,
            state))(state[-1], print(state, **{'state': state[-1]} if
            hasattr(print, 'store') else {})))[0]
        x = (state := getattr(lambda a: a, 'store', lambda state, output: (
            output, state))(state[-1], lambda a: a(x, **{'state': state[-1]
            } if hasattr(lambda a: a, 'store') else {})))[0]
        (state := getattr(print, 'store', lambda state, output: (output,
            state))(state[-1], print(state

(2, {'u': 2})

In [55]:
import jax.experimental.host_callback
from copy import deepcopy, copy
import dataclasses


try:
    jax.tree_util.register_pytree_node(
        jmp.Policy,
        lambda policy: (
            (policy.param_dtype, policy.compute_dtype, policy.output_dtype),
            None,
        ),
        lambda _, tup: jmp.Policy(*tup),
    )
except ValueError:
    pass
try:
    jax.tree_util.register_pytree_node(
        jax._src.numpy.lax_numpy._ScalarMeta, lambda x: (tuple(), x), lambda _, x: x
    )
except ValueError:
    pass


# short for multi-level map
def tree_mlm(fn, tree):
    if isinstance(tree, (int, float, bool, str, np.ndarray, jax.Array, jmp.Policy)):
        return tree
    leaves, treedef = jax.tree_util.tree_flatten_with_path(
        tree, is_leaf=lambda x: x is not tree
    )
    if not leaves:
        return tree
    if any(x is tree for x in leaves):
        return
    leaves = [fn(path, tree_mlm(fn, leaf)) for path, leaf in leaves]
    return jax.tree_util.tree_unflatten(treedef, leaves)


def debugify_model(model):
    def new_model(*args, **kwargs):
        refs = {}

        def add_ref(pv):
            refs[pv[0]] = pv[1]

        class DebugWrapper(eqx.Module):
            module: eqx.Module
            cb: object

            def __init__(self, module, cb):
                self.module = module
                self.cb = cb

            def __call__(self, *args, **kwargs):
                result = self.module(*args, **kwargs)
                jax.debug.callback(self.cb, result)
                return result

        def debugify(path, leaf):
            if not isinstance(leaf, eqx.Module):
                return leaf
            if not hasattr(leaf, "__call__"):
                return leaf

            path = "".join(map(str, path))

            # def cb(x):
            #     print("bruh")
            #     refs["".join(map(str, path))] = x

            # return DebugWrapper(leaf, cb)
            call = leaf.__call__
            # if hasattr(call, "_orig_fn"):
            # leaf = eqx.tree_at(lambda l: l.__call__, leaf, call._orig_fn)
            call_ = partial(add_state(call), leaf)
            # call_._orig_fn = call
            call = call_

            call.store = lambda state, output: (
                output[0],
                {path: output[0], **output[-1], **state},
            )

            # leaf.__call__ = call
            class MyDebugWrapper(eqx.Module):
                child: eqx.Module
                call: object

                def __init__(self, child, call):
                    self.child = child
                    self.call = call

                def __call__(self, *args, **kwargs):
                    print("called", path)
                    print(kwargs["state"])
                    ret = call(*args, **kwargs)
                    print(ret)
                    return ret

                def __hasattr__(self, attr):
                    return hasattr(self.child, attr)

                def __getattr__(self, attr):
                    return getattr(self.child, attr)

            return MyDebugWrapper(leaf, call)
            # # return new_dataclass(**{k: getattr(leaf, k) for k in fields}, __call__=call)

        return tree_mlm(debugify, model)(*args, **kwargs), dict(
            [l for ref in refs for l in ref]
        )

    return new_model


llama_debug = jax.vmap(debugify_model(llama))
shmoogle_smi()

with mesh:
    input_ids = [
        tokenizer.encode(x, bos=True, eos=False)
        for x in ["Hello world", "This is a test"]
    ]
    input_ids = [x + [0] * (128 - len(x)) for x in input_ids]
    ids = jnp.asarray(input_ids)
    ids = jax.device_put(ids, NamedSharding(mesh, spec=PartitionSpec("dp", None)))
    result = llama_debug(ids)

Main binary filename not available.


Total mem usage: 30.0 GB out of 128 GB
if True:

    def __call__(self, x, *, state=None):
        state = None, state
        x = (state := getattr(self.policy.cast_to_compute, 'store', lambda
            state, output: (output, state))(state[-1], self.policy.
            cast_to_compute(x, **{'state': state[-1]} if hasattr(self.
            policy.cast_to_compute, 'store') else {})))[0]
        return x @ self.weight if state is None else (x @ self.weight,
            state[-1])

if True:

    def __call__(self, x, *, state=None):
        state = None, state
        return self.weight[x] if state is None else (self.weight[x], state[-1])

if True:

    def __call__(self, x, *, state=None):
        state = None, state
        orig_dtype = x.dtype
        x = (state := getattr(x.astype, 'store', lambda state, output: (
            output, state))(state[-1], x.astype(jnp.float32, **{'state':
            state[-1]} if hasattr(x.astype, 'store') else {})))[0]
        rms = (state := getatt

KeyError: 'state'

In [None]:
reference_llama = transformers.LlamaModel.from_pretrained("models/Llama-2-7b-hf")

In [None]:
import torch


def make_torch_hook(name):
    def torch_hook(module, input, output):
        print(f"Module called: {name} {module.__class__}")
        for arg in input:
            if isinstance(arg, torch.Tensor):
                print("Input:", arg.shape, arg.dtype, str(arg)[:100])
        if isinstance(output, tuple):
            output = output[0]
        if isinstance(output, torch.Tensor):
            print("Output:", output.shape, output.dtype, str(output)[:100])
        return output

    return torch_hook


for name, module in reference_llama.named_modules():
    module._forward_hooks.clear()
    module.register_forward_hook(make_torch_hook(name))


input_ids = torch.tensor(input_ids)
reference_llama(input_ids)