## Blocks

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from lib.modules.layers import *


class RFB_kernel(nn.Module):
    def __init__(self, in_channel, out_channel, receptive_size=3):
        super(RFB_kernel, self).__init__()
        self.conv0 = conv(in_channel, out_channel, 1)
        # k x k = (1 x k) x (k x 1)
        self.conv1 = conv(out_channel, out_channel, kernel_size=(1, receptive_size))
        self.conv2 = conv(out_channel, out_channel, kernel_size=(receptive_size, 1))     
        # 3x3 conv with dilation=k
        self.conv3 = conv(out_channel, out_channel, 3, dilation=receptive_size)

    def forward(self, x):
        '''(N, in_c, h, w) -> (N, out_c, h, w)'''
        x = self.conv0(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x
    
    
class RFB(nn.Module):
    '''Receptive Field Block
    Input feature map is forwarde to each receptive field path'''
    def __init__(self, in_channel, out_channel):
        super(RFB, self).__init__()
        self.relu = nn.ReLU(True)
        
        # receptive field paths
        self.branch0 = conv(in_channel, out_channel, 1)  # 1x1 conv
        self.branch1 = RFB_kernel(in_channel, out_channel, 3)  # k=3
        self.branch2 = RFB_kernel(in_channel, out_channel, 5)  # k=5
        self.branch3 = RFB_kernel(in_channel, out_channel, 7)  # k=7

        self.conv_cat = conv(4 * out_channel, out_channel, 3)
        self.conv_res = conv(in_channel, out_channel, 1)

    def forward(self, x):
        '''(N, in_c, h, w) -> (N, out_c, h, w)'''
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)

        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
        x = self.relu(x_cat + self.conv_res(x))

        return x

