From 03501569792506d29f730dccb9188f5686f74200 Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Tue, 5 Mar 2024 09:46:28 +0000 Subject: [PATCH] fix: remove unnecessary dim expansion from ivy.interpolate --- ivy/functional/ivy/experimental/layers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ivy/functional/ivy/experimental/layers.py b/ivy/functional/ivy/experimental/layers.py index 0942677e34361..d70889c3745c9 100644 --- a/ivy/functional/ivy/experimental/layers.py +++ b/ivy/functional/ivy/experimental/layers.py @@ -1476,10 +1476,6 @@ def nearest_interpolate(x, dims, size, scale, exact): n = size[d] offsets = (ivy.arange(n, dtype="float32") + off) * scale[d] offsets = ivy.astype(ivy.floor(ivy.astype(offsets, "float32")), "int64") - num_dims_to_add = x.ndim - offsets.ndim - if num_dims_to_add > 0: - for _ in range(num_dims_to_add): - offsets = ivy.expand_dims(offsets, axis=0) x = ivy.gather(x, offsets, axis=d + 2) return x