Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grid Sample functionality in MLX backend #1213

Closed
sachinraja13 opened this issue Jun 15, 2024 · 12 comments
Closed

Grid Sample functionality in MLX backend #1213

sachinraja13 opened this issue Jun 15, 2024 · 12 comments

Comments

@sachinraja13
Copy link

I’m trying to implement Deformable DeTr in MLX and I’m stuck at the grid sampling part of the code. Can you help me replicate the torch.nn.functional.grid_sample function in MLX please?

@sachinraja13
Copy link
Author

sachinraja13 commented Jun 16, 2024

#Grid sampling function

 def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
    N, C, H_in, W_in = input.shape
    N, H_out, W_out, _ = grid.shape
    # Normalize the grid to [0, H_in-1] and [0, W_in-1]
    grid_x = (grid[..., 0] + 1) * (W_in - 1) / 2
    grid_y = (grid[..., 1] + 1) * (H_in - 1) / 2
 
    if mode == 'bilinear':
        x0 = mx.floor(grid_x).astype(mx.int32)
        x1 = x0 + 1
        y0 = mx.floor(grid_y).astype(mx.int32)
        y1 = y0 + 1

        # Clip to the range [0, H_in-1] and [0, W_in-1]
        x0 = mx.clip(x0, 0, W_in - 1)
        x1 = mx.clip(x1, 0, W_in - 1)
        y0 = mx.clip(y0, 0, H_in - 1)
        y1 = mx.clip(y1, 0, H_in - 1)

        Ia = input[mx.arange(N)[:, None, None], :, y0, x0]
        Ib = input[mx.arange(N)[:, None, None], :, y1, x0]
        Ic = input[mx.arange(N)[:, None, None], :, y0, x1]
        Id = input[mx.arange(N)[:, None, None], :, y1, x1]
        

        wa = (x1 - grid_x) * (y1 - grid_y)
        wb = (x1 - grid_x) * (grid_y - y0)
        wc = (grid_x - x0) * (y1 - grid_y)
        wd = (grid_x - x0) * (grid_y - y0)

        output = wa[..., None] * Ia + wb[..., None] * Ib + wc[..., None] * Ic + wd[..., None] * Id
    elif mode == 'nearest':
        x = mx.round(grid_x).astype(mx.int32)
        y = mx.round(grid_y).astype(mx.int32)

        # Clip to the range [0, H_in-1] and [0, W_in-1]
        x = mx.clip(x, 0, W_in - 1)
        y = mx.clip(y, 0, H_in - 1)

        output = input[..., y, x]

    else:
        raise ValueError(f"Unsupported mode: {mode}")

    if padding_mode == 'zeros':
        out_of_bound = (grid_x < 0) | (grid_x > W_in - 1) | (grid_y < 0) | (grid_y > H_in - 1)
        output = mx.where(out_of_bound[..., None], mx.zeros_like(output), output)
    elif padding_mode == 'border':
        output = output
    elif padding_mode == 'reflection':
        grid_x = mx.abs(mx.clip(grid_x, -1, 1))
        grid_y = mx.abs(mx.clip(grid_y, -1, 1))
    else:
        raise ValueError(f"Unsupported padding mode: {padding_mode}")

    return output
        

While the forward pass works, gradient computation gives me this error:
ValueError: [gather] Cannot calculate VJP with respect to indices.

@barronalex
Copy link
Collaborator

Are you sure you need to differentiate with respect to grid above?

My understanding was that in Deformable DETR you would typically stop the gradient flow through the reference points to avoid this (e.g. here and here in the original implementation).

In MLX you can use mx.stop_gradient(x) in place of x.detach() in PyTorch.

@sachinraja13
Copy link
Author

sachinraja13 commented Jun 18, 2024

Hi @barronalex ,

I had stopped gradient at reference points as well as at topk_coords_unact in the deformable transformer at the same places. However, I'm still getting the same error. Is there a way I can debug where exactly the error "ValueError: [gather] Cannot calculate VJP with respect to indices." is coming from during gradient computation?

Additionally, I tried stopping the gradient at sampling_grid_l_ inside the ms_deform_attn_core_mlx function. That avoided the error, however, it stopped the sampling_offsets layer from getting updated. I'm not sure how to make this work since I believe that I will have to take the differentiation of sampling_offsets with respect to the grid to have the weights of the sampling_offsets layer getting updated. Please correct me if I'm wrong.

