-
Notifications
You must be signed in to change notification settings - Fork 908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grid Sample functionality in MLX backend #1213
Comments
#Grid sampling function
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
N, C, H_in, W_in = input.shape
N, H_out, W_out, _ = grid.shape
# Normalize the grid to [0, H_in-1] and [0, W_in-1]
grid_x = (grid[..., 0] + 1) * (W_in - 1) / 2
grid_y = (grid[..., 1] + 1) * (H_in - 1) / 2
if mode == 'bilinear':
x0 = mx.floor(grid_x).astype(mx.int32)
x1 = x0 + 1
y0 = mx.floor(grid_y).astype(mx.int32)
y1 = y0 + 1
# Clip to the range [0, H_in-1] and [0, W_in-1]
x0 = mx.clip(x0, 0, W_in - 1)
x1 = mx.clip(x1, 0, W_in - 1)
y0 = mx.clip(y0, 0, H_in - 1)
y1 = mx.clip(y1, 0, H_in - 1)
Ia = input[mx.arange(N)[:, None, None], :, y0, x0]
Ib = input[mx.arange(N)[:, None, None], :, y1, x0]
Ic = input[mx.arange(N)[:, None, None], :, y0, x1]
Id = input[mx.arange(N)[:, None, None], :, y1, x1]
wa = (x1 - grid_x) * (y1 - grid_y)
wb = (x1 - grid_x) * (grid_y - y0)
wc = (grid_x - x0) * (y1 - grid_y)
wd = (grid_x - x0) * (grid_y - y0)
output = wa[..., None] * Ia + wb[..., None] * Ib + wc[..., None] * Ic + wd[..., None] * Id
elif mode == 'nearest':
x = mx.round(grid_x).astype(mx.int32)
y = mx.round(grid_y).astype(mx.int32)
# Clip to the range [0, H_in-1] and [0, W_in-1]
x = mx.clip(x, 0, W_in - 1)
y = mx.clip(y, 0, H_in - 1)
output = input[..., y, x]
else:
raise ValueError(f"Unsupported mode: {mode}")
if padding_mode == 'zeros':
out_of_bound = (grid_x < 0) | (grid_x > W_in - 1) | (grid_y < 0) | (grid_y > H_in - 1)
output = mx.where(out_of_bound[..., None], mx.zeros_like(output), output)
elif padding_mode == 'border':
output = output
elif padding_mode == 'reflection':
grid_x = mx.abs(mx.clip(grid_x, -1, 1))
grid_y = mx.abs(mx.clip(grid_y, -1, 1))
else:
raise ValueError(f"Unsupported padding mode: {padding_mode}")
return output
While the forward pass works, gradient computation gives me this error: |
Are you sure you need to differentiate with respect to My understanding was that in Deformable DETR you would typically stop the gradient flow through the reference points to avoid this (e.g. here and here in the original implementation). In MLX you can use |
Hi @barronalex , I had stopped gradient at reference points as well as at topk_coords_unact in the deformable transformer at the same places. However, I'm still getting the same error. Is there a way I can debug where exactly the error "ValueError: [gather] Cannot calculate VJP with respect to indices." is coming from during gradient computation? Additionally, I tried stopping the gradient at sampling_grid_l_ inside the ms_deform_attn_core_mlx function. That avoided the error, however, it stopped the sampling_offsets layer from getting updated. I'm not sure how to make this work since I believe that I will have to take the differentiation of sampling_offsets with respect to the grid to have the weights of the sampling_offsets layer getting updated. Please correct me if I'm wrong. Just for reference, here is my complete code for the Deformable Attention calculation: import math
import warnings
import mlx.core as mx
import mlx.nn as nn
def _is_power_of_2(n):
if not isinstance(n, int) or n <= 0:
raise ValueError(f"invalid input for _is_power_of_2: {n} (type: {type(n)})")
return (n & (n - 1) == 0) and n != 0
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
N, C, H_in, W_in = input.shape
N, H_out, W_out, _ = grid.shape
# Normalize the grid to [0, H_in-1] and [0, W_in-1]
grid_x = (grid[..., 0] + 1) * (W_in - 1) / 2
grid_y = (grid[..., 1] + 1) * (H_in - 1) / 2
if mode == 'bilinear':
x0 = mx.floor(grid_x).astype(mx.int32)
x1 = x0 + 1
y0 = mx.floor(grid_y).astype(mx.int32)
y1 = y0 + 1
# Clip to the range [0, H_in-1] and [0, W_in-1]
x0 = mx.clip(x0, 0, W_in - 1)
x1 = mx.clip(x1, 0, W_in - 1)
y0 = mx.clip(y0, 0, H_in - 1)
y1 = mx.clip(y1, 0, H_in - 1)
Ia = input[mx.arange(N)[:, None, None], :, y0, x0]
Ib = input[mx.arange(N)[:, None, None], :, y1, x0]
Ic = input[mx.arange(N)[:, None, None], :, y0, x1]
Id = input[mx.arange(N)[:, None, None], :, y1, x1]
wa = (x1 - grid_x) * (y1 - grid_y)
wb = (x1 - grid_x) * (grid_y - y0)
wc = (grid_x - x0) * (y1 - grid_y)
wd = (grid_x - x0) * (grid_y - y0)
output = wa[..., None] * Ia + wb[..., None] * Ib + wc[..., None] * Ic + wd[..., None] * Id
elif mode == 'nearest':
x = mx.round(grid_x).astype(mx.int32)
y = mx.round(grid_y).astype(mx.int32)
# Clip to the range [0, H_in-1] and [0, W_in-1]
x = mx.clip(x, 0, W_in - 1)
y = mx.clip(y, 0, H_in - 1)
output = input[..., y, x]
else:
raise ValueError(f"Unsupported mode: {mode}")
if padding_mode == 'zeros':
out_of_bound = (grid_x < 0) | (grid_x > W_in - 1) | (grid_y < 0) | (grid_y > H_in - 1)
output = mx.where(out_of_bound[..., None], mx.zeros_like(output), output)
elif padding_mode == 'border':
output = output
elif padding_mode == 'reflection':
grid_x = mx.abs(mx.clip(grid_x, -1, 1))
grid_y = mx.abs(mx.clip(grid_y, -1, 1))
else:
raise ValueError(f"Unsupported padding mode: {padding_mode}")
return output
def ms_deform_attn_core_mlx(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
level_indices = [H_ * W_ for H_, W_ in value_spatial_shapes]
split_indices = []
prev = 0
for i in range(len(level_indices)):
split_indices.append(prev + level_indices[i])
prev = split_indices[-1]
split_indices = split_indices[:-1]
value_list = mx.split(value, split_indices, axis=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_, W_ -> N_*M_, D_, H_, W_
value_l_ = mx.reshape(mx.transpose(mx.reshape(value_list[lid_], (N_, H_ * W_, M_, D_)), (0, 2, 3, 1)), (N_ * M_, D_, H_, W_))
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = mx.reshape(mx.transpose(sampling_grids[:, :, :, lid_, :], (0, 2, 1, 3, 4)), (N_ * M_, Lq_, P_, 2))
# N_*M_, D_, Lq_, P_
#print(value_l_.shape, sampling_grid_l_.shape)
# sampling_value_l_ = mx.stop_gradient(sampling_value_l_) ---->>> ADDING THIS LINE AVOIDS THE ERROR BUT RESULTS IN 0 GRADIENT FOR THE SAMPLING_OFFSETS LAYER.
sampling_value_l_ = grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros')
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = mx.reshape(mx.transpose(attention_weights, (0, 2, 1, 3, 4)), (N_ * M_, 1, Lq_, L_ * P_))
output = mx.sum(mx.reshape(mx.stack(sampling_value_list, axis=-2), (N_ * M_, D_, Lq_, L_ * P_)) * attention_weights, axis=-1).reshape((N_, M_ * D_, Lq_))
return mx.transpose(output, (0, 2, 1))
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
super().__init__()
if d_model % n_heads != 0:
raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
_d_per_head = d_model // n_heads
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our implementation.")
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
nn.init.constant(self.sampling_offsets.weight, 0.0)
thetas = mx.arange(self.n_heads) * (2.0 * math.pi / self.n_heads)
grid_init = mx.stack([mx.cos(thetas), mx.sin(thetas)], axis=-1)
grid_init = (grid_init / grid_init.abs().max(axis=-1, keepdims=True)[0]).reshape(self.n_heads, 1, 1, 2)
grid_init = mx.tile(grid_init, (1, self.n_levels, self.n_points, 1))
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias = mx.array(grid_init.reshape(-1))
nn.init.constant(self.attention_weights.weight, 0.0)
nn.init.constant(self.attention_weights.bias, 0.0)
nn.init.uniform(self.value_proj.weight)
nn.init.constant(self.value_proj.bias, 0.0)
nn.init.uniform(self.output_proj.weight)
nn.init.constant(self.output_proj.bias, 0.0)
def __call__(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value * (1 - input_padding_mask[..., None].astype(value.dtype))
value = value.reshape(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).reshape(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).reshape(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = nn.softmax(attention_weights, axis=-1).reshape(N, Len_q, self.n_heads, self.n_levels, self.n_points)
if reference_points.shape[-1] == 2:
input_spatial_shapes_mx = mx.array(input_spatial_shapes)
offset_normalizer = mx.stack([input_spatial_shapes_mx[:, 1], input_spatial_shapes_mx[:, 0]], axis=-1)
sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]} instead.')
output = ms_deform_attn_core_mlx(value, input_spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output
|
I think the sampling_locations = mx.stop_gradient(sampling_locations)
output = ms_deform_attn_core_mlx(value, input_spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output Let me know if you still see the error. |
I tried stopping the gradient at sampling_grid_l_ inside the ms_deform_attn_core_mlx function. That avoided the error, however, it stopped the sampling_offsets layer from getting updated because of 0 gradient. I'm not sure how to make this work since I believe that I will have to take the differentiation of sampling_offsets with respect to the grid to have the weights of the sampling_offsets layer getting updated. Please correct me if I'm wrong. |
Very sorry you're completely right you do need the gradient with respect to the grid. It looks like PyTorch has a custom backward pass implementation for The Scenic Deformable DETR implementation in Jax does seem to rely on differentiating w.r.t indices through a gather operation though so maybe it is something we should support. I can take a closer look at that later today. |
Hi @barronalex : Many thanks for your help in this regard. I will be waiting to hear from you. |
Hi @barronalex : Do you have any insights on this? |
+1 I believe this is might also be necessary for grid-based NeRF representations as described in InstantNGP. |
Sorry for the delay! To fix this, I'm adding zero VJPs to gather w.r.t. indices and the bitwise ops in #1256 With that, I'm able to get an MLX import mlx.core as mx
import numpy as np
import torch
import torch.nn.functional as F
def grid_sample_mx(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False):
if align_corners:
raise NotImplementedError("`align_corners=True` not yet implemented.")
if padding_mode != 'zeros':
raise NotImplementedError(f"padding_mode={padding_mode} not yet implemented.")
if mode != "bilinear":
raise NotImplementedError(f"mode={mode} not yet implemented.")
N, C, H_in, W_in = input.shape
N, H_out, W_out, _ = grid.shape
# Normalize the grid to [0, H_in-1] and [0, W_in-1]
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
if mode == 'bilinear':
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = input[mx.arange(N)[:, None, None], :, iy_nw, ix_nw]
I_ne = input[mx.arange(N)[:, None, None], :, iy_ne, ix_ne]
I_sw = input[mx.arange(N)[:, None, None], :, iy_sw, ix_sw]
I_se = input[mx.arange(N)[:, None, None], :, iy_se, ix_se]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
np.random.seed(7)
n, m, gn, gm = 32, 16, 8, 4
x = np.random.normal(size=(3, 7, n, m)).astype(np.float32)
grid = np.random.uniform(-1.5, 1, size=(3, gn, gm, 2)).astype(np.float32)
x_t = torch.tensor(x, requires_grad=True)
grid_t = torch.tensor(grid, requires_grad=True)
out_t = F.grid_sample(x_t, grid_t, mode="bilinear", padding_mode="zeros")
loss_t = out_t.sum()
loss_t.backward()
out_t = out_t.cpu().detach().numpy().transpose((0, 2, 3, 1))
x = mx.array(x)
grid = mx.array(grid)
def grid_sample(x, grid):
return grid_sample_mx(x, grid, mode="bilinear", padding_mode="zeros").sum()
loss, (grad_x, grad_grid) = mx.value_and_grad(grid_sample, argnums=(0, 1))(x, grid)
np.testing.assert_allclose(x_t.grad.cpu().numpy(), grad_x, atol=1e-4)
np.testing.assert_allclose(grid_t.grad.cpu().numpy(), grad_grid, atol=1e-4)
np.testing.assert_allclose(loss_t.detach().cpu().numpy(), loss, atol=1e-4) |
If this is a big performance bottleneck then we could also consider adding a custom GPU op some time in the future but hopefully this unblocks you for now. |
Hi @barronalex : Many thanks for your help. Hoping to see a custom VJP for gpu in the future. But of course this unblocks me for now. Thanks again! |
I’m trying to implement Deformable DeTr in MLX and I’m stuck at the grid sampling part of the code. Can you help me replicate the torch.nn.functional.grid_sample function in MLX please?
The text was updated successfully, but these errors were encountered: