Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support resize relative position embedding in SwinTransformer. #749

Merged
merged 5 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion mmcls/models/backbones/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from mmcv.utils.parrots_wrapper import _BatchNorm

from ..builder import BACKBONES
from ..utils import ShiftWindowMSA, resize_pos_embed, to_2tuple
from ..utils import (ShiftWindowMSA, resize_pos_embed,
resize_relative_position_bias_table, to_2tuple)
from .base_backbone import BaseBackbone


Expand Down Expand Up @@ -352,6 +353,9 @@ def __init__(self,
self._register_load_state_dict_pre_hook(
self._prepare_abs_pos_embed)

self._register_load_state_dict_pre_hook(
self._prepare_relative_position_bias_table)

self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval

Expand Down Expand Up @@ -499,3 +503,30 @@ def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)

def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
**kwargs):
all_keys = list(state_dict.keys())
state_dict_model = self.state_dict()
for key in all_keys:
if 'relative_position_bias_table' in key:
relative_position_bias_table_pretrained = state_dict[key]
relative_position_bias_table_current = state_dict_model[key]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
YuanLiuuuuuu marked this conversation as resolved.
Show resolved Hide resolved
if L1 != L2:
src_size = int(L1**0.5)
dst_size = int(L2**0.5)
new_rel_pos_bias = resize_relative_position_bias_table(
src_size, dst_size,
relative_position_bias_table_pretrained, nH1)
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info(
f'Resize the relative_position_bias_table from \
{state_dict[key].shape} to {new_rel_pos_bias.shape}')
state_dict[key] = new_rel_pos_bias

# The index buffer need to be re-generated.
index_buffer = key.replace('bias_table', 'index')
del state_dict[index_buffer]
6 changes: 4 additions & 2 deletions mmcls/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .attention import MultiheadAttention, ShiftWindowMSA
from .augment.augments import Augments
from .channel_shuffle import channel_shuffle
from .embed import HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed
from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
resize_relative_position_bias_table)
from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
from .inverted_residual import InvertedResidual
from .make_divisible import make_divisible
Expand All @@ -13,5 +14,6 @@
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed'
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed',
'resize_relative_position_bias_table'
]
57 changes: 57 additions & 0 deletions mmcls/models/utils/embed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -54,6 +55,62 @@ def resize_pos_embed(pos_embed,
return torch.cat((extra_tokens, dst_weight), dim=1)


def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head):
"""Resize relative position bias table.

Args:
src_shape (int): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (int): The resolution of downsampled new training
image, in format (H, W).
table (tensor): The relative position bias of the pretrained model.
num_head (int): Number of attention heads.

Returns:
torch.Tensor: The resized relative position bias table.
"""
from scipy import interpolate

def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)

left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_shape // 2)
if gp > dst_shape // 2:
right = q
else:
left = q

dis = []
cur = 1
for i in range(src_shape // 2):
dis.append(cur)
cur += q**(i + 1)

r_ids = [-_ for _ in reversed(dis)]

x = r_ids + [0] + dis
y = r_ids + [0] + dis

t = dst_shape // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)

all_rel_pos_bias = []

for i in range(num_head):
z = table[:, i].view(src_shape, src_shape).float().numpy()
f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f_cubic(dx,
dy)).contiguous().view(-1,
1).to(table.device))
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
return new_rel_pos_bias


class PatchEmbed(BaseModule):
"""Image to Patch Embedding.

Expand Down