Just for reference, here is my complete code for the Deformable Attention calculation:

import math
import warnings
import mlx.core as mx
import mlx.nn as nn

def _is_power_of_2(n):
    if not isinstance(n, int) or n <= 0:
        raise ValueError(f"invalid input for _is_power_of_2: {n} (type: {type(n)})")
    return (n & (n - 1) == 0) and n != 0

def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
    N, C, H_in, W_in = input.shape
    N, H_out, W_out, _ = grid.shape

    # Normalize the grid to [0, H_in-1] and [0, W_in-1]
    grid_x = (grid[..., 0] + 1) * (W_in - 1) / 2
    grid_y = (grid[..., 1] + 1) * (H_in - 1) / 2
 
    if mode == 'bilinear':
        x0 = mx.floor(grid_x).astype(mx.int32)
        x1 = x0 + 1
        y0 = mx.floor(grid_y).astype(mx.int32)
        y1 = y0 + 1

        # Clip to the range [0, H_in-1] and [0, W_in-1]
        x0 = mx.clip(x0, 0, W_in - 1)
        x1 = mx.clip(x1, 0, W_in - 1)
        y0 = mx.clip(y0, 0, H_in - 1)
        y1 = mx.clip(y1, 0, H_in - 1)

        Ia = input[mx.arange(N)[:, None, None], :, y0, x0]
        Ib = input[mx.arange(N)[:, None, None], :, y1, x0]
        Ic = input[mx.arange(N)[:, None, None], :, y0, x1]
        Id = input[mx.arange(N)[:, None, None], :, y1, x1]
        

        wa = (x1 - grid_x) * (y1 - grid_y)
        wb = (x1 - grid_x) * (grid_y - y0)
        wc = (grid_x - x0) * (y1 - grid_y)
        wd = (grid_x - x0) * (grid_y - y0)

        output = wa[..., None] * Ia + wb[..., None] * Ib + wc[..., None] * Ic + wd[..., None] * Id
    elif mode == 'nearest':
        x = mx.round(grid_x).astype(mx.int32)
        y = mx.round(grid_y).astype(mx.int32)

        # Clip to the range [0, H_in-1] and [0, W_in-1]
        x = mx.clip(x, 0, W_in - 1)
        y = mx.clip(y, 0, H_in - 1)

        output = input[..., y, x]

    else:
        raise ValueError(f"Unsupported mode: {mode}")

    if padding_mode == 'zeros':
        out_of_bound = (grid_x < 0) | (grid_x > W_in - 1) | (grid_y < 0) | (grid_y > H_in - 1)
        output = mx.where(out_of_bound[..., None], mx.zeros_like(output), output)
    elif padding_mode == 'border':
        output = output
    elif padding_mode == 'reflection':
        grid_x = mx.abs(mx.clip(grid_x, -1, 1))
        grid_y = mx.abs(mx.clip(grid_y, -1, 1))
    else:
        raise ValueError(f"Unsupported padding mode: {padding_mode}")

    return output
        

def ms_deform_attn_core_mlx(value, value_spatial_shapes, sampling_locations, attention_weights):
    # for debug and test only,
    # need to use cuda version instead
    N_, S_, M_, D_ = value.shape
    _, Lq_, M_, L_, P_, _ = sampling_locations.shape

    level_indices = [H_ * W_ for H_, W_ in value_spatial_shapes]
    split_indices = []
    prev = 0
    for i in range(len(level_indices)):
        split_indices.append(prev + level_indices[i])
        prev = split_indices[-1]
    split_indices = split_indices[:-1]
    value_list = mx.split(value, split_indices, axis=1)

    sampling_grids = 2 * sampling_locations - 1

    sampling_value_list = []
    for lid_, (H_, W_) in enumerate(value_spatial_shapes):
        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_, W_ -> N_*M_, D_, H_, W_
        value_l_ = mx.reshape(mx.transpose(mx.reshape(value_list[lid_], (N_, H_ * W_, M_, D_)), (0, 2, 3, 1)), (N_ * M_, D_, H_, W_))
        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
        sampling_grid_l_ = mx.reshape(mx.transpose(sampling_grids[:, :, :, lid_, :], (0, 2, 1, 3, 4)), (N_ * M_, Lq_, P_, 2))
        # N_*M_, D_, Lq_, P_
        #print(value_l_.shape, sampling_grid_l_.shape)
        # sampling_value_l_ = mx.stop_gradient(sampling_value_l_) ---->>> ADDING THIS LINE AVOIDS THE ERROR BUT RESULTS IN 0 GRADIENT FOR THE SAMPLING_OFFSETS LAYER.
        sampling_value_l_ = grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros')
        sampling_value_list.append(sampling_value_l_)
    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
    attention_weights = mx.reshape(mx.transpose(attention_weights, (0, 2, 1, 3, 4)), (N_ * M_, 1, Lq_, L_ * P_))
    output = mx.sum(mx.reshape(mx.stack(sampling_value_list, axis=-2), (N_ * M_, D_, Lq_, L_ * P_)) * attention_weights, axis=-1).reshape((N_, M_ * D_, Lq_))
    return mx.transpose(output, (0, 2, 1))

