/
composite.ex
165 lines (126 loc) · 5.01 KB
/
composite.ex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
defmodule Nx.Defn.Composite do
@moduledoc """
Functions to deal with composite data types according to `Nx.Container`.
The functions in this module can be used both inside and outside `defn`.
Note the functions in this module traverses tensors, but it does not
automatically convert values to tensors. For example, the tuple `{1, 2, 3}`
once traversed will emit the numbers `1`, `2`, and `3`. If desired,
you can invoke `Nx.to_tensor/1` to normalize them.
Note that, when a value is given to `defn`, it is first converted to
tensors and containers via `Nx.LazyContainer`. Inside `defn`, there are
no lazy containers, only containers.
"""
alias Nx.Tensor, as: T
import Nx, only: [is_tensor: 1]
@doc """
Traverses two composite types to see if they are compatible.
For non-composite types, the given `fun` will be called to
compare numbers/tensors pairwise.
"""
def compatible?(left, right, fun)
when is_tensor(left) and is_tensor(right),
do: fun.(left, right)
def compatible?(left, right, fun) when tuple_size(left) == tuple_size(right) do
Tuple.to_list(left)
|> Enum.zip(Tuple.to_list(right))
|> Enum.all?(fn {l, r} -> compatible?(l, r, fun) end)
end
def compatible?(%mod{} = left, %mod{} = right, fun) do
# LazyContainer is fully recursive but we don't want to go full recursive
# unless we have to, so we can also compare structures along the way.
{left, right} =
case Nx.LazyContainer.impl_for(left) do
Nx.LazyContainer.Any ->
left = Nx.Container.reduce(left, [], &[&1 | &2])
right = Nx.Container.reduce(right, [], &[&1 | &2])
{left, right}
impl ->
{_, left} =
impl.traverse(left, [], fn template, _fun, acc -> {template, [template | acc]} end)
{_, right} =
impl.traverse(right, [], fn template, _fun, acc -> {template, [template | acc]} end)
{left, right}
end
Enum.zip(left, right) |> Enum.all?(fn {l, r} -> compatible?(l, r, fun) end)
end
def compatible?(%_{}, %_{}, _fun),
do: false
def compatible?(left, right, fun) when map_size(left) == map_size(right) do
Enum.all?(left, fn {k, v1} ->
case right do
%{^k => v2} -> compatible?(v1, v2, fun)
%{} -> false
end
end)
end
def compatible?(_, _, _),
do: false
@doc """
Counts the number of non-composite types in the composite type.
## Examples
iex> Nx.Defn.Composite.count(123)
1
iex> Nx.Defn.Composite.count({1, {2, 3}})
3
iex> Nx.Defn.Composite.count({Complex.new(1), {Nx.tensor(2), 3}})
3
"""
def count(tree), do: count(tree, 0)
defp count(tensor, acc) when is_tensor(tensor), do: acc + 1
defp count(container, acc), do: Nx.Container.reduce(container, acc, &count/2)
@doc """
Traverses recursively the given composite types with `fun`.
If a composite tensor is given, such as a tuple, the composite
type is recursively traversed and returned.
Otherwise the function is invoked with the tensor (be it a
number, complex, or actual tensor).
"""
def traverse(expr, fun) when is_function(fun, 1) do
{result, []} = traverse(expr, [], fn expr, [] -> {fun.(expr), []} end)
result
end
@doc """
Traverses recursively the given composite types with `acc` and `fun`.
If a composite tensor is given, such as a tuple, the composite
type is recursively traversed and returned.
Otherwise the function is invoked with the tensor (be it a
number, complex, or actual tensor).
"""
def traverse(expr, acc, fun) when is_tensor(expr) and is_function(fun, 2),
do: fun.(expr, acc)
def traverse(container, acc, fun),
do: Nx.Container.traverse(container, acc, &traverse(&1, &2, fun))
@doc """
Reduces recursively the given composite types with `acc` and `fun`.
If composite tensor expressions are given, such as a tuple,
the composite type is recursively traversed and returned.
If a non-composite tensor expression is given, the function
is invoked for it but not for its arguments.
"""
def reduce(expr, acc, fun) when is_tensor(expr) and is_function(fun, 2),
do: fun.(expr, acc)
def reduce(container, acc, fun),
do: Nx.Container.reduce(container, acc, &reduce(&1, &2, fun))
@doc """
Flattens recursively the given list of composite types.
Elements that are not tensors (i.e. numbers and `Complex` numbers) are kept as is
unless a custom function is given.
## Examples
iex> Nx.Defn.Composite.flatten_list([1, {2, 3}])
[1, 2, 3]
iex> Nx.Defn.Composite.flatten_list([1, {2, 3}], [Nx.tensor(4)])
[1, 2, 3, Nx.tensor(4)]
"""
def flatten_list(args, tail \\ []) when is_list(args) do
args
|> Enum.reduce([], &flatten_each/2)
|> Enum.reverse(tail)
end
defp flatten_each(%T{} = tensor, acc),
do: [tensor | acc]
defp flatten_each(number, acc)
when is_number(number) or is_struct(number, Complex),
do: [number | acc]
defp flatten_each(container, acc),
do: Nx.Container.reduce(container, acc, &flatten_each/2)
end