In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo

backbone_url = 'https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/backbone_v2.pth'

## 0. Basic Blocks

In [2]:
class ConvBNReLU(nn.Module):

    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
                 dilation=1, groups=1, bias=False):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(
                in_chan, out_chan, kernel_size=ks, stride=stride,
                padding=padding, dilation=dilation,
                groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_chan)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        feat = self.conv(x)
        feat = self.bn(feat)
        feat = self.relu(feat)     
        return feat
    
class UpSample(nn.Module):

    def __init__(self, n_chan, factor=2):
        super(UpSample, self).__init__()
        out_chan = n_chan * factor * factor
        self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
        self.up = nn.PixelShuffle(factor)
        self.init_weight()

    def forward(self, x):
        feat = self.proj(x)  # (N, out_chan, h, 2)
        feat = self.up(feat)  # (N, out_chan/factor^2, 2*h, 2*w)
        return feat  # (N, n_chan, 2*h, 2*w)

    def init_weight(self):
        nn.init.xavier_normal_(self.proj.weight, gain=1.)

## 1. Detail Branch

In [3]:
class DetailBranch(nn.Module):

    def __init__(self):
        super(DetailBranch, self).__init__()
        self.S1 = nn.Sequential(
            ConvBNReLU(3, 64, 3, stride=2),
            ConvBNReLU(64, 64, 3, stride=1),
        )
        self.S2 = nn.Sequential(
            ConvBNReLU(64, 64, 3, stride=2),
            ConvBNReLU(64, 64, 3, stride=1),
            ConvBNReLU(64, 64, 3, stride=1),
        )
        self.S3 = nn.Sequential(
            ConvBNReLU(64, 128, 3, stride=2),
            ConvBNReLU(128, 128, 3, stride=1),
            ConvBNReLU(128, 128, 3, stride=1),
        )

    def forward(self, x):
        feat = self.S1(x)  # (N, 64, h/2, w/2)
        feat = self.S2(feat)  # (N, 64, h/4, w/4)
        feat = self.S3(feat)  # (N, 128, h/8, w/8)
        return feat


## 2. Semantic Branch

### 2.1 Stem Block

In [4]:
class StemBlock(nn.Module):
    '''1st stage of semantic branch
    - Block with 2 branches with different manners 
    to downsample the feature representation'''
    def __init__(self):
        super(StemBlock, self).__init__()
        self.conv = ConvBNReLU(3, 16, 3, stride=2)  # shared
        
        # conv path
        self.left = nn.Sequential(
            ConvBNReLU(16, 8, 1, stride=1, padding=0),
            ConvBNReLU(8, 16, 3, stride=2),
        )
        
        # max-pool path
        self.right = nn.MaxPool2d(
            kernel_size=3, stride=2, padding=1, ceil_mode=False)
        
        # fuse
        self.fuse = ConvBNReLU(32, 16, 3, stride=1)

    def forward(self, x):
        feat = self.conv(x)  # (N, 16, 2/h, w/2)
        
        feat_left = self.left(feat)  # (N, 16, h/4, w/4)
        feat_right = self.right(feat)  # (N, 16, h/4, w/4)
        
        feat = torch.cat([feat_left, feat_right], dim=1)  # (N, 32, h/4, w/4)
        feat = self.fuse(feat)  # (N, 16, h/4, w/4)
        
        return feat

### 2.2 Context Embedding Block

In [5]:
class CEBlock(nn.Module):
    '''last stage of semantic branch 
    for context embedding'''
    def __init__(self):
        super(CEBlock, self).__init__()
        self.bn = nn.BatchNorm2d(128)
        self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
        #TODO: in paper here is naive conv2d, no bn-relu
        self.conv_last = ConvBNReLU(128, 128, 3, stride=1)

    def forward(self, x):
        feat = torch.mean(x, dim=(2, 3), keepdim=True)
        feat = self.bn(feat)
        feat = self.conv_gap(feat)
        feat = feat + x
        feat = self.conv_last(feat)
        
        return feat

