Skip to content

Commit

Permalink
add se
Browse files Browse the repository at this point in the history
  • Loading branch information
likyoo committed Nov 22, 2021
1 parent 9286d9f commit de96478
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 226 deletions.
1 change: 0 additions & 1 deletion change_detection_pytorch/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@
from .heads import (
SegmentationHead,
ClassificationHead,
SegmentationOCRHead,
)
17 changes: 1 addition & 16 deletions change_detection_pytorch/base/heads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
from .modules import Flatten, Activation, OCR
from .modules import Flatten, Activation


class SegmentationHead(nn.Sequential):
Expand All @@ -11,21 +11,6 @@ def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, up
super().__init__(conv2d, upsampling, activation)


class SegmentationOCRHead(nn.Module):

def __init__(self, in_channels, out_channels, activation=None, upsampling=1, align_corners=True):
super().__init__()
self.ocr_head = OCR(in_channels, out_channels)
self.upsampling = nn.Upsample(scale_factor=upsampling, mode='bilinear', align_corners=align_corners) if upsampling > 1 else nn.Identity()
self.activation = Activation(activation)

def forward(self, x):
coarse_pre, pre = self.ocr_head(x)
coarse_pre = self.activation(self.upsampling(coarse_pre))
pre = self.activation(self.upsampling(pre))
return [coarse_pre, pre]


class ClassificationHead(nn.Sequential):

def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
Expand Down
225 changes: 16 additions & 209 deletions change_detection_pytorch/base/modules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

try:
from inplace_abn import InPlaceABN
Expand Down Expand Up @@ -119,7 +117,6 @@ class ECAM(nn.Module):
Ensemble Channel Attention Module for UNetPlusPlus.
Fang S, Li K, Shao J, et al. SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images[J].
IEEE Geoscience and Remote Sensing Letters, 2021.
Not completely consistent, to be improved.
"""
def __init__(self, in_channels, out_channels, map_num=4):
Expand All @@ -142,218 +139,26 @@ def forward(self, x):
return out


class ModuleHelper:

@staticmethod
def BNReLU(num_features, bn_type=None, **kwargs):
return nn.Sequential(
nn.BatchNorm2d(num_features, **kwargs),
nn.ReLU()
)

@staticmethod
def BatchNorm2d(*args, **kwargs):
return nn.BatchNorm2d


class SpatialGather_Module(nn.Module):
"""
Aggregate the context features according to the initial
predicted probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, cls_num=0, scale=1):
super(SpatialGather_Module, self).__init__()
self.cls_num = cls_num
self.scale = scale

def forward(self, feats, probs):
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
probs = probs.view(batch_size, c, -1)
feats = feats.view(batch_size, feats.size(1), -1)
feats = feats.permute(0, 2, 1) # batch x hw x c
probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw
ocr_context = torch.matmul(probs, feats)\
.permute(0, 2, 1).unsqueeze(3)# batch x k x c
return ocr_context


class _ObjectAttentionBlock(nn.Module):
'''
The basic implementation for object context block
Input:
N X C X H X W
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
scale : choose the scale to downsample the input feature maps (save memory cost)
bn_type : specify the bn type
Return:
N X C X H X W
'''
def __init__(self,
in_channels,
key_channels,
scale=1,
bn_type=None):
super(_ObjectAttentionBlock, self).__init__()
self.scale = scale
self.in_channels = in_channels
self.key_channels = key_channels
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
self.f_pixel = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
)
self.f_object = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
)
self.f_down = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
)
self.f_up = nn.Sequential(
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0, bias=False),
ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type),
)

def forward(self, x, proxy):
batch_size, h, w = x.size(0), x.size(2), x.size(3)
if self.scale > 1:
x = self.pool(x)

query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
query = query.permute(0, 2, 1)
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
value = value.permute(0, 2, 1)

sim_map = torch.matmul(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)

# add bg context ...
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.view(batch_size, self.key_channels, *x.size()[2:])
context = self.f_up(context)
if self.scale > 1:
context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True)

return context


class ObjectAttentionBlock2D(_ObjectAttentionBlock):
def __init__(self,
in_channels,
key_channels,
scale=1,
bn_type=None):
super(ObjectAttentionBlock2D, self).__init__(in_channels,
key_channels,
scale,
bn_type=bn_type)


class SpatialOCR_Module(nn.Module):
"""
Implementation of the OCR module:
We aggregate the global object representation to update the representation for each pixel.
"""
def __init__(self,
in_channels,
key_channels,
out_channels,
scale=1,
dropout=0.1,
bn_type=None):
super(SpatialOCR_Module, self).__init__()
self.object_context_block = ObjectAttentionBlock2D(in_channels,
key_channels,
scale,
bn_type)
_in_channels = 2 * in_channels

self.conv_bn_dropout = nn.Sequential(
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
ModuleHelper.BNReLU(out_channels, bn_type=bn_type),
nn.Dropout2d(dropout)
)

def forward(self, feats, proxy_feats):
context = self.object_context_block(feats, proxy_feats)

output = self.conv_bn_dropout(torch.cat([context, feats], 1))

return output


class OCR(nn.Module):
class SEModule(nn.Module):
"""
Segmentation Transformer: Object-Contextual Representations for Semantic Segmentation
https://arxiv.org/pdf/1909.11065.pdf
Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C]
//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141.
"""
def __init__(self, in_channels, num_classes, ocr_mid_channels=512, ocr_key_channels=256):

super().__init__()
pre_stage_channels = in_channels
last_inp_channels = np.int(np.sum(pre_stage_channels))

self.conv3x3_ocr = nn.Sequential(
nn.Conv2d(last_inp_channels, ocr_mid_channels,
kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(ocr_mid_channels),
nn.ReLU(inplace=True),
)
self.ocr_gather_head = SpatialGather_Module(num_classes)

self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
key_channels=ocr_key_channels,
out_channels=ocr_mid_channels,
scale=1,
dropout=0.05,
)
self.cls_head = nn.Conv2d(
ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)

self.aux_head = nn.Sequential(
nn.Conv2d(last_inp_channels, last_inp_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(last_inp_channels),
def __init__(self, in_channels, reduction=16):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(last_inp_channels, num_classes,
kernel_size=1, stride=1, padding=0, bias=True)
nn.Linear(in_channels // reduction, in_channels, bias=False),
nn.Sigmoid()
)

def forward(self, x):

out_aux_seg = []

# ocr
out_aux = self.aux_head(x)
# compute contrast feature
feats = self.conv3x3_ocr(x)

context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)

out = self.cls_head(feats)

out_aux_seg.append(out_aux)
out_aux_seg.append(out)

return out_aux_seg
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)


class ArgMax(nn.Module):
Expand Down Expand Up @@ -412,6 +217,8 @@ def __init__(self, name, **params):
self.attention = CBAMSpatial(**params)
elif name == 'cbam':
self.attention = CBAM(**params)
elif name == 'se':
self.attention = SEModule(**params)
else:
raise ValueError("Attention {} is not implemented".format(name))

Expand Down

0 comments on commit de96478

Please sign in to comment.