class MSDeformAttn(nn.Module):
    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
        _d_per_head = d_model // n_heads
        if not _is_power_of_2(_d_per_head):
            warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
                          "which is more efficient in our implementation.")

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.constant(self.sampling_offsets.weight, 0.0)
        thetas = mx.arange(self.n_heads) * (2.0 * math.pi / self.n_heads)
        grid_init = mx.stack([mx.cos(thetas), mx.sin(thetas)], axis=-1)
        grid_init = (grid_init / grid_init.abs().max(axis=-1, keepdims=True)[0]).reshape(self.n_heads, 1, 1, 2)
        grid_init = mx.tile(grid_init, (1, self.n_levels, self.n_points, 1))
        for i in range(self.n_points):
            grid_init[:, :, i, :] *= i + 1
        self.sampling_offsets.bias = mx.array(grid_init.reshape(-1))
        nn.init.constant(self.attention_weights.weight, 0.0)
        nn.init.constant(self.attention_weights.bias, 0.0)
        nn.init.uniform(self.value_proj.weight)
        nn.init.constant(self.value_proj.bias, 0.0)
        nn.init.uniform(self.output_proj.weight)
        nn.init.constant(self.output_proj.bias, 0.0)

    def __call__(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
        N, Len_q, _ = query.shape
        N, Len_in, _ = input_flatten.shape

        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
        value = self.value_proj(input_flatten)
        if input_padding_mask is not None:
            value = value * (1 - input_padding_mask[..., None].astype(value.dtype))
        value = value.reshape(N, Len_in, self.n_heads, self.d_model // self.n_heads)
        sampling_offsets = self.sampling_offsets(query).reshape(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)

        attention_weights = self.attention_weights(query).reshape(N, Len_q, self.n_heads, self.n_levels * self.n_points)

        attention_weights = nn.softmax(attention_weights, axis=-1).reshape(N, Len_q, self.n_heads, self.n_levels, self.n_points)

        if reference_points.shape[-1] == 2:
            input_spatial_shapes_mx = mx.array(input_spatial_shapes)
            offset_normalizer = mx.stack([input_spatial_shapes_mx[:, 1], input_spatial_shapes_mx[:, 0]], axis=-1)
            sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
        else:
            raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]} instead.')

        output = ms_deform_attn_core_mlx(value, input_spatial_shapes, sampling_locations, attention_weights)
        output = self.output_proj(output)
        return output

@barronalex
Copy link
Collaborator

I think the stop_gradient call needs to be before you use sampling_locations to compute the output. Maybe try changing the last few lines above to:

        sampling_locations = mx.stop_gradient(sampling_locations)
        output = ms_deform_attn_core_mlx(value, input_spatial_shapes, sampling_locations, attention_weights)
        output = self.output_proj(output)
        return output

Let me know if you still see the error.

@sachinraja13
Copy link
Author

I tried stopping the gradient at sampling_grid_l_ inside the ms_deform_attn_core_mlx function. That avoided the error, however, it stopped the sampling_offsets layer from getting updated because of 0 gradient. I'm not sure how to make this work since I believe that I will have to take the differentiation of sampling_offsets with respect to the grid to have the weights of the sampling_offsets layer getting updated. Please correct me if I'm wrong.

@barronalex
Copy link
Collaborator

Very sorry you're completely right you do need the gradient with respect to the grid.

It looks like PyTorch has a custom backward pass implementation for grid_sample that gets around the issue.

The Scenic Deformable DETR implementation in Jax does seem to rely on differentiating w.r.t indices through a gather operation though so maybe it is something we should support. I can take a closer look at that later today.

