Skip to content

Commit

Permalink
Update to support torch2onnx for DETR series models (#10910)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Sep 12, 2023
1 parent 82d2a6e commit dece858
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 62 deletions.
109 changes: 70 additions & 39 deletions mmdet/models/detectors/deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,37 @@ def pre_transformer(
# construct binary masks for the transformer.
assert batch_data_samples is not None
batch_input_shape = batch_data_samples[0].batch_input_shape
img_shape_list = [sample.img_shape for sample in batch_data_samples]
input_img_h, input_img_w = batch_input_shape
masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w = img_shape_list[img_id]
masks[img_id, :img_h, :img_w] = 0
# NOTE following the official DETR repo, non-zero values representing
# ignored positions, while zero values means valid positions.

mlvl_masks = []
mlvl_pos_embeds = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1]))
img_shape_list = [sample.img_shape for sample in batch_data_samples]
same_shape_flag = all([
s[0] == input_img_h and s[1] == input_img_w for s in img_shape_list
])
# support torch2onnx without feeding masks
if torch.onnx.is_in_onnx_export() or same_shape_flag:
mlvl_masks = []
mlvl_pos_embeds = []
for feat in mlvl_feats:
mlvl_masks.append(None)
mlvl_pos_embeds.append(
self.positional_encoding(None, input=feat))
else:
masks = mlvl_feats[0].new_ones(
(batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w = img_shape_list[img_id]
masks[img_id, :img_h, :img_w] = 0
# NOTE following the official DETR repo, non-zero
# values representing ignored positions, while
# zero values means valid positions.

mlvl_masks = []
mlvl_pos_embeds = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(masks[None], size=feat.shape[-2:]).to(
torch.bool).squeeze(0))
mlvl_pos_embeds.append(
self.positional_encoding(mlvl_masks[-1]))

feat_flatten = []
lvl_pos_embed_flatten = []
Expand All @@ -175,13 +190,14 @@ def pre_transformer(
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
batch_size, c, h, w = feat.shape
spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device)
# [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
feat = feat.view(batch_size, c, -1).permute(0, 2, 1)
pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
# [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
mask = mask.flatten(1)
spatial_shape = (h, w)
if mask is not None:
mask = mask.flatten(1)

feat_flatten.append(feat)
lvl_pos_embed_flatten.append(lvl_pos_embed)
Expand All @@ -192,17 +208,22 @@ def pre_transformer(
feat_flatten = torch.cat(feat_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
mask_flatten = torch.cat(mask_flatten, 1)
if mask_flatten[0] is not None:
mask_flatten = torch.cat(mask_flatten, 1)
else:
mask_flatten = None

spatial_shapes = torch.as_tensor( # (num_level, 2)
spatial_shapes,
dtype=torch.long,
device=feat_flatten.device)
# (num_level, 2)
spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
level_start_index = torch.cat((
spatial_shapes.new_zeros((1, )), # (num_level)
spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack( # (bs, num_level, 2)
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
if mlvl_masks[0] is not None:
valid_ratios = torch.stack( # (bs, num_level, 2)
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
else:
valid_ratios = mlvl_feats[0].new_ones(batch_size, len(mlvl_feats),
2)

encoder_inputs_dict = dict(
feat=feat_flatten,
Expand Down Expand Up @@ -465,39 +486,49 @@ def gen_encoder_output_proposals(
bs = memory.size(0)
proposals = []
_cur = 0 # start index in the sequence of the current level
for lvl, (H, W) in enumerate(spatial_shapes):
mask_flatten_ = memory_mask[:,
_cur:(_cur + H * W)].view(bs, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1)

for lvl, HW in enumerate(spatial_shapes):
H, W = HW

if memory_mask is not None:
mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view(
bs, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0],
1).unsqueeze(-1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0],
1).unsqueeze(-1)
scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
else:
if not isinstance(HW, torch.Tensor):
HW = memory.new_tensor(HW)
scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(bs, 1, 1, 2)
grid_y, grid_x = torch.meshgrid(
torch.linspace(
0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(
0, W - 1, W, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
proposals.append(proposal)
_cur += (H * W)
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) &
(output_proposals < 0.99)).all(
-1, keepdim=True)
# do not use `all` to make it exportable to onnx
output_proposals_valid = (
(output_proposals > 0.01) & (output_proposals < 0.99)).sum(
-1, keepdim=True) == output_proposals.shape[-1]
# inverse_sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(
memory_mask.unsqueeze(-1), float('inf'))
if memory_mask is not None:
output_proposals = output_proposals.masked_fill(
memory_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float('inf'))

output_memory = memory
output_memory = output_memory.masked_fill(
memory_mask.unsqueeze(-1), float(0))
if memory_mask is not None:
output_memory = output_memory.masked_fill(
memory_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid,
float(0))
output_memory = self.memory_trans_fc(output_memory)
Expand Down
37 changes: 23 additions & 14 deletions mmdet/models/detectors/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,36 @@ def pre_transformer(
# construct binary masks which for the transformer.
assert batch_data_samples is not None
batch_input_shape = batch_data_samples[0].batch_input_shape
img_shape_list = [sample.img_shape for sample in batch_data_samples]

input_img_h, input_img_w = batch_input_shape
masks = feat.new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w = img_shape_list[img_id]
masks[img_id, :img_h, :img_w] = 0
# NOTE following the official DETR repo, non-zero values represent
# ignored positions, while zero values mean valid positions.

masks = F.interpolate(
masks.unsqueeze(1), size=feat.shape[-2:]).to(torch.bool).squeeze(1)
# [batch_size, embed_dim, h, w]
pos_embed = self.positional_encoding(masks)
img_shape_list = [sample.img_shape for sample in batch_data_samples]
same_shape_flag = all([
s[0] == input_img_h and s[1] == input_img_w for s in img_shape_list
])
if torch.onnx.is_in_onnx_export() or same_shape_flag:
masks = None
# [batch_size, embed_dim, h, w]
pos_embed = self.positional_encoding(masks, input=feat)
else:
masks = feat.new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w = img_shape_list[img_id]
masks[img_id, :img_h, :img_w] = 0
# NOTE following the official DETR repo, non-zero values represent
# ignored positions, while zero values mean valid positions.

masks = F.interpolate(
masks.unsqueeze(1),
size=feat.shape[-2:]).to(torch.bool).squeeze(1)
# [batch_size, embed_dim, h, w]
pos_embed = self.positional_encoding(masks)

# use `view` instead of `flatten` for dynamically exporting to ONNX
# [bs, c, h, w] -> [bs, h*w, c]
feat = feat.view(batch_size, feat_dim, -1).permute(0, 2, 1)
pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(0, 2, 1)
# [bs, h, w] -> [bs, h*w]
masks = masks.view(batch_size, -1)
if masks is not None:
masks = masks.view(batch_size, -1)

# prepare transformer_inputs_dict
encoder_inputs_dict = dict(
Expand Down
37 changes: 28 additions & 9 deletions mmdet/models/layers/positional_encoding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -56,36 +57,54 @@ def __init__(self,
self.eps = eps
self.offset = offset

def forward(self, mask: Tensor) -> Tensor:
def forward(self, mask: Tensor, input: Optional[Tensor] = None) -> Tensor:
"""Forward function for `SinePositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
input (Tensor, optional): Input image/feature Tensor.
Shape [bs, c, h, w]
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
# For convenience of exporting to ONNX, it's required to convert
# `masks` from bool to int.
mask = mask.to(torch.int)
not_mask = 1 - mask # logical_not
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
assert not (mask is None and input is None)

if mask is not None:
B, H, W = mask.size()
device = mask.device
# For convenience of exporting to ONNX,
# it's required to convert
# `masks` from bool to int.
mask = mask.to(torch.int)
not_mask = 1 - mask # logical_not
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
else:
# single image or batch image with no padding
B, _, H, W = input.shape
device = input.device
x_embed = torch.arange(
1, W + 1, dtype=torch.float32, device=device)
x_embed = x_embed.view(1, 1, -1).repeat(B, H, 1)
y_embed = torch.arange(
1, H + 1, dtype=torch.float32, device=device)
y_embed = y_embed.view(1, -1, 1).repeat(B, 1, W)
if self.normalize:
y_embed = (y_embed + self.offset) / \
(y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = (x_embed + self.offset) / \
(x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = torch.arange(
self.num_feats, dtype=torch.float32, device=mask.device)
self.num_feats, dtype=torch.float32, device=device)
dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
# use `view` instead of `flatten` for dynamically exporting to ONNX
B, H, W = mask.size()

pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
dim=4).view(B, H, W, -1)
Expand Down

0 comments on commit dece858

Please sign in to comment.