From 0b6a24da71e77f21e3af3f0de3e31a068b429c98 Mon Sep 17 00:00:00 2001 From: wenqingzhang Date: Fri, 14 Oct 2022 13:36:51 +0000 Subject: [PATCH 1/4] support modified ResNet in CLIP and oCLIP --- mmocr/models/common/backbones/__init__.py | 3 +- mmocr/models/common/backbones/clip_resnet.py | 85 +++++++++++++++++++ mmocr/models/common/plugins/__init__.py | 4 + mmocr/models/common/plugins/common.py | 39 +++++++++ .../test_backbones/test_clip_resnet.py | 66 ++++++++++++++ .../test_common/test_plugins/test_avgpool.py | 16 ++++ 6 files changed, 212 insertions(+), 1 deletion(-) create mode 100644 mmocr/models/common/backbones/clip_resnet.py create mode 100644 mmocr/models/common/plugins/__init__.py create mode 100644 mmocr/models/common/plugins/common.py create mode 100644 tests/test_models/test_common/test_backbones/test_clip_resnet.py create mode 100644 tests/test_models/test_common/test_plugins/test_avgpool.py diff --git a/mmocr/models/common/backbones/__init__.py b/mmocr/models/common/backbones/__init__.py index 3c384ba30..7e0c98ef7 100644 --- a/mmocr/models/common/backbones/__init__.py +++ b/mmocr/models/common/backbones/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .clip_resnet import CLIPResNet from .unet import UNet -__all__ = ['UNet'] +__all__ = ['UNet', 'CLIPResNet'] diff --git a/mmocr/models/common/backbones/clip_resnet.py b/mmocr/models/common/backbones/clip_resnet.py new file mode 100644 index 000000000..bd4865978 --- /dev/null +++ b/mmocr/models/common/backbones/clip_resnet.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch.nn as nn +from mmdet.models.backbones import ResNet +from mmdet.models.backbones.resnet import Bottleneck + +from mmocr.registry import MODELS + + +class CLIPBottleneck(Bottleneck): + """Bottleneck for CLIPResNet. + + It is a variant Bottleneck used in the variant ResNet of CLIP. After the + second convolution layer, there is an additional average pooling layer with + kernel_size 2 and stride 2, which is added in the plug-in manner when the + input stride > 1. The stride of each convolution layer is always set to 1. + """ + + def __init__(self, **kwargs): + stride = kwargs.get('stride', 1) + kwargs['stride'] = 1 + plugins = kwargs.get('plugins', None) + if stride > 1: + if plugins is None: + plugins = [] + + plugins.insert( + 0, + dict( + cfg=dict(type='mmocr.AvgPool2d', kernel_size=2), + position='after_conv2')) + kwargs['plugins'] = plugins + super().__init__(**kwargs) + + +@MODELS.register_module() +class CLIPResNet(ResNet): + """Implement the variant ResNet used in `oCLIP. + + `_ + + It is also the official structure in `CLIP + `_. + + Compared with ResNetV1d structure, CLIPResNet replaces the + max pooling layer with an average pooling layer at the end + of the input stem. + + In the Bottleneck of CLIPResNet, after the second convolution + layer, there is an additional average pooling layer with + kernel_size 2 and stride 2, which is added in the plug-in + manner when the input stride > 1. + The stride of each convolution layer is always set to 1. + """ + arch_settings = { + 50: (CLIPBottleneck, (3, 4, 6, 3)), + } + + def __init__(self, + depth=50, + strides=(1, 2, 2, 2), + deep_stem=True, + avg_down=True, + **kwargs): + super().__init__( + depth=depth, + strides=strides, + deep_stem=deep_stem, + avg_down=avg_down, + **kwargs) + + def _make_stem_layer(self, in_channels, stem_channels): + """Build stem layer for CLIPResNet used in `CLIP. + + `_. + It uses an average pooling layer rather than a max pooling + layer at the end of the input stem. + + Args: + in_channels (int): Number of input channels. + stem_channels (int): Number of output channels. + """ + super()._make_stem_layer(in_channels, stem_channels) + if self.deep_stem: + self.maxpool = nn.AvgPool2d(kernel_size=2) diff --git a/mmocr/models/common/plugins/__init__.py b/mmocr/models/common/plugins/__init__.py new file mode 100644 index 000000000..1ad4c93c0 --- /dev/null +++ b/mmocr/models/common/plugins/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .common import AvgPool2d + +__all__ = ['AvgPool2d'] diff --git a/mmocr/models/common/plugins/common.py b/mmocr/models/common/plugins/common.py new file mode 100644 index 000000000..eeae464c4 --- /dev/null +++ b/mmocr/models/common/plugins/common.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn + +from mmocr.registry import MODELS + + +@MODELS.register_module() +class AvgPool2d(nn.Module): + """Applies a 2D average pooling over an input signal composed of several + input planes. + + AvgPool2d class for plug-in manner usage + + Args: + kernel_size (int | tuple(int)): the size of the window. + stride (int | tuple(int)): the stride of the window. + padding (int | tuple(int)): implicit zero padding. + """ + + def __init__(self, + kernel_size: Union[int, Tuple[int]], + stride: Union[int, Tuple[int]] = None, + padding: Union[int, Tuple[int]] = 0, + **kwargs) -> None: + super().__init__() + self.model = nn.AvgPool2d(kernel_size, stride, padding) + + def forward(self, x) -> torch.Tensor: + """Forward function. + Args: + x (Tensor): Input feature map. + + Returns: + Tensor: Output tensor after Maxpooling layer. + """ + return self.model(x) diff --git a/tests/test_models/test_common/test_backbones/test_clip_resnet.py b/tests/test_models/test_common/test_backbones/test_clip_resnet.py new file mode 100644 index 000000000..12830817f --- /dev/null +++ b/tests/test_models/test_common/test_backbones/test_clip_resnet.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmocr.models.common.backbones import CLIPResNet +from mmocr.models.common.backbones.clip_resnet import CLIPBottleneck + + +class TestCLIPResNet(TestCase): + + def test_forward(self): + model = CLIPResNet() + model.eval() + + imgs = torch.randn(1, 3, 32, 32) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 8, 8]) + assert feat[1].shape == torch.Size([1, 512, 4, 4]) + assert feat[2].shape == torch.Size([1, 1024, 2, 2]) + assert feat[3].shape == torch.Size([1, 2048, 1, 1]) + + +class TestCLIPBottleneck(TestCase): + + def test_forward(self): + stride = 1 + inplanes = 64 + planes = 64 + conv_cfg = None + norm_cfg = {'type': 'BN', 'requires_grad': True} + + downsample = [] + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * CLIPBottleneck.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(norm_cfg, planes * CLIPBottleneck.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + model = CLIPBottleneck( + inplanes=64, + planes=64, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + model.eval() + + input_feat = torch.randn(1, 64, 8, 8) + output_feat = model(input_feat) + assert output_feat.shape == torch.Size([1, 256, 8, 8]) diff --git a/tests/test_models/test_common/test_plugins/test_avgpool.py b/tests/test_models/test_common/test_plugins/test_avgpool.py new file mode 100644 index 000000000..766ddf5d6 --- /dev/null +++ b/tests/test_models/test_common/test_plugins/test_avgpool.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.common.plugins import AvgPool2d + + +class TestAvgPool2d(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 3, 32, 100) + + def test_avgpool2d(self): + avgpool2d = AvgPool2d(kernel_size=2, stride=2) + self.assertEqual(avgpool2d(self.img).shape, torch.Size([1, 3, 16, 50])) From e410ff37c64263396b7883eec88976b8791bbf1f Mon Sep 17 00:00:00 2001 From: wenqingzhang Date: Wed, 2 Nov 2022 14:12:05 +0000 Subject: [PATCH 2/4] update unit test for TestCLIPBottleneck; update docs --- mmocr/models/common/backbones/clip_resnet.py | 35 +++++++++++++------ mmocr/models/common/plugins/common.py | 15 ++++---- .../test_backbones/test_clip_resnet.py | 16 ++++----- 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/mmocr/models/common/backbones/clip_resnet.py b/mmocr/models/common/backbones/clip_resnet.py index bd4865978..82b36324a 100644 --- a/mmocr/models/common/backbones/clip_resnet.py +++ b/mmocr/models/common/backbones/clip_resnet.py @@ -10,10 +10,14 @@ class CLIPBottleneck(Bottleneck): """Bottleneck for CLIPResNet. - It is a variant Bottleneck used in the variant ResNet of CLIP. After the + It is a Bottleneck variant used in the ResNet variant of CLIP. After the second convolution layer, there is an additional average pooling layer with - kernel_size 2 and stride 2, which is added in the plug-in manner when the + kernel_size 2 and stride 2, which is added as a plugin when the input stride > 1. The stride of each convolution layer is always set to 1. + + Args: + **kwargs: Keyword arguments for + :class:``mmdet.models.backbones.resnet.Bottleneck``. """ def __init__(self, **kwargs): @@ -35,12 +39,12 @@ def __init__(self, **kwargs): @MODELS.register_module() class CLIPResNet(ResNet): - """Implement the variant ResNet used in `oCLIP. + """Implement the ResNet variant used in `oCLIP. - `_ + `_. - It is also the official structure in `CLIP - `_. + It is also the official structure in + `CLIP `_. Compared with ResNetV1d structure, CLIPResNet replaces the max pooling layer with an average pooling layer at the end @@ -48,9 +52,20 @@ class CLIPResNet(ResNet): In the Bottleneck of CLIPResNet, after the second convolution layer, there is an additional average pooling layer with - kernel_size 2 and stride 2, which is added in the plug-in - manner when the input stride > 1. + kernel_size 2 and stride 2, which is added as a plugin + when the input stride > 1. The stride of each convolution layer is always set to 1. + + Args: + depth (int): Depth of resnet, from {50}. Defaults to 50. + strides (sequence(int)): Strides of the first block of each stage. + Defaults to (1, 2, 2, 2). + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to True. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Defaults to True. + **kwargs: Keyword arguments for + :class:``mmdet.models.backbones.resnet.ResNet``. """ arch_settings = { 50: (CLIPBottleneck, (3, 4, 6, 3)), @@ -70,9 +85,9 @@ def __init__(self, **kwargs) def _make_stem_layer(self, in_channels, stem_channels): - """Build stem layer for CLIPResNet used in `CLIP. + """Build stem layer for CLIPResNet used in `CLIP + https://github.com/openai/CLIP>`_. - `_. It uses an average pooling layer rather than a max pooling layer at the end of the input stem. diff --git a/mmocr/models/common/plugins/common.py b/mmocr/models/common/plugins/common.py index eeae464c4..8c6ea19c0 100644 --- a/mmocr/models/common/plugins/common.py +++ b/mmocr/models/common/plugins/common.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -12,17 +12,18 @@ class AvgPool2d(nn.Module): """Applies a 2D average pooling over an input signal composed of several input planes. - AvgPool2d class for plug-in manner usage + It can also be used as a network plugin. Args: - kernel_size (int | tuple(int)): the size of the window. - stride (int | tuple(int)): the stride of the window. - padding (int | tuple(int)): implicit zero padding. + kernel_size (int or tuple(int)): the size of the window. + stride (int or tuple(int), optional): the stride of the window. + Defaults to None. + padding (int or tuple(int)): implicit zero padding. Defaults to 0. """ def __init__(self, kernel_size: Union[int, Tuple[int]], - stride: Union[int, Tuple[int]] = None, + stride: Optional[Union[int, Tuple[int]]] = None, padding: Union[int, Tuple[int]] = 0, **kwargs) -> None: super().__init__() @@ -34,6 +35,6 @@ def forward(self, x) -> torch.Tensor: x (Tensor): Input feature map. Returns: - Tensor: Output tensor after Maxpooling layer. + Tensor: Output tensor after Avgpooling layer. """ return self.model(x) diff --git a/tests/test_models/test_common/test_backbones/test_clip_resnet.py b/tests/test_models/test_common/test_backbones/test_clip_resnet.py index 12830817f..fd71395f8 100644 --- a/tests/test_models/test_common/test_backbones/test_clip_resnet.py +++ b/tests/test_models/test_common/test_backbones/test_clip_resnet.py @@ -27,9 +27,9 @@ def test_forward(self): class TestCLIPBottleneck(TestCase): def test_forward(self): - stride = 1 - inplanes = 64 - planes = 64 + stride = 2 + inplanes = 256 + planes = 128 conv_cfg = None norm_cfg = {'type': 'BN', 'requires_grad': True} @@ -46,21 +46,21 @@ def test_forward(self): inplanes, planes * CLIPBottleneck.expansion, kernel_size=1, - stride=stride, + stride=1, bias=False), build_norm_layer(norm_cfg, planes * CLIPBottleneck.expansion)[1] ]) downsample = nn.Sequential(*downsample) model = CLIPBottleneck( - inplanes=64, - planes=64, + inplanes=inplanes, + planes=planes, stride=stride, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg) model.eval() - input_feat = torch.randn(1, 64, 8, 8) + input_feat = torch.randn(1, 256, 8, 8) output_feat = model(input_feat) - assert output_feat.shape == torch.Size([1, 256, 8, 8]) + assert output_feat.shape == torch.Size([1, 512, 4, 4]) From 39c89457515f047d064a2a1df06e10a863b792ab Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Thu, 3 Nov 2022 15:32:57 +0800 Subject: [PATCH 3/4] Apply suggestions from code review --- mmocr/models/common/backbones/clip_resnet.py | 11 +++++------ mmocr/models/common/plugins/common.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mmocr/models/common/backbones/clip_resnet.py b/mmocr/models/common/backbones/clip_resnet.py index 82b36324a..aa908344f 100644 --- a/mmocr/models/common/backbones/clip_resnet.py +++ b/mmocr/models/common/backbones/clip_resnet.py @@ -39,8 +39,7 @@ def __init__(self, **kwargs): @MODELS.register_module() class CLIPResNet(ResNet): - """Implement the ResNet variant used in `oCLIP. - + """Implement the ResNet variant used in `oCLIP `_. It is also the official structure in @@ -57,13 +56,13 @@ class CLIPResNet(ResNet): The stride of each convolution layer is always set to 1. Args: - depth (int): Depth of resnet, from {50}. Defaults to 50. + depth (int): Depth of resnet, options are [50]. Defaults to 50. strides (sequence(int)): Strides of the first block of each stage. Defaults to (1, 2, 2, 2). deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. Defaults to True. - avg_down (bool): Use AvgPool instead of stride conv when - downsampling in the bottleneck. Defaults to True. + avg_down (bool): Use AvgPool instead of stride conv at + the downsampling stage in the bottleneck. Defaults to True. **kwargs: Keyword arguments for :class:``mmdet.models.backbones.resnet.ResNet``. """ @@ -84,7 +83,7 @@ def __init__(self, avg_down=avg_down, **kwargs) - def _make_stem_layer(self, in_channels, stem_channels): + def _make_stem_layer(self, in_channels: int, stem_channels: int): """Build stem layer for CLIPResNet used in `CLIP https://github.com/openai/CLIP>`_. diff --git a/mmocr/models/common/plugins/common.py b/mmocr/models/common/plugins/common.py index 8c6ea19c0..722b53f56 100644 --- a/mmocr/models/common/plugins/common.py +++ b/mmocr/models/common/plugins/common.py @@ -29,7 +29,7 @@ def __init__(self, super().__init__() self.model = nn.AvgPool2d(kernel_size, stride, padding) - def forward(self, x) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function. Args: x (Tensor): Input feature map. From 9a5293940c5a93a415668e462dc67c22ff421657 Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Thu, 3 Nov 2022 15:55:23 +0800 Subject: [PATCH 4/4] fix --- mmocr/models/common/backbones/clip_resnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmocr/models/common/backbones/clip_resnet.py b/mmocr/models/common/backbones/clip_resnet.py index aa908344f..4de20986b 100644 --- a/mmocr/models/common/backbones/clip_resnet.py +++ b/mmocr/models/common/backbones/clip_resnet.py @@ -39,7 +39,8 @@ def __init__(self, **kwargs): @MODELS.register_module() class CLIPResNet(ResNet): - """Implement the ResNet variant used in `oCLIP + """Implement the ResNet variant used in `oCLIP. + `_. It is also the official structure in