In [None]:
class self_attn(nn.Module):
    '''Axial-Attention 
       - performing non-local operation with respect to the single axis(H/W)'''
    def __init__(self, in_channels, mode='hw'):
        super(self_attn, self).__init__()

        self.mode = mode

        self.query_conv = conv(in_channels, in_channels // 8, kernel_size=(1, 1))
        self.key_conv = conv(in_channels, in_channels // 8, kernel_size=(1, 1))
        self.value_conv = conv(in_channels, in_channels, kernel_size=(1, 1))

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):  # x: (N, C, H, W), c=C/8
        batch_size, channel, height, width = x.size()
        
        axis = 1
        if 'h' in self.mode:
            axis *= height
        if 'w' in self.mode:
            axis *= width

        view = (batch_size, -1, axis)  
        
        # 1. axis=h -> Q: (N, H, Wc), K: (N, Wc, H), V: (N, WC, H)
        # 2. axis=w -> Q: (N, W, Hc), K: (N, Hc, W), V: (N, HC, W)
        projected_query = self.query_conv(x).view(*view).permute(0, 2, 1)  
        projected_key = self.key_conv(x).view(*view)  
        projected_value = self.value_conv(x).view(*view)  
        
        # 1. axis=h -> attention: (N, H, H)
        # 2. axis=w -> attention: (N, W, W)
        attention_map = torch.bmm(projected_query, projected_key)  
        attention = self.softmax(attention_map)
        
        # 1. axis=h -> (N, WC, H) -> (N, C, H, W)
        # 2. axis=w -> (N, HC, W) -> (N, C, H, W)
        out = torch.bmm(projected_value, attention.permute(0, 2, 1))  
        out = out.view(batch_size, channel, height, width)  

        out = self.gamma * out + x
        
        return out

## PAA-Encoder

In [None]:
class PAA_kernel(nn.Module):
    def __init__(self, in_channel, out_channel, receptive_size=3):
        super(PAA_kernel, self).__init__()
        self.conv0 = conv(in_channel, out_channel, 1)
        self.conv1 = conv(out_channel, out_channel, kernel_size=(1, receptive_size))
        self.conv2 = conv(out_channel, out_channel, kernel_size=(receptive_size, 1))
        self.conv3 = conv(out_channel, out_channel, 3, dilation=receptive_size)
        
        # Parallel Axial Attention (sequentially connected axial-attention)
        self.Hattn = self_attn(out_channel, mode='h')
        self.Wattn = self_attn(out_channel, mode='w')

    def forward(self, x):
        x = self.conv0(x)
        x = self.conv1(x)
        x = self.conv2(x)

        Hx = self.Hattn(x)
        Wx = self.Wattn(x)
        
        # aggregate(parallel connection) with summation
        x = self.conv3(Hx + Wx)
        return x
    
    
class PAA_e(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(PAA_e, self).__init__()
        self.relu = nn.ReLU(True)
        
        # RFB strategy & global refinement(attention)
        self.branch0 = conv(in_channel, out_channel, 1)
        self.branch1 = PAA_kernel(in_channel, out_channel, 3)
        self.branch2 = PAA_kernel(in_channel, out_channel, 5)
        self.branch3 = PAA_kernel(in_channel, out_channel, 7)
        
        # aggregation
        self.conv_cat = conv(4 * out_channel, out_channel, 3)
        self.conv_res = conv(in_channel, out_channel, 1)

    def forward(self, x):
        '''(N, in_c, H, W) -> (N, out_c, H, W)'''
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)

        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
        x = self.relu(x_cat + self.conv_res(x))

        return x

## UACA

In [None]:
class UACA(nn.Module):
    '''Uncertainty Augmented Context Attention
      -self-attention which incorporates uncertan area for rich semantic feature extraction wo extra boundary guidance'''
    def __init__(self, in_channel, channel):
        super(UACA, self).__init__()
        self.channel = channel

        self.conv_query = nn.Sequential(conv(in_channel, channel, 3, relu=True),
                                        conv(channel, channel, 3, relu=True))
        self.conv_key = nn.Sequential(conv(in_channel, channel, 1, relu=True),
                                      conv(channel, channel, 1, relu=True))
        self.conv_value = nn.Sequential(conv(in_channel, channel, 1, relu=True),
                                        conv(channel, channel, 1, relu=True))

        self.conv_out1 = conv(channel, channel, 3, relu=True)
        self.conv_out2 = conv(in_channel + channel, channel, 3, relu=True)
        self.conv_out3 = conv(channel, channel, 3, relu=True)
        self.conv_out4 = conv(channel, 1, 1)
        

    def forward(self, x, map_):
        '''x: input feature (N, 512, H, W)
           map_: predicted semantic mask (N, 1, h, w)'''
        b, c, h, w = x.shape
        
        # compute class probability
        map_ = F.interpolate(map_, size=x.shape[-2:], mode='bilinear', align_corners=False)
        fg = torch.sigmoid(map_)
        
        p = fg - .5

        fg = torch.clip(p, 0, 1)  # foreground
        bg = torch.clip(-p, 0, 1)  # background
        cg = .5 - torch.abs(p)  # confusion area

        prob = torch.cat([fg, bg, cg], dim=1)

        # reshape feature & prob
        f = x.view(b, h * w, -1)  # (N, hw, 512)
        prob = prob.view(b, 3, h * w)  # (N, 3, hw)
        
        # compute context vector
        context = torch.bmm(prob, f).permute(0, 2, 1).unsqueeze(3) # b, 3, c (N, 512, 3, 1)

        # k q v compute
        query = self.conv_query(x).view(b, self.channel, -1).permute(0, 2, 1)  # (N, 256, 256)
        key = self.conv_key(context).view(b, self.channel, -1)  # (N, 256, 3)
        value = self.conv_value(context).view(b, self.channel, -1).permute(0, 2, 1)  # (N, 3, 256)

        # compute similarity map
        sim = torch.bmm(query, key) # b, hw, c x b, c, 2  (N, 256, 3)
        sim = (self.channel ** -.5) * sim
        sim = F.softmax(sim, dim=-1)  # (N, 256, 3)

        # compute refined feature
        context = torch.bmm(sim, value).permute(0, 2, 1).contiguous().view(b, -1, h, w)  # (N, 256, H, W)
        context = self.conv_out1(context)  # (N, 256, H, W)

        x = torch.cat([x, context], dim=1)  # (N, 512+256, H, W)
        x = self.conv_out2(x)  # (N, 256, H, W)
        x = self.conv_out3(x)  # (N, 256, H, W)
        out = self.conv_out4(x)  # (N, 1, H, W)
        
        out = out + map_
        
        return x, out

## PAA-Decoder

In [None]:
class PAA_d(nn.Module):
    # dense decoder, it can be replaced by other decoder previous, such as DSS, amulet, and so on.
    # used after MSF
    def __init__(self, channel):
        super(PAA_d, self).__init__()
        self.conv1 = conv(channel * 3 ,channel, 3)
        self.conv2 = conv(channel, channel, 3)
        self.conv3 = conv(channel, channel, 3)
        self.conv4 = conv(channel, channel, 3)
        self.conv5 = conv(channel, 1, 3, bn=False)
        
        # PAA
        self.Hattn = self_attn(channel, mode='h')
        self.Wattn = self_attn(channel, mode='w')

        self.upsample = lambda img, size: F.interpolate(img, size=size, mode='bilinear', align_corners=True)
        
    def forward(self, f1, f2, f3):
        '''f1, f2, f3: PAA encoder features
           f1: (N, 256, H/16, W/16), f2: (N, 256, H/16, W/16), f3: (N, 256, H/8, W/8)'''
        # up-sampling
        f1 = self.upsample(f1, f3.shape[-2:])
        f2 = self.upsample(f2, f3.shape[-2:])
        
        # concatenation
        f_con = torch.cat([f1, f2, f3], dim=1)
        f_con = self.conv1(f_con)
        
        # PAA
        Hf_con = self.Hattn(f_con)
        Wf_con = self.Wattn(f_con)
        f_con = Hf_con + Wf_con
        
        f_con = self.conv2(f_con)
        f_con = self.conv3(f_con)
        f_con = self.conv4(f_con)
        
        out = self.conv5(f_con)

        return f_con, out

## Network

In [None]:
from backbones.Res2Net_v1b import res2net50_v1b_26w_4s
from optim.losses import *


class UACANet(nn.Module):
    # res2net based encoder decoder
    def __init__(self, channels=256, output_stride=16, pretrained=False):
        super(UACANet, self).__init__()
        self.resnet = res2net50_v1b_26w_4s(pretrained=pretrained, output_stride=output_stride)
        
        # parallel PAA encoders
        self.context2 = PAA_e(512, channels)
        self.context3 = PAA_e(1024, channels)
        self.context4 = PAA_e(2048, channels)

        self.decoder = PAA_d(channels)
        
        # parallel uncertainty augmented context attention
        self.attention2 = UACA(channels * 2, channels)
        self.attention3 = UACA(channels * 2, channels)
        self.attention4 = UACA(channels * 2, channels)

        self.loss_fn = bce_iou_loss

        self.ret = lambda x, target: F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False)
        self.res = lambda x, size: F.interpolate(x, size=size, mode='bilinear', align_corners=False)
        

    def forward(self, sample):
        x = sample['image']  # (N, 3, H, W)
        if 'gt' in sample.keys():
            y = sample['gt']  # (N, 1, H, W)
        else:
            y = None
            
        base_size = x.shape[-2:]
        
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)  # (N, 64, H/4, W/4)
        x1 = self.resnet.layer1(x)  # (N, 256, H/4, W/4)
        
        # ------------------ 1.Backbone -----------------------
        # intermediate feature maps of backbone
        x2 = self.resnet.layer2(x1)  # (N, 512, H/8, W/8)
        x3 = self.resnet.layer3(x2)  # (N, 1024, H/16, W/16)
        x4 = self.resnet.layer4(x3)  # (N, 2048, H/16, W/16)
        
        # ------------------ 2.PAA-Encoders -------------------
        # PAA encoded features
        x2_enc = self.context2(x2)  # (N, 256, H/8, W/8)
        x3_enc = self.context3(x3)  # (N, 256, H/16, W/16)
        x4_enc = self.context4(x4)  # (N, 256, H/16, W/16)
        
        # ------------------ 3.PAA-Decoders -------------------
        # PAA decoding(initial prediction) - aggregates multi-scale PAA encoder features
        # a5: initial saliency map for UACA
        f5, a5 = self.decoder(x4_enc, x3_enc, x2_enc)  # (N, 256, H/8, W/8), (N, 1, H/8, W/8)
        out5 = self.res(a5, base_size)  # (N, 1, H, W)
        
        # Sequential UACAs
        # ------------------ 4.1 UACA -------------------
        # inputs: concat-(N, 512, H/16, W/16), a5-(N, 1, H/8, W/8)
        # f4: (N, 256, H/16, W/16), a4: (N, 1, H/16, W/16)
        f4, a4 = self.attention4(torch.cat([x4_enc, self.ret(f5, x4_enc)], dim=1), a5)  
        out4 = self.res(a4, base_size)  # (N, 1, H, W)
        
        # ------------------ 4.2 UACA -------------------
        # inputs: concat-(N, 512, H/16, W/16), a4-(N, 1, H/16, W/16)
        # f3: (N, 256, H/16, W/16), a3: (N, 1, H/16, W/16)
        f3, a3 = self.attention3(torch.cat([x3_enc, self.ret(f4, x3_enc)], dim=1), a4)  
        out3 = self.res(a3, base_size)  # (N, 1, H, W)

        # ------------------ 4.3 UACA -------------------
        # inputs: concat-(N, 512, H/8, W/8), a3-(N, 1, H/16, W/16)
        # a2: (N, 1, H/8, W/8)
        _, a2 = self.attention2(torch.cat([x2_enc, self.ret(f3, x2_enc)], dim=1), a3)  # (N, 1, H/8, W/8)
        out2 = self.res(a2, base_size)  # (N, 1, H, W)
        
        if y is not None:
            loss5 = self.loss_fn(out5, y)
            loss4 = self.loss_fn(out4, y)
            loss3 = self.loss_fn(out3, y)
            loss2 = self.loss_fn(out2, y)

            loss = loss2 + loss3 + loss4 + loss5
            debug = [out5, out4, out3]
            
        else:
            loss = 0
            debug = []

        return {'pred': out2, 'loss': loss, 'debug': debug}

In [None]:
x = torch.randn([2, 3, 256, 256])
y = torch.randn([2, 1, 256, 256])
sample ={}
sample['image'] = x
sample['gt'] = y

model = UACANet()
out = model(sample)