In [1]:
from nnsight import LanguageModel
from nnsight.envoy import Envoy
from nnsight.util import WrapperModule
import torch
from typing import List, Tuple, Dict, Any

from transformers.utils import fx as tfx
import torch.fx as fx
from torch.fx import replace_pattern

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
model = LanguageModel("EleutherAI/pythia-70m", device_map="cuda:0", dispatch=True)


with model.trace("a"):
    # sample_input = model.transformer.h[3].mlp.input.save()
    sample_input = model.gpt_neox.layers[3].mlp.input.save()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [13]:
attn = model._model.gpt_neox.layers[3].attention
attn_envoy = model._envoy.gpt_neox.layers[3].attention

In [14]:
symbolic_traced = fx.symbolic_trace(attn)

['hidden_states', 'attention_mask', 'position_ids', 'head_mask', 'layer_past', 'use_cache', 'output_attentions', 'padding_mask']
hidden_states
attention_mask
position_ids
head_mask
layer_past
use_cache
output_attentions
padding_mask


TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

In [7]:
attn.is_cross_attention

False

In [None]:
mlp = model._model.transformer.h[3].mlp
mlp_envoy = model._envoy.transformer.h[3].mlp

symbolic_traced : fx.GraphModule = fx.symbolic_trace(mlp)
print(symbolic_traced.graph)

In [None]:
from typing import Any, Optional
from torch._C import ScriptObject 
import inspect
import torch.utils._pytree as pytree

HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS

_proxyable_classes: Dict[Type, None] = {}

test_proxy = 0

correct_args = list(attn_envoy._fake_inputs[0][1].keys()) + ["hidden_states"]

