/
tensor.ex
178 lines (137 loc) · 5.56 KB
/
tensor.ex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
defmodule Nx.Tensor do
@moduledoc """
The tensor struct and the behaviour for backends.
`Nx.Tensor` is a generic container for multidimensional data structures.
It contains the tensor type, shape, and names. The data itself is a
struct that points to a backend responsible for controlling the data.
The backend behaviour is described in `Nx.Backend`.
The tensor has the following fields:
* `:data` - the tensor backend and its data
* `:shape` - the tensor shape
* `:type` - the tensor type
* `:names` - the tensor names
In general it is discouraged to access those fields directly. Use
the functions in the `Nx` module instead. Backends have to access those
fields but it cannot update them, except for the `:data` field itself.
"""
@type data :: Nx.Backend.t()
@type type :: Nx.Type.t()
@type shape :: tuple()
@type axis :: name | integer
@type axes :: [axis]
@type name :: atom
@type t :: %Nx.Tensor{data: data, type: type, shape: shape, names: [name]}
@type t(data) :: %Nx.Tensor{data: data, type: type, shape: shape, names: [name]}
@enforce_keys [:type, :shape, :names]
defstruct [:data, :type, :shape, :names]
## Access
@behaviour Access
@impl true
def fetch(%Nx.Tensor{shape: {}} = tensor, _index) do
raise ArgumentError,
"cannot use the tensor[index] syntax on scalar tensor #{inspect(tensor)}"
end
def fetch(tensor, %Nx.Tensor{} = index),
do: {:ok, fetch_axes(tensor, [{0, index}])}
def fetch(tensor, index) when is_integer(index),
do: {:ok, fetch_axes(tensor, [{0, index}])}
def fetch(tensor, _.._//_ = range),
do: {:ok, fetch_axes(tensor, [{0, range}])}
def fetch(tensor, []),
do: {:ok, tensor}
def fetch(%{names: names} = tensor, [{_, _} | _] = keyword),
do: {:ok, fetch_axes(tensor, with_names(keyword, names, []))}
def fetch(tensor, [_ | _] = list),
do: {:ok, fetch_axes(tensor, with_index(list, 0, []))}
def fetch(_tensor, value) do
raise """
tensor[slice] expects slice to be one of:
* an integer or a scalar tensor representing a zero-based index
* a first..last range representing inclusive start-stop indexes
* a list of integers and ranges
* a keyword list of integers and ranges
Got #{inspect(value)}
"""
end
defp with_index([h | t], i, acc), do: with_index(t, i + 1, [{i, h} | acc])
defp with_index([], _i, acc), do: acc
defp with_names([{k, v} | t], names, acc),
do: with_names(t, names, [{Nx.Shape.find_name!(names, k), v} | acc])
defp with_names([], _names, acc),
do: acc
defp fetch_axes(%Nx.Tensor{shape: shape} = tensor, axes) do
rank = Nx.rank(shape)
impl = Nx.Shared.impl!(tensor)
{start, lengths, squeeze} = fetch_axes(rank - 1, axes, shape, [], [], [])
%{tensor | shape: List.to_tuple(lengths)}
|> impl.slice(tensor, start, lengths, List.duplicate(1, rank))
|> Nx.squeeze(axes: squeeze)
end
defp fetch_axes(axis, axes, shape, start, lengths, squeeze) when axis >= 0 do
case List.keytake(axes, axis, 0) do
{{^axis, %Nx.Tensor{} = index}, axes} ->
fetch_axes(axis - 1, axes, shape, [index | start], [1 | lengths], [axis | squeeze])
{{^axis, index}, axes} when is_integer(index) ->
index = normalize_index(index, axis, shape)
fetch_axes(axis - 1, axes, shape, [index | start], [1 | lengths], [axis | squeeze])
{{^axis, first..last//step = range}, axes} ->
first = normalize_index(first, axis, shape)
last = normalize_index(last, axis, shape)
if last < first or step != 1 do
raise ArgumentError,
"slicing a tensor requires a non-empty range with a step of 1, got: #{inspect(range)}"
end
len = last - first + 1
fetch_axes(axis - 1, axes, shape, [first | start], [len | lengths], squeeze)
{{^axis, value}, _} ->
raise ArgumentError,
"slicing a tensor on an axis requires an integer, a scalar tensor or a range, got: " <>
inspect(value)
nil ->
fetch_axes(axis - 1, axes, shape, [0 | start], [elem(shape, axis) | lengths], squeeze)
end
end
defp fetch_axes(_axis, [{axis, _} | _], shape, _start, _lengths, _squeeze) do
raise ArgumentError,
"unknown or duplicate axis #{axis} found when slicing shape #{inspect(shape)}"
end
defp fetch_axes(_axis, [], _shape, start, lengths, squeeze) do
{start, lengths, squeeze}
end
defp normalize_index(index, axis, shape) do
dim = elem(shape, axis)
norm = if index < 0, do: dim + index, else: index
if norm < 0 or norm >= dim do
raise ArgumentError,
"index #{index} is out of bounds for axis #{axis} in shape #{inspect(shape)}"
end
norm
end
@impl true
def get_and_update(_tensor, _index, _update) do
raise "Access.get_and_update/3 is not supported. Please use Nx.put_slice/3 instead"
end
@impl true
def pop(_tensor, _index) do
raise "Access.pop/2 is not yet supported by Nx.Tensor"
end
defimpl Inspect do
import Inspect.Algebra
def inspect(%{shape: shape, names: names, type: type} = tensor, opts) do
open = color("[", :list, opts)
close = color("]", :list, opts)
type = color(Nx.Type.to_string(type), :atom, opts)
shape = Nx.Shape.to_algebra(shape, names, open, close)
data = tensor.data.__struct__.inspect(tensor, opts)
inner = concat([line(), type, shape, line(), data])
force_unfit(
concat([
color("#Nx.Tensor<", :map, opts),
nest(inner, 2),
line(),
color(">", :map, opts)
])
)
end
end
end