diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index 7fc50fe..eb4a16d 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -1342,8 +1342,6 @@ defmodule EMLX.Backend do end defp window_op(op, out, tensor, window_shape, opts) do - # TODO: window dilations can be implemented after we support internal padding - # in Nx.pad (we should have pad_internal as a shared defp) tensor_rank = tuple_size(tensor.shape) axes = @@ -1356,7 +1354,20 @@ defmodule EMLX.Backend do {low_pad, high_pad} = Enum.unzip(opts[:padding]) {device, _} = t_mx = from_nx(tensor) - {_device, pad_mx} = + window_dilations = opts[:window_dilations] || List.duplicate(1, tuple_size(window_shape)) + interior_padding_config = Enum.map(window_dilations, &(&1 - 1)) + + {_device, zero_mx} = EMLX.scalar_tensor(0, :bool, device) + + window = + 1 + |> EMLX.scalar_tensor(:bool, device) + |> EMLX.broadcast_to(window_shape) + |> interior_padding_mlx(zero_mx, interior_padding_config) + + window_shape = EMLX.shape(window) + + {device, pad_mx} = case op do :sum -> EMLX.scalar_tensor(0, to_mlx_type(out.type), device) @@ -1375,6 +1386,7 @@ defmodule EMLX.Backend do padded_mx |> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides]) + |> then(&EMLX.where(window, &1, {device, pad_mx})) |> then(&apply(EMLX, op, [&1, axes, false])) |> to_nx(out) end diff --git a/test/emlx/nx_doctest_test.exs b/test/emlx/nx_doctest_test.exs index 43f9547..642c4e8 100644 --- a/test/emlx/nx_doctest_test.exs +++ b/test/emlx/nx_doctest_test.exs @@ -47,13 +47,6 @@ defmodule EMLX.Nx.DoctestTest do @to_be_fixed [ :moduledoc, - # window_* do not support window_dilations yet - window_sum: 3, - window_max: 3, - window_min: 3, - window_product: 3, - window_mean: 3, - # missing support for inner padding # MLX sorts NaNs lowest, Nx sorts them highest argmin: 2, argmax: 2,