### 2.3 Gather and Expansion Layer

In [6]:
class GELayerS1(nn.Module):
    '''Gather & Expansion to aggregate feature responses and
    expand to a high dimensional space'''
    def __init__(self, in_chan, out_chan, exp_ratio=6):
        super(GELayerS1, self).__init__()
        mid_chan = in_chan * exp_ratio
        # 3x3 conv
        self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
        
        # depth-wise conv (expansion)
        self.dwconv = nn.Sequential(
            nn.Conv2d(
                in_chan, mid_chan, kernel_size=3, stride=1,
                padding=1, groups=in_chan, bias=False),
            nn.BatchNorm2d(mid_chan),
            nn.ReLU(inplace=True), # not shown in paper
        )
        
        # 1x1 conv (reduction)
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                mid_chan, out_chan, kernel_size=1, stride=1,
                padding=0, bias=False),
            nn.BatchNorm2d(out_chan),
        )
        self.conv2[1].last_bn = True
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        feat = self.conv1(x)
        feat = self.dwconv(feat)
        feat = self.conv2(feat)
        feat = feat + x  # residual connection
        feat = self.relu(feat)
        return feat


class GELayerS2(nn.Module):
    '''Gather & Expansion for the stride=2 case'''
    def __init__(self, in_chan, out_chan, exp_ratio=6):
        super(GELayerS2, self).__init__()
        mid_chan = in_chan * exp_ratio
        # 3x3 conv
        self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
        
        # depth-wise conv (expansion), stride=2
        self.dwconv1 = nn.Sequential(
            nn.Conv2d(
                in_chan, mid_chan, kernel_size=3, stride=2,
                padding=1, groups=in_chan, bias=False),
            nn.BatchNorm2d(mid_chan),
        )
        
        # depth-wise conv
        self.dwconv2 = nn.Sequential(
            nn.Conv2d(
                mid_chan, mid_chan, kernel_size=3, stride=1,
                padding=1, groups=mid_chan, bias=False),
            nn.BatchNorm2d(mid_chan),
            nn.ReLU(inplace=True), # not shown in paper
        )
        
        # 1x1 conv
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                mid_chan, out_chan, kernel_size=1, stride=1,
                padding=0, bias=False),
            nn.BatchNorm2d(out_chan),
        )
        self.conv2[1].last_bn = True
        
        # separable path for residual connection
        self.shortcut = nn.Sequential(
                # 3x3 conv with stride=2
                nn.Conv2d(
                    in_chan, in_chan, kernel_size=3, stride=2,
                    padding=1, groups=in_chan, bias=False),
                nn.BatchNorm2d(in_chan),
            
                # 1x1 conv
                nn.Conv2d(
                    in_chan, out_chan, kernel_size=1, stride=1,
                    padding=0, bias=False),
                nn.BatchNorm2d(out_chan),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        feat = self.conv1(x)
        feat = self.dwconv1(feat)
        feat = self.dwconv2(feat)
        feat = self.conv2(feat)
        
        shortcut = self.shortcut(x)
        
        feat = feat + shortcut
        feat = self.relu(feat)
        
        return feat

In [7]:
class SegmentBranch(nn.Module):

    def __init__(self):
        super(SegmentBranch, self).__init__()
        self.S1S2 = StemBlock()
        self.S3 = nn.Sequential(
            GELayerS2(16, 32),
            GELayerS1(32, 32),
        )
        self.S4 = nn.Sequential(
            GELayerS2(32, 64),
            GELayerS1(64, 64),
        )
        self.S5_4 = nn.Sequential(
            GELayerS2(64, 128),
            GELayerS1(128, 128),
            GELayerS1(128, 128),
            GELayerS1(128, 128),
        )
        self.S5_5 = CEBlock()

    def forward(self, x):
        feat2 = self.S1S2(x)  # (N, 16, h/4, w/4)
        feat3 = self.S3(feat2)  # (N, 32, h/8. w/8)
        feat4 = self.S4(feat3)  # (N, 64, h/16, w/16)
        feat5_4 = self.S5_4(feat4)  # (N, 128, h/32, w/32 )
        feat5_5 = self.S5_5(feat5_4)  # (N, 128, h/32, w/32)

        return feat2, feat3, feat4, feat5_4, feat5_5

## 3. Bilateral Guided Aggregation

In [15]:
class BGALayer(nn.Module):
    '''fuse the complementary information from detail & semantic branches
       detail-branch-(N, H, W, 128) / semantic-branch-(N, H/4, W/4, 128)'''
    def __init__(self):
        super(BGALayer, self).__init__()
        
        # ---- Detail Branch ----
        # 1. 3x3 depth-wise conv & 1x1 conv
        self.left1 = nn.Sequential(
            nn.Conv2d(
                128, 128, kernel_size=3, stride=1,
                padding=1, groups=128, bias=False),
            nn.BatchNorm2d(128),
            nn.Conv2d(
                128, 128, kernel_size=1, stride=1,
                padding=0, bias=False),
        )  # (N, H, W, 128)
        
        # 2. 3x3 conv with stride=2 & 3x3 APooling
        self.left2 = nn.Sequential(
            nn.Conv2d(
                128, 128, kernel_size=3, stride=2,
                padding=1, bias=False),  # (N, H/2, W/2, 128)
            nn.BatchNorm2d(128),
            nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)  # (N, H/4, W/4, 128)
        )  
        
        # ---- Semantic Branch ----
        # 1. 3x3 conv & 4x4 up-sample
        self.right1 = nn.Sequential(
            nn.Conv2d(
                128, 128, kernel_size=3, stride=1,
                padding=1, bias=False),
            nn.BatchNorm2d(128),
        )  # (N, H/4, W/4, 128)
        self.up1 = nn.Upsample(scale_factor=4)  # (N, H, W, 128)
        
        # 2. 3x3 depth-wise conv & 1x1 conv
        self.right2 = nn.Sequential(
            nn.Conv2d(
                128, 128, kernel_size=3, stride=1,
                padding=1, groups=128, bias=False),  # (N, H/4, W/4, 128)
            nn.BatchNorm2d(128),
            nn.Conv2d(
                128, 128, kernel_size=1, stride=1,
                padding=0, bias=False),  # (N, H/4, W/4, 128)
        )   
        
        self.up2 = nn.Upsample(scale_factor=4)  # use before summation
        
        ##TODO: does this really has no relu?
        self.conv = nn.Sequential(
            nn.Conv2d(
                128, 128, kernel_size=3, stride=1,
                padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True), # not shown in paper
        )

    def forward(self, x_d, x_s):  
        '''x_d: output of detail-branch (N, 128, h/8, w/8)
           x_s: output of semantic-branch (N, 128, h/32, w/32)'''
        dsize = x_d.size()[2:]
        
        # Detail Branch
        left1 = self.left1(x_d)  # (N, 128, h/8, w/8)
        left2 = self.left2(x_d)  # (N, 128, h/32, w/32)

        # Semantic Branch
        right1 = self.right1(x_s)  # (N, 128, h/32, w/32)
        right1 = self.up1(right1)  # (N, 128, h/8, w/8)

        right2 = self.right2(x_s)  # (N, 128, h/32, w/32)
        
        # Fuse
        left = left1 * torch.sigmoid(right1)  # (N, 128, h/8, w/8)
        
        right = left2 * torch.sigmoid(right2)  # (N, 128, h/32, w/32)
        right = self.up2(right)  # (N, 128, h/8, w/8)
        
        out = self.conv(left + right)  # (N, 128, h/8, w/8)
        
        return out

