Skip to content

Commit

Permalink
add FGD
Browse files Browse the repository at this point in the history
  • Loading branch information
yzd committed Mar 11, 2022
1 parent 67f35cb commit 105489c
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 95 deletions.
24 changes: 24 additions & 0 deletions configs/distill/fgd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# FGD
> [Focal and Global Knowledge Distillation for Detectors](https://arxiv.org/abs/2111.11837)
<!-- [ALGORITHM] -->
## Abstract

Knowledge distillation has been applied to image classification successfully. However, object detection is much more sophisticated and most knowledge distillation methods have failed on it. In this paper, we point out that in object detection, the features of the teacher and student vary greatly in different areas, especially in the foreground and background. If we distill them equally, the uneven differences between feature maps will negatively affect the distillation. Thus, we propose Focal and Global Distillation (FGD). Focal distillation separates the foreground and background, forcing the student to focus on the teacher's critical pixels and channels. Global distillation rebuilds the relation between different pixels and transfers it from teachers to students, compensating for missing global information in focal distillation. As our method only needs to calculate the loss on the feature map, FGD can be applied to various detectors. We experiment on various detectors with different backbones and the results show that the student detector achieves excellent mAP improvement. For example, ResNet-50 based RetinaNet, Faster RCNN, RepPoints and Mask RCNN with our distillation method achieve 40.7%, 42.0%, 42.0% and 42.1% mAP on COCO2017, which are 3.3, 3.6, 3.4 and 2.9 higher than the baseline, respectively.


![pipeline](/docs/en/imgs/model_zoo/fgd/pipeline.png)




## Citation

```latex
@article{yang2021focal,
title={Focal and Global Knowledge Distillation for Detectors},
author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun},
journal={arXiv preprint arXiv:2111.11837},
year={2021}
}
```
26 changes: 14 additions & 12 deletions configs/distill/fgd/fgd_gfl_r101_distill_gfl_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
]

# model settings
t_weight = 'https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth'
t_weight = 'https://download.openmmlab.com/mmdetection/v2.0/' + \
'gfl/gfl_r101_fpn_mstrain_2x_coco/' + \
'gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth'
student = dict(
type='mmdet.GFL',
backbone=dict(
Expand Down Expand Up @@ -46,7 +48,8 @@
loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25),
reg_max=16,
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
init_cfg=dict(type='Pretrained', prefix='bbox_head', checkpoint=t_weight)),
init_cfg=dict(
type='Pretrained', prefix='bbox_head', checkpoint=t_weight)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
Expand All @@ -62,9 +65,7 @@

teacher = dict(
type='mmdet.GFL',
init_cfg=dict(
type='Pretrained',
checkpoint=t_weight),
init_cfg=dict(type='Pretrained', checkpoint=t_weight),
backbone=dict(
type='ResNet',
depth=101,
Expand Down Expand Up @@ -116,11 +117,11 @@
max_per_img=100))

# algorithm setting
temp=0.5
alpha_fgd=0.001
beta_fgd=0.0005
gamma_fgd=0.0005
lambda_fgd=0.000005
temp = 0.5
alpha_fgd = 0.001
beta_fgd = 0.0005
gamma_fgd = 0.0005
lambda_fgd = 0.000005
algorithm = dict(
type='GeneralDistill',
architecture=dict(
Expand Down Expand Up @@ -200,8 +201,9 @@
]),
)

find_unused_parameters=True
find_unused_parameters = True

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
Binary file added docs/en/imgs/model_zoo/fgd/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cwd import ChannelWiseDivergence
from .fgd import FGDLoss
from .kl_divergence import KLDivergence
from .weighted_soft_label_distillation import WSLD
from .fgd import FGDLoss

__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD', 'FGDLoss']
165 changes: 83 additions & 82 deletions mmrazor/models/losses/fgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,158 +9,162 @@