class MyCustomTracer(torch.fx.Tracer):
    def __init__(self):
        super().__init__()

    def create_args_for_root(self, root_fn, is_module, concrete_args=None):
        print("Called create_args_for_root")
        """
        Create ``placeholder`` nodes corresponding to the signature of the ``root``
        Module. This method introspects root's signature and emits those
        nodes accordingly, also supporting ``*args`` and ``**kwargs``.
        """
        # In some cases, a function or method has been decorated with a wrapper
        # defined via ``functools.wraps``. In this case, the outer code object
        # will likely not contain the actual parameters we care about, so unwrap
        # the function to get to the innermost callable.
        fn_for_analysis = inspect.unwrap(root_fn)
        co = fn_for_analysis.__code__
        total_args = co.co_argcount + co.co_kwonlyargcount
        orig_args = list(co.co_varnames)
        names_iter = iter(co.co_varnames)
        args: List[Any] = []
        skip_arg_idx = 0
        if is_module:
            if total_args == 0:
                raise RuntimeError(
                    "``self`` argument cannot be part of *args expansion!"
                )
            skip_arg_idx = 1
            next(names_iter)  # skip self
            args.append(self.root)

        sig = inspect.signature(fn_for_analysis)

        def proxy_placeholder(name: str):
            if concrete_args is not None and name in concrete_args:
                cnt = 0

                def replace_ph(x):
                    nonlocal cnt
                    cnt += 1
                    param = sig.parameters[name]
                    default = (
                        ()
                        if param.default is inspect.Parameter.empty
                        else (param.default,)
                    )
                    out = self.create_proxy(
                        "placeholder", f"{name}_{str(cnt)}", default, {}
                    )
                    if isinstance(x, PHBase):
                        def transfer_attrs(fr, to):
                            for attr_name in dir(fr):
                                attr_val = getattr(fr, attr_name)
                                if (
                                    not callable(attr_val)
                                    and not attr_name.startswith("__")
                                    and not hasattr(to, attr_name)
                                ):
                                    setattr(to, attr_name, attr_val)

                        if x != PH:
                            # Transfer attrs in the case where you're using a placeholder other
                            # than the singleton PH (PH has no attributes to transfer).
                            # Proxies were created out of the placeholders.
                            # Transfer any metadata (put on the placeholders in the form of
                            # attributes set by the user) from the placeholder to the
                            # underlying nodes (the proxy is unwrapped by the user, but
                            # the metadata should hold).
                            transfer_attrs(fr=x, to=out.node)

                        return out
                    # Union[int, bool] == bool in Python <= 3.6
                    if (
                        type(x) == bool
                        or type(x) in base_types
                        and type(x) != torch.Tensor
                    ):
                        torch._assert(
                            out == x,
                            f"{name} has been specialized to have value {x} but got another value",
                        )
                    elif type(x) == type(None):
                        args = (
                            out,
                            f"{name} has been specialized to have value None but got another value",
                        )
                        self.create_proxy("call_function", _assert_is_none, args, {})
                    else:
                        warnings.warn(
                            f"Was not able to add assertion to guarantee correct input {name} to "
                            f"specialized function. It is up to the user to make sure that your inputs match the "
                            f"inputs you specialized the function with."
                        )

                    return x

                return pytree.tree_map(replace_ph, concrete_args[name])
            if name[0] == "*":
                default = ()
            else:
                param = sig.parameters[name]
                default = () if param.default is inspect.Parameter.empty else (param.default,)  # type: ignore[assignment]
            return self.create_proxy(
                "placeholder",
                name,
                default,
                {},
                type_expr=fn_for_analysis.__annotations__.get(name, None)
            )

        arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
        if isinstance(concrete_args, tuple):
            if len(arg_names) != len(concrete_args):
                raise RuntimeError(
                    f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
                )
            concrete_args = dict(zip(arg_names, concrete_args))
        args.extend(proxy_placeholder(names) for names in arg_names)

        if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
            # TODO: type annotations for *args and **kwargs
            if co.co_flags & inspect.CO_VARARGS:
                args.append(proxy_placeholder("*" + next(names_iter)))
            if co.co_flags & inspect.CO_VARKEYWORDS:
                args.append(proxy_placeholder("**" + next(names_iter)))
            root_fn = _patch_function(root_fn, len(args))

        flat_args, in_spec = pytree.tree_flatten(tuple(args))
        if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs):
            # In the case that we have pytree-flattened inputs in
            # `concrete_args`, generate a flattening wrapper around the
            # original root function and return that.
            self.graph._codegen = _PyTreeCodeGen(
                _PyTreeInfo(orig_args[:total_args], in_spec, None)
            )

            def flatten_fn(*args):
                tree_args = pytree.tree_unflatten(list(args), in_spec)
                tree_out = root_fn(*tree_args)
                out_args, out_spec = pytree.tree_flatten(tree_out)
                assert isinstance(self.graph._codegen, _PyTreeCodeGen)
                self.graph._codegen.pytree_info = (
                    self.graph._codegen.pytree_info._replace(out_spec=out_spec)
                )
                return out_args

            return flatten_fn, flat_args

        new_args = [args[0]]

        test_proxy = args[1]
        
        for arg in args[1:]:
            node = arg.node
            name = node.name
            if name in correct_args:
                new_args.append(arg)
                print(name)
            else:
                new_args.append(None)

        
        print(new_args)

        return root_fn, new_args
    
    def create_arg(self, a: Any) -> "Argument":
        """
        A method to specify the behavior of tracing when preparing values to
        be used as arguments to nodes in the ``Graph``.

        By default, the behavior includes:

        #. Iterate through collection types (e.g. tuple, list, dict) and recursively
           call ``create_args`` on the elements.
        #. Given a Proxy object, return a reference to the underlying IR ``Node``
        #. Given a non-Proxy Tensor object, emit IR for various cases:

            * For a Parameter, emit a ``get_attr`` node referring to that Parameter
            * For a non-Parameter Tensor, store the Tensor away in a special
              attribute referring to that attribute.

        This method can be overridden to support more types.

        Args:

            a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.


        Returns:

            The value ``a`` converted into the appropriate ``Argument``
        """
        # The base tracer is used to construct Graphs when there is no associated
        # module hierarchy, so it can never create parameter references.
        # The default tracer adds the ability to refer to parameters when
        # tracing modules.
        if isinstance(a, torch.nn.Parameter):
            print("Conditional type: torch.nn.Parameter")
            for n, p in self.root.named_parameters():
                if a is p:
                    return self.create_node("get_attr", n, (), {})
            raise NameError("parameter is not a member of this module")

        elif isinstance(a, torch.Tensor):
            print("Conditional type: torch.Tensor")
            for n_, p_ in self.root.named_buffers():
                if a is p_:
                    return self.create_node("get_attr", n_, (), {})
        elif isinstance(a, torch.nn.Module):
            print("Conditional type: torch.nn.Module")
            for n_, p_ in self.root.named_modules():
                if a is p_:
                    return self.create_node("get_attr", n_, (), {})
        # For NamedTuple instances that appear literally as args, we emit
        # a node to construct the NamedTuple and use that Node as the argument.
        if isinstance(a, tuple) and hasattr(a, "_fields"):
            print("Conditional type: NamedTuple")
            args = tuple(self.create_arg(elem) for elem in a)
            return self.create_node("call_function", a.__class__, args, {})

        
        # Tensors do not have a reliable string repr() from which they can be
        # constructed (and we probably don't want to rely on that, either), so
        # for any constant Tensor values we encounter, first search for if they
        # are an attribute of some module in the module hierarchy. If so, emit
        # a get_attr to retrieve that tensor. Otherwise, we'll store away the
        # tensor value into a special attribute on the Module s.t. we can
        # retrieve it with a get_attr.
        if isinstance(a, (torch.Tensor, ScriptObject)):
            qualname: Optional[str] = self.tensor_attrs.get(a)

            # Tensor was not found in the Module hierarchy, stow it away in a
            # special attribute and set the qualname to refer to that
            if not qualname:
                i = 0
                while True:
                    qualname = f"_tensor_constant{i}"
                    if not hasattr(self.root, qualname):
                        break
                    i += 1
                self.tensor_attrs[a] = qualname
                setattr(self.root, qualname, a)

            return self.create_node("get_attr", qualname, (), {})

        # if type(a) in _proxyable_classes:
        #     # This is an instance of a proxyable class for which we did not
        #     # witness its construction. Intern this as a constant attribute

        #     # TODO: binary search
        #     i = 0
        #     while True:
        #         qualname = f"_{a.__class__.__name__}_constant_{i}"
        #         if not hasattr(self.root, qualname):
        #             break
        #         i += 1
        #     setattr(self.root, qualname, a)

        #     return self.create_node("get_attr", qualname, (), {})

        return super().create_arg(a)


traced_graph = MyCustomTracer().trace(attn)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
# traced = torch.fx.GraphModule(mod, traced_graph)

In [None]:
torch.full(
    [], test_proxy,   
)
