diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 6e82e84fed..ed32688af4 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -102,27 +102,6 @@ def __init__(self, self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) - if self.batch_first: - - def _bnc_to_nbc(forward): - """Because the dataflow('key', 'query', 'value') of - ``torch.nn.MultiheadAttention`` is (num_query, batch, - embed_dims), We should adjust the shape of dataflow from - batch_first (batch, num_query, embed_dims) to num_query_first - (num_query ,batch, embed_dims), and recover ``attn_output`` - from num_query_first to batch_first.""" - - def forward_wrapper(**kwargs): - convert_keys = ('key', 'query', 'value') - for key in kwargs.keys(): - if key in convert_keys: - kwargs[key] = kwargs[key].transpose(0, 1) - attn_output, attn_output_weights = forward(**kwargs) - return attn_output.transpose(0, 1), attn_output_weights - - return forward_wrapper - - self.attn.forward = _bnc_to_nbc(self.attn.forward) self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = build_dropout( @@ -199,6 +178,17 @@ def forward(self, if key_pos is not None: key = key + key_pos + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + out = self.attn( query=query, key=key, @@ -206,6 +196,9 @@ def forward(self, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] + if self.batch_first: + out = out.transpose(0, 1) + return identity + self.dropout_layer(self.proj_drop(out)) diff --git a/tests/test_cnn/test_transformer.py b/tests/test_cnn/test_transformer.py index a4a5f62e9c..106753b423 100644 --- a/tests/test_cnn/test_transformer.py +++ b/tests/test_cnn/test_transformer.py @@ -1,3 +1,5 @@ +import copy + import pytest import torch @@ -5,6 +7,7 @@ from mmcv.cnn.bricks.transformer import (FFN, BaseTransformerLayer, MultiheadAttention, TransformerLayerSequence) +from mmcv.runner import ModuleList def test_multiheadattention(): @@ -92,6 +95,28 @@ def test_ffn(): ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available') +def test_basetransformerlayer_cuda(): + # To test if the BaseTransformerLayer's behaviour remains + # consistent after being deepcopied + operation_order = ('self_attn', 'ffn') + baselayer = BaseTransformerLayer( + operation_order=operation_order, + batch_first=True, + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + ), + ) + baselayers = ModuleList([copy.deepcopy(baselayer) for _ in range(2)]) + baselayers.to('cuda') + x = torch.rand(2, 10, 256).cuda() + for m in baselayers: + x = m(x) + assert x.shape == torch.Size([2, 10, 256]) + + def test_basetransformerlayer(): attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8), feedforward_channels = 2048