@LOSSES.register_module()
class FGDLoss(nn.Module):
"""PyTorch version of 'Focal and Global Knowledge Distillation for
Detectors'.
"""PyTorch version of 'Focal and Global Knowledge Distillation for Detectors'
<https://arxiv.org/abs/2111.11837>
Args:
student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
temp (float, optional): Temperature coefficient. Defaults to 0.5.
name (str): the loss name of the layer
alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001
lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
alpha_fgd (float, optional): Weight of fg_loss.
beta_fgd (float, optional): Weight of bg_loss.
gamma_fgd (float, optional): Weight of mask_loss.
lambda_fgd (float, optional): Weight of relation_loss.
"""

def __init__(self,
student_channels,
teacher_channels,
temp=0.5,
alpha_fgd=0.001,
beta_fgd=0.0005,
gamma_fgd=0.001,
lambda_fgd=0.000005,
):
def __init__(
self,
student_channels,
teacher_channels,
temp=0.5,
alpha_fgd=0.001,
beta_fgd=0.0005,
gamma_fgd=0.001,
lambda_fgd=0.000005,
):
super(FGDLoss, self).__init__()
self.temp = temp
self.alpha_fgd = alpha_fgd
self.beta_fgd = beta_fgd
self.gamma_fgd = gamma_fgd
self.lambda_fgd = lambda_fgd

self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)
self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
self.channel_add_conv_s = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
nn.LayerNorm([teacher_channels//2, 1, 1]),
nn.ReLU(inplace=True),
nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
nn.Conv2d(teacher_channels, teacher_channels // 2, kernel_size=1),
nn.LayerNorm([teacher_channels // 2, 1, 1]), nn.ReLU(inplace=True),
nn.Conv2d(teacher_channels // 2, teacher_channels, kernel_size=1))
self.channel_add_conv_t = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
nn.LayerNorm([teacher_channels//2, 1, 1]),
nn.ReLU(inplace=True),
nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
nn.Conv2d(teacher_channels, teacher_channels // 2, kernel_size=1),
nn.LayerNorm([teacher_channels // 2, 1, 1]), nn.ReLU(inplace=True),
nn.Conv2d(teacher_channels // 2, teacher_channels, kernel_size=1))

self.reset_parameters()


def forward(self, preds_S, preds_T):
"""Forward function.
Args:
preds_S(Tensor): Bs*C*H*W, student's feature map
preds_T(Tensor): Bs*C*H*W, teacher's feature map
gt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)
gt_bboxes(tuple): Bs*[nt*4], (tl_x, tl_y, br_x, br_y)
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
"""
assert preds_S.shape[-2:] == preds_T.shape[-2:]
N, C, H, W = preds_S.shape
gt_bboxes = self.current_data['gt_boxxes']
img_metas = self.current_data['img_metas']
metas = self.current_data['img_metas']

S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)
S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)

Mask_fg = torch.zeros_like(S_attention_t)
Mask_bg = torch.ones_like(S_attention_t)
wmin,wmax,hmin,hmax = [],[],[],[]
M_fg = torch.zeros_like(S_attention_t)
M_bg = torch.ones_like(S_attention_t)
wmin, wmax, hmin, hmax = [], [], [], []
for i in range(N):
new_boxxes = torch.ones_like(gt_bboxes[i])
new_boxxes[:, 0] = gt_bboxes[i][:, 0]/img_metas[i]['img_shape'][1]*W
new_boxxes[:, 2] = gt_bboxes[i][:, 2]/img_metas[i]['img_shape'][1]*W
new_boxxes[:, 1] = gt_bboxes[i][:, 1]/img_metas[i]['img_shape'][0]*H
new_boxxes[:, 3] = gt_bboxes[i][:, 3]/img_metas[i]['img_shape'][0]*H
new_boxx = torch.ones_like(gt_bboxes[i])
new_boxx[:, 0] = gt_bboxes[i][:, 0] / metas[i]['img_shape'][1] * W
new_boxx[:, 2] = gt_bboxes[i][:, 2] / metas[i]['img_shape'][1] * W
new_boxx[:, 1] = gt_bboxes[i][:, 1] / metas[i]['img_shape'][0] * H
new_boxx[:, 3] = gt_bboxes[i][:, 3] / metas[i]['img_shape'][0] * H

wmin.append(torch.floor(new_boxxes[:, 0]).int())
wmax.append(torch.ceil(new_boxxes[:, 2]).int())
hmin.append(torch.floor(new_boxxes[:, 1]).int())
hmax.append(torch.ceil(new_boxxes[:, 3]).int())
wmin.append(torch.floor(new_boxx[:, 0]).int())
wmax.append(torch.ceil(new_boxx[:, 2]).int())
hmin.append(torch.floor(new_boxx[:, 1]).int())
hmax.append(torch.ceil(new_boxx[:, 3]).int())

area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))
height = hmax[i].view(1, -1) + 1 - hmin[i].view(1, -1)
width = wmax[i].view(1, -1) + 1 - wmin[i].view(1, -1)
area = 1.0 / height / width

for j in range(len(gt_bboxes[i])):
Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])

Mask_bg[i] = torch.where(Mask_fg[i]>0, 0, 1)
if torch.sum(Mask_bg[i]):
Mask_bg[i] /= torch.sum(Mask_bg[i])

fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg,
C_attention_s, C_attention_t, S_attention_s, S_attention_t)
mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
M_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
torch.maximum(M_fg[i][hmin[i][j]:hmax[i][j]+1,
wmin[i][j]:wmax[i][j]+1], area[0][j])

M_bg[i] = torch.where(M_fg[i] > 0, 0, 1)
if torch.sum(M_bg[i]):
M_bg[i] /= torch.sum(M_bg[i])

fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, M_fg, M_bg,
C_attention_s, C_attention_t,
S_attention_s, S_attention_t)
mask_loss = self.get_mask_loss(C_attention_s, C_attention_t,
S_attention_s, S_attention_t)
rela_loss = self.get_rela_loss(preds_S, preds_T)


loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss

return loss
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss

return loss

def get_attention(self, preds, temp):
""" preds: Bs*C*H*W """
N, C, H, W= preds.shape
N, C, H, W = preds.shape

value = torch.abs(preds)
# Bs*W*H
fea_map = value.mean(axis=1, keepdim=True)
S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)
S_attention = (H * W * F.softmax(
(fea_map / temp).view(N, -1), dim=1)).view(N, H, W)

# Bs*C
channel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)
C_attention = C * F.softmax(channel_map/temp, dim=1)
channel_map = value.mean(
axis=2, keepdim=False).mean(
axis=2, keepdim=False)
C_attention = C * F.softmax(channel_map / temp, dim=1)

return S_attention, C_attention


def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
def get_fea_loss(self, preds_S, preds_T, M_fg, M_bg, C_s, C_t, S_s, S_t):
loss_mse = nn.MSELoss(reduction='sum')
Mask_fg = Mask_fg.unsqueeze(dim=1)
Mask_bg = Mask_bg.unsqueeze(dim=1)

M_fg = M_fg.unsqueeze(dim=1)
M_bg = M_bg.unsqueeze(dim=1)

C_t = C_t.unsqueeze(dim=-1)
C_t = C_t.unsqueeze(dim=-1)

S_t = S_t.unsqueeze(dim=1)

fea_t= torch.mul(preds_T, torch.sqrt(S_t))
fea_t = torch.mul(preds_T, torch.sqrt(S_t))
fea_t = torch.mul(fea_t, torch.sqrt(C_t))
fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))
fg_fea_t = torch.mul(fea_t, torch.sqrt(M_fg))
bg_fea_t = torch.mul(fea_t, torch.sqrt(M_bg))

fea_s = torch.mul(preds_S, torch.sqrt(S_t))
fea_s = torch.mul(fea_s, torch.sqrt(C_t))
fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))
fg_fea_s = torch.mul(fea_s, torch.sqrt(M_fg))
bg_fea_s = torch.mul(fea_s, torch.sqrt(M_bg))

fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)
bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)
fg_loss = loss_mse(fg_fea_s, fg_fea_t) / len(M_fg)
bg_loss = loss_mse(bg_fea_s, bg_fea_t) / len(M_bg)

return fg_loss, bg_loss


def get_mask_loss(self, C_s, C_t, S_s, S_t):

mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)
mask_loss = torch.sum(torch.abs(
(C_s - C_t))) / len(C_s) + torch.sum(torch.abs(
(S_s - S_t))) / len(S_s)

return mask_loss



def spatial_pool(self, x, in_type):
batch, channel, width, height = x.size()
input_x = x
Expand All @@ -186,7 +190,6 @@ def spatial_pool(self, x, in_type):

return context


def get_rela_loss(self, preds_S, preds_T):
loss_mse = nn.MSELoss(reduction='sum')

Expand All @@ -202,23 +205,21 @@ def get_rela_loss(self, preds_S, preds_T):
channel_add_t = self.channel_add_conv_t(context_t)
out_t = out_t + channel_add_t

rela_loss = loss_mse(out_s, out_t)/len(out_s)

return rela_loss
rela_loss = loss_mse(out_s, out_t) / len(out_s)

return rela_loss

def last_zero_init(self, m):
if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0)
else:
constant_init(m, val=0)


def reset_parameters(self):
kaiming_init(self.conv_mask_s, mode='fan_in')
kaiming_init(self.conv_mask_t, mode='fan_in')
self.conv_mask_s.inited = True
self.conv_mask_t.inited = True

self.last_zero_init(self.channel_add_conv_s)
self.last_zero_init(self.channel_add_conv_t)
self.last_zero_init(self.channel_add_conv_t)

0 comments on commit 105489c

Please sign in to comment.