@sachinraja13
Copy link
Author

sachinraja13 commented Jun 18, 2024

Hi @barronalex : Many thanks for your help in this regard. I will be waiting to hear from you.

@sachinraja13
Copy link
Author

Hi @barronalex : Do you have any insights on this?

@petertsoi
Copy link

+1 I believe this is might also be necessary for grid-based NeRF representations as described in InstantNGP.

@barronalex
Copy link
Collaborator

Sorry for the delay! To fix this, I'm adding zero VJPs to gather w.r.t. indices and the bitwise ops in #1256

With that, I'm able to get an MLX grid_sample to match PyTorch's outputs and gradients (at least for mode="bilinear", padding_mode="zeros"):

import mlx.core as mx 
import numpy as np

import torch
import torch.nn.functional as F

def grid_sample_mx(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False):
    if align_corners:
        raise NotImplementedError("`align_corners=True` not yet implemented.")

    if padding_mode != 'zeros':
        raise NotImplementedError(f"padding_mode={padding_mode} not yet implemented.")

    if mode != "bilinear":
        raise NotImplementedError(f"mode={mode} not yet implemented.")

    N, C, H_in, W_in = input.shape
    N, H_out, W_out, _ = grid.shape
    # Normalize the grid to [0, H_in-1] and [0, W_in-1]
    ix = ((grid[..., 0] + 1) * W_in - 1) / 2
    iy = ((grid[..., 1] + 1) * H_in - 1) / 2
 
    if mode == 'bilinear':
        ix_nw = mx.floor(ix).astype(mx.int32)
        iy_nw = mx.floor(iy).astype(mx.int32)

        ix_ne = ix_nw + 1
        iy_ne = iy_nw

        ix_sw = ix_nw
        iy_sw = iy_nw + 1

        ix_se = ix_nw + 1
        iy_se = iy_nw + 1

        nw = (ix_se - ix)    * (iy_se - iy)
        ne = (ix    - ix_sw) * (iy_sw - iy)
        sw = (ix_ne - ix)    * (iy    - iy_ne)
        se = (ix    - ix_nw) * (iy    - iy_nw)

        I_nw = input[mx.arange(N)[:, None, None], :, iy_nw, ix_nw]
        I_ne = input[mx.arange(N)[:, None, None], :, iy_ne, ix_ne]
        I_sw = input[mx.arange(N)[:, None, None], :, iy_sw, ix_sw]
        I_se = input[mx.arange(N)[:, None, None], :, iy_se, ix_se]

        mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
        mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
        mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
        mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)

        I_nw *= mask_nw[..., None]
        I_ne *= mask_ne[..., None]
        I_sw *= mask_sw[..., None]
        I_se *= mask_se[..., None]

        output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se

    return output


np.random.seed(7)
n, m, gn, gm = 32, 16, 8, 4
x = np.random.normal(size=(3, 7, n, m)).astype(np.float32)
grid = np.random.uniform(-1.5, 1, size=(3, gn, gm, 2)).astype(np.float32)

x_t = torch.tensor(x, requires_grad=True)
grid_t = torch.tensor(grid, requires_grad=True)
out_t = F.grid_sample(x_t, grid_t, mode="bilinear", padding_mode="zeros")
loss_t = out_t.sum()
loss_t.backward()
out_t = out_t.cpu().detach().numpy().transpose((0, 2, 3, 1))

x = mx.array(x)
grid = mx.array(grid)

def grid_sample(x, grid):
    return grid_sample_mx(x, grid, mode="bilinear", padding_mode="zeros").sum()

loss, (grad_x, grad_grid) = mx.value_and_grad(grid_sample, argnums=(0, 1))(x, grid)

np.testing.assert_allclose(x_t.grad.cpu().numpy(), grad_x, atol=1e-4)
np.testing.assert_allclose(grid_t.grad.cpu().numpy(), grad_grid, atol=1e-4)
np.testing.assert_allclose(loss_t.detach().cpu().numpy(), loss, atol=1e-4)

@barronalex
Copy link
Collaborator

If this is a big performance bottleneck then we could also consider adding a custom GPU op some time in the future but hopefully this unblocks you for now.

@sachinraja13
Copy link
Author

Hi @barronalex : Many thanks for your help. Hoping to see a custom VJP for gpu in the future. But of course this unblocks me for now. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants