Skip to content

Commit

Permalink
[Fix] Replace interpolate with resize (open-mmlab#731)
Browse files Browse the repository at this point in the history
* Replace interpolate with resize

* Replace nn.Upsample with ops.Upsample

* Fix test
  • Loading branch information
mmeendez8 committed Jul 28, 2021
1 parent b5ae7a7 commit 50461ef
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 24 deletions.
3 changes: 2 additions & 1 deletion mmseg/models/backbones/swin.py
Expand Up @@ -13,6 +13,7 @@
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.utils import _pair as to_2tuple

from mmseg.ops import resize
from ...utils import get_root_logger
from ..builder import ATTENTION, BACKBONES
from ..utils import PatchEmbed, swin_convert
Expand Down Expand Up @@ -745,7 +746,7 @@ def init_weights(self):
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = F.interpolate(
table_pretrained_resized = resize(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/backbones/unet.py
Expand Up @@ -7,6 +7,7 @@
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmseg.ops import Upsample
from ..builder import BACKBONES
from ..utils import UpConvBlock

Expand Down Expand Up @@ -203,7 +204,7 @@ def __init__(self,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
upsample = nn.Upsample(**upsample_cfg)
upsample = Upsample(**upsample_cfg)
if conv_first:
self.interp_upsample = nn.Sequential(conv, upsample)
else:
Expand Down
4 changes: 2 additions & 2 deletions mmseg/models/backbones/vit.py
Expand Up @@ -3,14 +3,14 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
normal_init, trunc_normal_init)
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple

from mmseg.ops import resize
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, vit_convert
Expand Down Expand Up @@ -373,7 +373,7 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(
pos_embed_weight = resize(
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
Expand Down
4 changes: 2 additions & 2 deletions mmseg/models/decode_heads/fpn_head.py
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.ops import resize
from mmseg.ops import Upsample, resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead

Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, feature_strides, **kwargs):
act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]:
scale_head.append(
nn.Upsample(
Upsample(
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners))
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/setr_mla_head.py
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.ops import Upsample
from ..builder import HEADS
from .decode_head import BaseDecodeHead

Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(self, mla_channels=128, up_scale=4, **kwargs):
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/setr_up_head.py
@@ -1,6 +1,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer

from mmseg.ops import Upsample
from ..builder import HEADS
from .decode_head import BaseDecodeHead

Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(self,
padding=int(kernel_size - 1) // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
Expand Down
6 changes: 3 additions & 3 deletions mmseg/models/necks/fpn.py
Expand Up @@ -3,6 +3,7 @@
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16

from mmseg.ops import resize
from ..builder import NECKS


Expand Down Expand Up @@ -173,11 +174,10 @@ def forward(self, inputs):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] += F.interpolate(laterals[i],
**self.upsample_cfg)
laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
else:
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(
laterals[i - 1] += resize(
laterals[i], size=prev_shape, **self.upsample_cfg)

# build outputs
Expand Down
4 changes: 2 additions & 2 deletions mmseg/models/necks/multilevel_neck.py
@@ -1,7 +1,7 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init

from mmseg.ops import resize
from ..builder import NECKS


Expand Down Expand Up @@ -70,7 +70,7 @@ def forward(self, inputs):
inputs = [inputs[0] for _ in range(self.num_outs)]
outs = []
for i in range(self.num_outs):
x_resize = F.interpolate(
x_resize = resize(
inputs[i], scale_factor=self.scales[i], mode='bilinear')
outs.append(self.convs[i](x_resize))
return tuple(outs)
10 changes: 5 additions & 5 deletions tests/test_models/test_backbones/test_unet.py
@@ -1,10 +1,10 @@
import pytest
import torch
from mmcv.cnn import ConvModule
from torch import nn

from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock)
from mmseg.ops import Upsample
from .utils import check_norm_state


Expand Down Expand Up @@ -145,7 +145,7 @@ def test_interp_conv():
block = InterpConv(64, 32, conv_first=False)
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])

Expand All @@ -154,7 +154,7 @@ def test_interp_conv():
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], ConvModule)
assert isinstance(block.interp_upsample[1], nn.Upsample)
assert isinstance(block.interp_upsample[1], Upsample)
assert x_out.shape == torch.Size([1, 32, 256, 256])

# test InterpConv with bilinear upsample for upsample 2X.
Expand All @@ -166,7 +166,7 @@ def test_interp_conv():
scale_factor=2, mode='bilinear', align_corners=False))
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'bilinear'
Expand All @@ -179,7 +179,7 @@ def test_interp_conv():
upsample_cfg=dict(scale_factor=2, mode='nearest'))
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'nearest'
Expand Down
5 changes: 3 additions & 2 deletions tools/deploy_test.py
Expand Up @@ -14,6 +14,7 @@
from mmseg.apis import single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models.segmentors.base import BaseSegmentor
from mmseg.ops import resize


class ONNXRuntimeSegmentor(BaseSegmentor):
Expand Down Expand Up @@ -79,7 +80,7 @@ def simple_test(self, img: torch.Tensor, img_meta: Iterable,
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate(
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
Expand Down Expand Up @@ -127,7 +128,7 @@ def simple_test(self, img: torch.Tensor, img_meta: Iterable,
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate(
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
Expand Down
6 changes: 2 additions & 4 deletions tools/pytorch2onnx.py
Expand Up @@ -16,6 +16,7 @@
from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor
from mmseg.ops import resize

torch.manual_seed(3)

Expand Down Expand Up @@ -210,10 +211,7 @@ def pytorch2onnx(model,

if dynamic_export and test_mode == 'whole':
# scale image for dynamic shape test
img_list = [
nn.functional.interpolate(_, scale_factor=1.5)
for _ in img_list
]
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
Expand Down

0 comments on commit 50461ef

Please sign in to comment.