## 4. Segment Head in Booster

In [23]:
class SegmentHead(nn.Module):

    def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True):
        '''aux=False for final prediction, aug=True for Booster'''
        super(SegmentHead, self).__init__()
        self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
        self.drop = nn.Dropout(0.1)
        self.up_factor = up_factor

        out_chan = n_classes
        mid_chan2 = up_factor * up_factor if aux else mid_chan
        up_factor = up_factor // 2 if aux else up_factor
        self.conv_out = nn.Sequential(
            nn.Sequential(
                nn.Upsample(scale_factor=2),
                ConvBNReLU(mid_chan, mid_chan2, 3, stride=1)
                ) if aux else nn.Identity(),
            nn.Conv2d(mid_chan2, out_chan, 1, 1, 0, bias=True),
            nn.Upsample(scale_factor=up_factor, mode='bilinear', align_corners=False)
        )

    def forward(self, x):
        '''x: (N, 128, h/8, w/8) - output of aggregation layer'''
        feat = self.conv(x)  # (N, 128x8, h/8, w/8 )
        feat = self.drop(feat)
        feat = self.conv_out(feat)  # (N, n_classes, h, w)
        return feat

## Model

In [11]:
class BiSeNetV2(nn.Module):

    def __init__(self, n_classes, aux_mode='train'):
        super(BiSeNetV2, self).__init__()
        self.aux_mode = aux_mode
        self.detail = DetailBranch()
        self.segment = SegmentBranch()
        self.bga = BGALayer()

        ## TODO: what is the number of mid chan ?
        self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False)
        if self.aux_mode == 'train':
            self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4)
            self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8)
            self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16)
            self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32)
            
        self.init_weights()

    def forward(self, x):
        size = x.size()[2:]
        # Detail Branch
        feat_d = self.detail(x)  # (N, 128, h/8, w/8)
        
        # Semantic Branch
        feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
        # feat2: (N, 16, h/4, w/4) - stem block
        # feat3: (N, 32, h/8, w/8) - GE
        # feat4: (N, 64, h/16, w/16) - GE
        # feat5_4: (N, 128, h/32, w/32) - GE
        # feat_s: (N, 128, h/32, w/32) - context embedding
        
        # Bilateral Guided Aggregation
        # fuse detail-branch: (N, H, W, C) & semantic-branch: (N, H/4, W/4, C)
        feat_head = self.bga(feat_d, feat_s)  # (N, 128, h/8, w/8)

        logits = self.head(feat_head)  # (N, n_classes, h, w)
        if self.aux_mode == 'train':
            logits_aux2 = self.aux2(feat2)  # (N, n_classes, h, w)
            logits_aux3 = self.aux3(feat3)
            logits_aux4 = self.aux4(feat4)
            logits_aux5_4 = self.aux5_4(feat5_4)
            return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4
        
        elif self.aux_mode == 'eval':
            return logits,
        
        elif self.aux_mode == 'pred':
            pred = logits.argmax(dim=1)
            return pred
        else:
            raise NotImplementedError
            
    def init_weights(self):
        for name, module in self.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(module.weight, mode='fan_out')
                if not module.bias is None: nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.modules.batchnorm._BatchNorm):
                if hasattr(module, 'last_bn') and module.last_bn:
                    nn.init.zeros_(module.weight)
                else:
                    nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
        self.load_pretrain()


    def load_pretrain(self):
        state = modelzoo.load_url(backbone_url)
        for name, child in self.named_children():
            if name in state.keys():
                child.load_state_dict(state[name], strict=True)

    def get_params(self):
        def add_param_to_list(mod, wd_params, nowd_params):
            for param in mod.parameters():
                if param.dim() == 1:
                    nowd_params.append(param)
                elif param.dim() == 4:
                    wd_params.append(param)
                else:
                    print(name)

        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            if 'head' in name or 'aux' in name:
                add_param_to_list(child, lr_mul_wd_params, lr_mul_nowd_params)
            else:
                add_param_to_list(child, wd_params, nowd_params)
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params


In [13]:
x = torch.randn([2, 3, 512, 512])
model = BiSeNetV2(n_classes=1, aux_mode='train')
out, aux1, aux2, aux3, aux4 = model(x)
print(out.shape)

torch.Size([2, 1, 512, 512])
