Skip to content

Nx.pad RuntimeError with EXLA #546

@madasebrof

Description

@madasebrof

Sample code:

defmodule Pilot.Splitter do
  import Nx.Defn
  @buf_len 12
  @channel_len 3
  
  # sample:
  # iex> nxb = Nx.from_binary(<<1,2,3,4,5,6,7,8,9,10,11,12>>, {:s, 8})
  # iex> Pilot.Splitter.splitnx(nxb)

  # works fine
  defn splitnx(rgba_tensor) do
    Nx.slice(rgba_tensor, [0], [@buf_len], strides: [4])
    |> Nx.broadcast({@channel_len, 3}, axes: [0], names: [:pixels, :rgba])
    # pad crashes EXLA
    |> Nx.pad(0, [{0, 0, 0}, {0, 1, 0}])
  end
  
  # crashes with error due to pad (works if pad is removed)
  @defn_compiler EXLA
  defn(splitnx_host(rgba_tensor), do: splitnx(rgba_tensor))
end

Output:

iex(14)> nxb = Nx.from_binary(<<1,2,3,4,5,6,7,8,9,10,11,12>>, {:s, 8})
#Nx.Tensor<
  s8[12]
  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
>
iex(15)> Pilot.Splitter.splitnx(nxb)                                  
#Nx.Tensor<
  s8[pixels: 3][rgba: 4]
  [
    [1, 1, 1, 0],
    [5, 5, 5, 0],
    [9, 9, 9, 0]
  ]
>
iex(16)> Pilot.Splitter.splitnx_host(nxb)                             
** (RuntimeError) The element types of the operands to Pad do not match.
    (exla 0.1.0-dev) lib/exla/op.ex:774: EXLA.Op.unwrap!/1
    (exla 0.1.0-dev) lib/exla/op.ex:93: EXLA.Op.get_shape/1
    (exla 0.1.0-dev) lib/exla/builder.ex:23: EXLA.Builder.build/1
    (exla 0.1.0-dev) lib/exla/defn.ex:35: anonymous fn/9 in EXLA.Defn.compile/4
    (exla 0.1.0-dev) lib/exla/locked_cache.ex:23: EXLA.LockedCache.run/2
    (exla 0.1.0-dev) lib/exla/defn.ex:32: EXLA.Defn.compile/4
    (exla 0.1.0-dev) lib/exla/defn.ex:10: EXLA.Defn.__jit__/4
iex(16)> 

Output with pad removed:

iex(18)> Pilot.Splitter.splitnx_host(nxb)
[info] XLA service 0x7f619800b630 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
[info]   StreamExecutor device (0): Host, Default Version
#Nx.Tensor<
  s8[pixels: 3][rgba: 3]
  [
    [1, 1, 1],
    [5, 5, 5],
    [9, 9, 9]
  ]

I'm on Ubuntu 20.04.

Thanks!!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions