From 0df7a3298e6d7009c920b97f6dbce82d4d2f3119 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 8 Oct 2025 07:08:28 -0300 Subject: [PATCH 1/2] first iteration of window dilations implementation followin Torchx steps --- lib/emlx/backend.ex | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index 7fc50fe..22d2b57 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -1342,8 +1342,6 @@ defmodule EMLX.Backend do end defp window_op(op, out, tensor, window_shape, opts) do - # TODO: window dilations can be implemented after we support internal padding - # in Nx.pad (we should have pad_internal as a shared defp) tensor_rank = tuple_size(tensor.shape) axes = @@ -1356,6 +1354,17 @@ defmodule EMLX.Backend do {low_pad, high_pad} = Enum.unzip(opts[:padding]) {device, _} = t_mx = from_nx(tensor) + window_dilations = opts[:window_dilations] || List.duplicate(1, tuple_size(window_shape)) + interior_padding_config = Enum.map(window_dilations, &{0, 0, &1 - 1}) + + window = + 1 + |> EMLX.scalar_tensor(:bool, device) + |> EMLX.broadcast_to(window_shape) + |> interior_padding_mlx(0, interior_padding_config) + + window_shape = EMLX.shape(window) + {_device, pad_mx} = case op do :sum -> @@ -1375,6 +1384,8 @@ defmodule EMLX.Backend do padded_mx |> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides]) + |> EMLX.broadcast_to(window_shape) + |> EMLX.where(window, &1) |> then(&apply(EMLX, op, [&1, axes, false])) |> to_nx(out) end From 7070161f1bf25f5fc42347d58a19397981c166e3 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 8 Oct 2025 08:10:27 -0300 Subject: [PATCH 2/2] implemented window dilations on EMLX.window_op() --- lib/emlx/backend.ex | 11 ++++++----- test/emlx/nx_doctest_test.exs | 7 ------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index 22d2b57..eb4a16d 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -1355,17 +1355,19 @@ defmodule EMLX.Backend do {device, _} = t_mx = from_nx(tensor) window_dilations = opts[:window_dilations] || List.duplicate(1, tuple_size(window_shape)) - interior_padding_config = Enum.map(window_dilations, &{0, 0, &1 - 1}) + interior_padding_config = Enum.map(window_dilations, &(&1 - 1)) + + {_device, zero_mx} = EMLX.scalar_tensor(0, :bool, device) window = 1 |> EMLX.scalar_tensor(:bool, device) |> EMLX.broadcast_to(window_shape) - |> interior_padding_mlx(0, interior_padding_config) + |> interior_padding_mlx(zero_mx, interior_padding_config) window_shape = EMLX.shape(window) - {_device, pad_mx} = + {device, pad_mx} = case op do :sum -> EMLX.scalar_tensor(0, to_mlx_type(out.type), device) @@ -1384,8 +1386,7 @@ defmodule EMLX.Backend do padded_mx |> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides]) - |> EMLX.broadcast_to(window_shape) - |> EMLX.where(window, &1) + |> then(&EMLX.where(window, &1, {device, pad_mx})) |> then(&apply(EMLX, op, [&1, axes, false])) |> to_nx(out) end diff --git a/test/emlx/nx_doctest_test.exs b/test/emlx/nx_doctest_test.exs index 43f9547..642c4e8 100644 --- a/test/emlx/nx_doctest_test.exs +++ b/test/emlx/nx_doctest_test.exs @@ -47,13 +47,6 @@ defmodule EMLX.Nx.DoctestTest do @to_be_fixed [ :moduledoc, - # window_* do not support window_dilations yet - window_sum: 3, - window_max: 3, - window_min: 3, - window_product: 3, - window_mean: 3, - # missing support for inner padding # MLX sorts NaNs lowest, Nx sorts them highest argmin: 2, argmax: 2,