diff --git a/mmdet/models/layers/transformer/deformable_detr_layers.py b/mmdet/models/layers/transformer/deformable_detr_layers.py index f337e7fd01b..e2d32388d6a 100644 --- a/mmdet/models/layers/transformer/deformable_detr_layers.py +++ b/mmdet/models/layers/transformer/deformable_detr_layers.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from typing import Optional, Tuple, Union import torch @@ -12,6 +13,11 @@ DetrTransformerEncoder, DetrTransformerEncoderLayer) from .utils import inverse_sigmoid +try: + from fairscale.nn.checkpoint import checkpoint_wrapper +except Exception: + checkpoint_wrapper = None + class DeformableDetrTransformerEncoder(DetrTransformerEncoder): """Transformer encoder of Deformable DETR.""" @@ -22,6 +28,16 @@ def _init_layers(self) -> None: DeformableDetrTransformerEncoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) + + if self.num_cp > 0: + if checkpoint_wrapper is None: + warnings.warn('If you want to reduce GPU memory usage, \ + please install fairscale by executing the \ + following command: pip install fairscale.') + return + for i in range(self.num_cp): + self.layers[i] = checkpoint_wrapper(self.layers[i]) + self.embed_dims = self.layers[0].embed_dims def forward(self, query: Tensor, query_pos: Tensor, diff --git a/mmdet/models/layers/transformer/detr_layers.py b/mmdet/models/layers/transformer/detr_layers.py index 43c2ffdb631..928b07ce2df 100644 --- a/mmdet/models/layers/transformer/detr_layers.py +++ b/mmdet/models/layers/transformer/detr_layers.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from typing import Union import torch @@ -10,6 +11,11 @@ from mmdet.utils import ConfigType, OptConfigType +try: + from fairscale.nn.checkpoint import checkpoint_wrapper +except Exception: + checkpoint_wrapper = None + class DetrTransformerEncoder(BaseModule): """Encoder of DETR. @@ -18,6 +24,8 @@ class DetrTransformerEncoder(BaseModule): num_layers (int): Number of encoder layers. layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder layer. All the layers will share the same config. + num_cp (int): Number of checkpointing blocks in encoder layer. + Default to -1. init_cfg (:obj:`ConfigDict` or dict, optional): the config to control the initialization. Defaults to None. """ @@ -25,11 +33,14 @@ class DetrTransformerEncoder(BaseModule): def __init__(self, num_layers: int, layer_cfg: ConfigType, + num_cp: int = -1, init_cfg: OptConfigType = None) -> None: super().__init__(init_cfg=init_cfg) self.num_layers = num_layers self.layer_cfg = layer_cfg + self.num_cp = num_cp + assert self.num_cp <= self.num_layers self._init_layers() def _init_layers(self) -> None: @@ -38,6 +49,16 @@ def _init_layers(self) -> None: DetrTransformerEncoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) + + if self.num_cp > 0: + if checkpoint_wrapper is None: + warnings.warn('If you want to reduce GPU memory usage, \ + please install fairscale by executing the \ + following command: pip install fairscale.') + return + for i in range(self.num_cp): + self.layers[i] = checkpoint_wrapper(self.layers[i]) + self.embed_dims = self.layers[0].embed_dims def forward(self, query: Tensor, query_pos: Tensor, diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py index 04de47288b3..645384bd014 100644 --- a/mmdet/models/layers/transformer/grounding_dino_layers.py +++ b/mmdet/models/layers/transformer/grounding_dino_layers.py @@ -86,7 +86,7 @@ def forward(self, Defaults to None. memory_text (Tensor): Memory text. It has shape (bs, len_text, text_embed_dims). - text_token_mask (Tensor): Text token mask. It has shape (bs, + text_attention_mask (Tensor): Text token mask. It has shape (bs, len_text). Returns: