-
Notifications
You must be signed in to change notification settings - Fork 212
Closed
Description
Sample code:
defmodule Pilot.Splitter do
import Nx.Defn
@buf_len 12
@channel_len 3
# sample:
# iex> nxb = Nx.from_binary(<<1,2,3,4,5,6,7,8,9,10,11,12>>, {:s, 8})
# iex> Pilot.Splitter.splitnx(nxb)
# works fine
defn splitnx(rgba_tensor) do
Nx.slice(rgba_tensor, [0], [@buf_len], strides: [4])
|> Nx.broadcast({@channel_len, 3}, axes: [0], names: [:pixels, :rgba])
# pad crashes EXLA
|> Nx.pad(0, [{0, 0, 0}, {0, 1, 0}])
end
# crashes with error due to pad (works if pad is removed)
@defn_compiler EXLA
defn(splitnx_host(rgba_tensor), do: splitnx(rgba_tensor))
endOutput:
iex(14)> nxb = Nx.from_binary(<<1,2,3,4,5,6,7,8,9,10,11,12>>, {:s, 8})
#Nx.Tensor<
s8[12]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
>
iex(15)> Pilot.Splitter.splitnx(nxb)
#Nx.Tensor<
s8[pixels: 3][rgba: 4]
[
[1, 1, 1, 0],
[5, 5, 5, 0],
[9, 9, 9, 0]
]
>
iex(16)> Pilot.Splitter.splitnx_host(nxb)
** (RuntimeError) The element types of the operands to Pad do not match.
(exla 0.1.0-dev) lib/exla/op.ex:774: EXLA.Op.unwrap!/1
(exla 0.1.0-dev) lib/exla/op.ex:93: EXLA.Op.get_shape/1
(exla 0.1.0-dev) lib/exla/builder.ex:23: EXLA.Builder.build/1
(exla 0.1.0-dev) lib/exla/defn.ex:35: anonymous fn/9 in EXLA.Defn.compile/4
(exla 0.1.0-dev) lib/exla/locked_cache.ex:23: EXLA.LockedCache.run/2
(exla 0.1.0-dev) lib/exla/defn.ex:32: EXLA.Defn.compile/4
(exla 0.1.0-dev) lib/exla/defn.ex:10: EXLA.Defn.__jit__/4
iex(16)>
Output with pad removed:
iex(18)> Pilot.Splitter.splitnx_host(nxb)
[info] XLA service 0x7f619800b630 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
[info] StreamExecutor device (0): Host, Default Version
#Nx.Tensor<
s8[pixels: 3][rgba: 3]
[
[1, 1, 1],
[5, 5, 5],
[9, 9, 9]
]
I'm on Ubuntu 20.04.
Thanks!!!
Metadata
Metadata
Assignees
Labels
No labels