In [2]:
"""
training code
also contains the F-score evaluation code.
"""
from __future__ import absolute_import
from __future__ import division
import argparse
import logging
import os
import numpy as np
import torch
#from apex import amp

from config import cfg, assert_and_infer_cfg
from utils.misc import AverageMeter, prep_experiment, evaluate_eval, fast_hist, set_bn_eval
from utils.f_boundary import eval_mask_boundary
import datasets
import loss
import network
import optimizer


args = argparse.Namespace()
args.lr = 0.002
args.arch = 'network.EBLNet.EBLNet_resnext101_os8'
args.dataset = 'MSD'

args.cv = 0
args.class_uniform_pct = 0.0
args.class_uniform_tile = 1024
args.coarse_boost_classes = None

args.img_wt_loss = False
args.batch_weighting = False
args.dice_loss = True
args.ohem = False
args.aux = False
args.jointwtborder = True
args.joint_edge_loss_light_cascade = True
args.edge_weight = 3.0
args.body_weight = 1.0
args.seg_weight = 1.0
args.rlx_off_epoch = -1
args.rescale = 1.0
args.repoly = 1.5
args.apex = False
args.fp16 = False

args.local_rank = 0

args.sgd = True
args.adam = False
args.amsgrad = False

args.freeze_trunk = False
args.hardnm = 0

args.trunk = 'resnet101'
args.max_epoch = 160
args.eval_epoch = 150
args.max_cu_epoch = 160
args.start_epoch = 0
args.color_aug = 0.0
args.gblur = True
args.bblur = False
args.lr_schedule = 'poly'
args.poly_exp = 0.9
args.bs_mult = 2
args.bs_mult_val = 1
args.crop_size = 384
args.pre_size = None
args.scale_min = 1.0
args.scale_max = 1.0
args.weight_decay = 1e-4
args.momentum = 0.9
args.snapshot = 'C:/Users/iml/Desktop/EBLNet-main/best.pth'
args.restore_optimizer = False
args.exp = 'default'
args.tb_tag = ''
args.ckpt = 'logs/ckpt'
args.tb_path = 'logs/tb'
args.syncbn = False
args.fix_bn = False
args.evaluateF = False
args.eval_thresholds = '0.0005,0.001875,0.00375,0.005'
args.dump_augmentation_images = False
args.test_mode = False
args.wb = 1.0
args.maxSkip = 0
args.scf = False

args.print_freq = 5
args.eval_freq = 1
args.num_cascade = 3
args.weight_mean = 0
args.num_points = 97
args.thres_gcn = 0.9
args.thicky = 8

args.ngpu = 1
torch.backends.cudnn.benchmark = True
args.world_size = 1

In [3]:
train_loader, val_loader, train_obj = datasets.setup_loaders(args)
criterion, criterion_val = loss.get_loss(args)
net = network.get_net(args, criterion)

optim, scheduler = optimizer.get_optimizer(args, net)



In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from network import resnet_d as Resnet_Deep
from network.resnext import resnext101_32x8
from network.nn.mynn import Norm2d
from network.nn.contour_point_gcn import ContourPointGCN
from network.nn.operators import _AtrousSpatialPyramidPoolingModule


num_classes = 2  # args.dataset_cls.num_classes
criterion = criterion
num_cascade = args.num_cascade
num_points = args.num_points
threshold = args.thres_gcn

trunk = 'resnext-101-32x8'
variant = 'D'
skip = 'm1'
skip_num = 48

In [5]:
class Edge_extractorWofirstext(nn.Module):
    def __init__(self, inplane, skip_num, norm_layer):
        '''default: inplane=256, skip_num=48'''
        super(Edge_extractorWofirstext, self).__init__()
        self.skip_mum = skip_num
        
        self.pre_extractor = nn.Sequential(
            nn.Conv2d(inplane, inplane, kernel_size=3,
                      padding=1, groups=1, bias=False),
            nn.BatchNorm2d(inplane),
            nn.ReLU(inplace=False)
        )
        
        self.extractor = nn.Sequential(
            nn.Conv2d(inplane + skip_num, inplane, kernel_size=3,
                      padding=1, groups=8, bias=False),
            nn.BatchNorm2d(inplane),
            nn.ReLU(inplace=False)
        )

    def forward(self, aspp, layer1):  
        '''supoose input image: (N, 3, 512, 512)
           aspp: high-level feature (N, 256, 64, 64)
           layer1: (projected)low-level feature (N, 48, 128, 128)
           
           seg_edge, seg_body: (N, 256, 128, 128)'''
        seg_edge = torch.cat([F.interpolate(aspp, size=layer1.size()[2:], mode='bilinear',
                                            align_corners=True), layer1], dim=1)  # 200
        seg_edge = self.extractor(seg_edge) 
        
        # F_residual = F_in - F_edge
        seg_body = F.interpolate(aspp, layer1.size()[2:], mode='bilinear', align_corners=True) - seg_edge

        return seg_edge, seg_body

