-
Notifications
You must be signed in to change notification settings - Fork 188
/
token.ex
45 lines (36 loc) · 1.04 KB
/
token.ex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
defmodule Nx.Defn.Token do
@moduledoc """
A `defn` token used by hooks.
## Documentation for compilers
The token has a `hooks` field as a list of maps of the shape:
%{
expr: Nx.Tensor.t | Nx.Container.t,
name: atom(),
callback: (Nx.Tensor.t | Nx.Container.t -> term()) | nil
}
The `hooks` field must only be accessed by `defn` compilers.
"""
# Hooks are stored with the hooks declared first
# at the end of the list.
defstruct hooks: []
@doc false
def new do
%Nx.Defn.Token{}
end
@doc false
def add_hook(%Nx.Defn.Token{} = token, expr, name, callback)
when is_atom(name) and (is_function(callback) or is_nil(callback)) do
hook = %{expr: expr, name: name, callback: callback}
update_in(token.hooks, &[hook | &1])
end
defimpl Inspect do
import Inspect.Algebra
def inspect(%{hooks: hooks}, opts) do
concat([
color("#Nx.Defn.Token<", :map, opts),
to_doc(Enum.map(hooks, & &1.name), opts),
color(">", :map, opts)
])
end
end
end