Skip to content

Commit

Permalink
Support checkpoint_wrapper (#10943)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid committed Sep 18, 2023
1 parent 75c2ada commit dfe7a57
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
16 changes: 16 additions & 0 deletions mmdet/models/layers/transformer/deformable_detr_layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Optional, Tuple, Union

import torch
Expand All @@ -12,6 +13,11 @@
DetrTransformerEncoder, DetrTransformerEncoderLayer)
from .utils import inverse_sigmoid

try:
from fairscale.nn.checkpoint import checkpoint_wrapper
except Exception:
checkpoint_wrapper = None


class DeformableDetrTransformerEncoder(DetrTransformerEncoder):
"""Transformer encoder of Deformable DETR."""
Expand All @@ -22,6 +28,16 @@ def _init_layers(self) -> None:
DeformableDetrTransformerEncoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])

if self.num_cp > 0:
if checkpoint_wrapper is None:
warnings.warn('If you want to reduce GPU memory usage, \
please install fairscale by executing the \
following command: pip install fairscale.')
return
for i in range(self.num_cp):
self.layers[i] = checkpoint_wrapper(self.layers[i])

self.embed_dims = self.layers[0].embed_dims

def forward(self, query: Tensor, query_pos: Tensor,
Expand Down
21 changes: 21 additions & 0 deletions mmdet/models/layers/transformer/detr_layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Union

import torch
Expand All @@ -10,6 +11,11 @@

from mmdet.utils import ConfigType, OptConfigType

try:
from fairscale.nn.checkpoint import checkpoint_wrapper
except Exception:
checkpoint_wrapper = None


class DetrTransformerEncoder(BaseModule):
"""Encoder of DETR.
Expand All @@ -18,18 +24,23 @@ class DetrTransformerEncoder(BaseModule):
num_layers (int): Number of encoder layers.
layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
layer. All the layers will share the same config.
num_cp (int): Number of checkpointing blocks in encoder layer.
Default to -1.
init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
the initialization. Defaults to None.
"""

def __init__(self,
num_layers: int,
layer_cfg: ConfigType,
num_cp: int = -1,
init_cfg: OptConfigType = None) -> None:

super().__init__(init_cfg=init_cfg)
self.num_layers = num_layers
self.layer_cfg = layer_cfg
self.num_cp = num_cp
assert self.num_cp <= self.num_layers
self._init_layers()

def _init_layers(self) -> None:
Expand All @@ -38,6 +49,16 @@ def _init_layers(self) -> None:
DetrTransformerEncoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])

if self.num_cp > 0:
if checkpoint_wrapper is None:
warnings.warn('If you want to reduce GPU memory usage, \
please install fairscale by executing the \
following command: pip install fairscale.')
return
for i in range(self.num_cp):
self.layers[i] = checkpoint_wrapper(self.layers[i])

self.embed_dims = self.layers[0].embed_dims

def forward(self, query: Tensor, query_pos: Tensor,
Expand Down
2 changes: 1 addition & 1 deletion mmdet/models/layers/transformer/grounding_dino_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def forward(self,
Defaults to None.
memory_text (Tensor): Memory text. It has shape (bs, len_text,
text_embed_dims).
text_token_mask (Tensor): Text token mask. It has shape (bs,
text_attention_mask (Tensor): Text token mask. It has shape (bs,
len_text).
Returns:
Expand Down

0 comments on commit dfe7a57

Please sign in to comment.