In [9]:
class EBLNet(nn.Module):
    """
    Implement deeplabv3 plus module without depthwise conv
    A: stride=8
    B: stride=16
    with skip connection
    """
    def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D',
                 skip='m1', skip_num=48, num_cascade=4, num_points=96, threshold=0.8):
        ''''''
        super(EBLNet, self).__init__()
        self.criterion = criterion
        self.variant = variant
        self.skip = skip
        self.skip_mum = skip_num
        self.num_cascade = num_cascade
        self.num_points = num_points
        self.threshold = threshold
        
        # --------------------------- 0.feature extractor ---------------------------
        # resnet
        resnet = resnext101_32x8()
        resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer0 = resnet.layer0
        self.layer1, self.layer2, self.layer3, self.layer4 = \
            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        
        # aspp (high-level)
        self.aspp = _AtrousSpatialPyramidPoolingModule(in_dim=2048, reduction_dim=256,
                                                       output_stride=8 if self.variant == 'D' else 16)
        self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)
        
        # low-level feature projection
        self.bot_fine = nn.Conv2d(256, self.skip_mum, kernel_size=1, bias=False)  # skip_num=48
            
        # --------------------------- 1.for initial F_residual(body) & F_edge ---------------------------
        # outputs F_residual & F_edge
        self.edge_extractors = [Edge_extractorWofirstext(256, norm_layer=Norm2d, skip_num=48)
                                for _ in range(self.num_cascade)]
        self.edge_extractors = nn.ModuleList(self.edge_extractors)
         
        # --------------------------- 2.to refine F_residual with F_high ---------------------------
        # high-level feature projection (F_high)
        # (N, C, 64, 64) -> (N, 48, 128, 128)
        self.body_fines = nn.ModuleList()
        for i in range(self.num_cascade):  # num_cascade=3
            inchannels = 2 ** (11 - i)
            self.body_fines.append(nn.Conv2d(inchannels, 48, kernel_size=1, bias=False))
        
        # concat(F_high, F_residual)conv to outputs F_residual'
        self.body_fuse = [nn.Conv2d(256 + 48, 256, kernel_size=1, bias=False) for _ in range(self.num_cascade)]
        self.body_fuse = nn.ModuleList(self.body_fuse)
        
        # --------------------------- 2.1.final residual head ---------------------------
        # F_residual' -> F_r (supervised by GT-residual(body) mask)
        self.body_out_pre = [nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)) for _ in range(self.num_cascade)]
        self.body_out_pre = nn.ModuleList(self.body_out_pre)
        
        self.body_out = nn.ModuleList([nn.Conv2d(256, num_classes, kernel_size=1, bias=False)
                                       for _ in range(self.num_cascade)])
        
        #  --------------------------- 3. to outputs F_b(boundary) from F_edge ---------------------------
        self.edge_out_pre = [nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)) for _ in range(self.num_cascade)]
        self.edge_out_pre = nn.ModuleList(self.edge_out_pre)
        
        self.edge_out = nn.ModuleList([nn.Conv2d(256, 1, kernel_size=1, bias=False)
                                       for _ in range(self.num_cascade)])
        
        # --------------------------- 4. point-based GCN to refine F_merge with F_b  ---------------------------
        # refine F_merge -> F_m where F_merge = F_residual + F_edge (from stage 1)
        self.refines = [ContourPointGCN(256, self.num_points, self.threshold) for _ in range(self.num_cascade)]
        self.refines = nn.ModuleList(self.refines)
        
        # --------------------------- 5. final seg head with F_m ---------------------------
        self.final_seg_out_pre = [nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)) for _ in range(self.num_cascade - 1)]
        
        self.final_seg_out_pre.append(nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)))  
        self.final_seg_out_pre = nn.ModuleList(self.final_seg_out_pre)
         
        self.final_seg_out = nn.ModuleList([nn.Conv2d(256, num_classes, kernel_size=1, bias=False)
                                            for _ in range(self.num_cascade)])
        
    def forward(self, x, gts=None):
        '''x: (N, 3, 512, 512)'''
        
        # ------------------ 0. extract features ------------------
        x_size = x.size()  
        # intermediate features
        feats = []
        feats.append(self.layer0(x))  # (N, 64, 128, 128) 
        feats.append(self.layer1(feats[0]))  # (N, 256, 128, 128) 
        feats.append(self.layer2(feats[1]))  # (N, 512, 64, 64)
        feats.append(self.layer3(feats[2]))  # (N, 1024, 64, 64)
        feats.append(self.layer4(feats[3]))  # (N, 2048, 64, 64)
        
        # high-level feature
        aspp = self.aspp(feats[-1])  # (N, 1280, 64, 64) 
        fine_size = feats[1].size()   
        
        aspp_ = self.bot_aspp(aspp)  # (N, 256, 64, 64) 
        final_fuse_feat = F.interpolate(aspp_, size=fine_size[2:], mode='bilinear', align_corners=True)  # (N, 256, 128, 128)
        
        # low-level feature projection
        # it is used at all cascade stage
        low_feat = self.bot_fine(feats[1])  # (N, 256, 128, 128) -> (N, 48, 128, 128)

        seg_edges = []
        seg_edge_outs = []
        seg_bodys = []
        seg_body_outs = []
        seg_finals = []
        seg_final_outs = []

        for i in range(self.num_cascade):
            if i == 0:
                last_seg_feat = aspp_  # (N, 256, 64, 64)
            else:
                # seg_finals: F_m(refined F_merge where F_merge = F_edge + F_residual)
                last_seg_feat = seg_finals[-1]  # from previous cascade stage (N, 256, 128, 128)
                last_seg_feat = F.interpolate(last_seg_feat, size=aspp_.size()[2:],
                                              mode='bilinear', align_corners=True)  # (N, 256, 64, 64)
            
            # ------------------ 1. initial F_edge & F_residual ------------------
            # last_seg_feat: (N, 256, 64, 64) / low_feat: (N, 48, 128, 128)
            # seg_edge: F_edge / seg_body: F_residual
            seg_edge, seg_body = self.edge_extractors[i](last_seg_feat, low_feat)  # (N, 256, 128, 128), (N, 256, 128, 128)
            
            # ------------------ 2. refine F_residual -> F_residual' ------------------
            # high-level feature projection (F_high)
            # feats[-1]: (N, 2048, 64, 64)
            # feats[-2]: (N, 1024, 64, 64)
            # feats[-3]: (N,512, 64, 64)
            high_fine = F.interpolate(self.body_fines[i](feats[-(i + 1)]), size=fine_size[2:], mode='bilinear',
                                      align_corners=True)  # (N, 48, 128, 128)
            
            # F_residual -> F_residual' with F_high
            seg_body = self.body_fuse[i](torch.cat([seg_body, high_fine], dim=1))  # (N, 256, 128, 128)
            
            # ------------------ 2.1 final residual head: F_residual' -> F_r(body) ------------------
            # F_residual' -> F_r
            seg_body_pre = self.body_out_pre[i](seg_body)  # (N, 256, 128, 128)
            seg_body_out = F.interpolate(self.body_out[i](seg_body_pre), size=x_size[2:],
                                         mode='bilinear', align_corners=True)  # (N, # classes, 512, 512)
            seg_bodys.append(seg_body_pre)
            seg_body_outs.append(seg_body_out)
            
            # ------------------ 3. F_edge -> F_b(boundary) ------------------
            # F_edge -> F_b (boundary)
            seg_edge_pre = self.edge_out_pre[i](seg_edge)  # (N, 256, 128, 128)
            seg_edge_out_pre = self.edge_out[i](seg_edge_pre)  # (N, 1, 128, 128)

            seg_edge_out = F.interpolate(seg_edge_out_pre, size=x_size[2:],
                                         mode='bilinear', align_corners=True)  # (N, 1, 512, 512)
            seg_edges.append(seg_edge_pre)
            seg_edge_outs.append(seg_edge_out)
            
            # ------------------ F_merge = F_body + F_edge (from 1.) ------------------
            seg_out = seg_body + seg_edge  # (N, 256, 128, 128)
            
            # ------------------ 4. F_merge -> F_merge' with PGM ------------------
            # seg_edge_out_pre: F_b(HXWX1)
            seg_out2 = self.refines[i](seg_out, torch.sigmoid(seg_edge_out_pre.clone().detach()))  # (N, 256, 128, 128)
            
            # ------------------ 5. final semantic segmentation ------------------
            if i >= self.num_cascade - 1:
                # final_fuse_feat: projected aspp (N, 256, 128, 128)
                seg_final_pre = self.final_seg_out_pre[i](torch.cat([final_fuse_feat, seg_out2], dim=1))  # (N, 256, 128, 128)
            else:
                seg_final_pre = self.final_seg_out_pre[i](seg_out2)  # (N, 256, 128, 128)
                
            seg_final_out = F.interpolate(self.final_seg_out[i](seg_final_pre), size=x_size[2:],
                                          mode='bilinear', align_corners=True)  # (N, # classes, 512, 512)
            seg_finals.append(seg_final_pre)
            seg_final_outs.append(seg_final_out)

        # if self.training:
        #    return self.criterion((seg_final_outs, seg_body_outs, seg_edge_outs), gts)

        return seg_final_outs[-1]

In [10]:
model = EBLNet(num_classes=2, trunk='resnext-101-32x8', variant='D', skip='m1', skip_num=48,  # num_classes=2 for bg/fg
               num_cascade=3, num_points=97, threshold=0.9)
x = torch.randn([2, 3, 512, 512])
out = model(x, gts=None)
print(out.shape)

torch.Size([2, 2, 512, 512])
