# Convolutional CRFs for Semantic Segmentation

### This notebook contains code and results of experiments on the ICLR Paper submisson "Convolutional CRFs for Semantic Segmentation" as part of the ICLR 2019 Reproducibility challenge

### Load Images

In [6]:
import imageio
import matplotlib.pyplot as plt
import numpy as np
import scipy
import skimage.transform
import logging
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter

from convcrf.convcrf import default_conf, test_config

In [2]:
LABELS = 'data/img_bicycle_labels.png'
labels = imageio.imread(LABELS)

### Produce unary by adding noise to label

In [3]:
num_classes = 21
keep_prop=0.8
scale=8

#output
unary = None

shape = labels.shape # (H, W)
labels = labels.reshape(shape[0], shape[1]) # H * W

# Onehot encoding of labels
onehot = np.eye(num_classes)[labels] # H * W * num_classes


lower_shape =  (shape[0] // scale, shape[1] // scale)

# Scale down onehot labels to 1/8
label_down = skimage.transform.resize(onehot, 
                                      (lower_shape[0], lower_shape[1], num_classes), 
                                      order=1, 
                                      preserve_range=True, 
                                      mode='constant') # (lower_shape[0], lower_shape[1], num_classes)

# scale up onehot of labels to original
onehot_up = skimage.transform.resize(label_down,
                                      (shape[0], shape[1], num_classes),
                                      order=1, preserve_range=True,
                                      mode='constant')


noise = np.random.randint(0, num_classes, lower_shape)  # Random ints with scaled shape in num_classes range
noise = np.eye(num_classes)[noise] # sclaed shape * num_classes

# scale up noise labels
noise_up = skimage.transform.resize(noise,
                                    (shape[0], shape[1], num_classes),
                                    order=1, preserve_range=True,
                                    mode='constant') 

mask = np.floor(keep_prop + np.random.rand(*lower_shape))

mask_up = skimage.transform.resize(mask, (shape[0], shape[1], 1),
                                       order=1, preserve_range=True,
                                       mode='constant')



  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


In [4]:
unary = mask_up * onehot_up + (1 - mask_up) * noise_up

### Some helper functions 

In [17]:
def _get_ind(dz):
    if dz == 0:
        return 0, 0
    if dz < 0:
        return 0, -dz
    if dz > 0:
        return dz, 0


def _negative(dz):
    """
    Computes -dz for numpy indexing. Goal is to use as in array[i:-dz].

    However, if dz=0 this indexing does not work.
    None needs to be used instead.
    """
    if dz == 0:
        return None
    else:
        return -dz

### Implement message passing
In the mean field algorithm used in FullCRF (Krähenbühl & Koltun, 2011), the message passing is the bottleneck. Due to the assumption of conditional independence, we can now reformulate message passing as convolutions to get two-orders of magnitude speed up in inference time.

In [18]:
class MessagePassingCol():
    def __init__(self, feat_list, compat_list, merge, npixels, nlcasses, norm='sym', 
                 filter_size=5, clip_edges=0, use_gpu=False, blur=1, matmul=False, 
                 verbose=False, pyinn=False):
        
        span = filter_size // 2
        self.span = span
        self.filter_size = filter_size
        self.use_gpu = use_gpu
        self.verbose = verbose
        self.blur = blur
        self.pyinn = pyinn
        self.merge = merge
        self.npixels = npixels
        
        if not self.blur == 1 and self.blur % 2:
            raise NotImplementedError
        
        self.matmul = matmul
        self._gaus_list = []
        self._norm_list = []
        
        for feats, compat in zip(feat_list, compat_list):
            gaussian = self._create_convolutional_filters()
            
    def _create_convolutional_filters(self, features):
        span = self.span
        bs = features.shape[0]
        
        if self.blur > 1:
            off_0 = (self.blur - self.npixels[0] % self.blur) % self.blur
            off_1 = (self.blur - self.npixels[1] % self.blur) % self.blur
            pad_0 = math.ceil(off_0 / 2)
            pad_1 = math.ceil(off_1 / 2)
            

            features = torch.nn.functional.avg_pool2d(features,
                                                      kernel_size=self.blur,
                                                      padding=(pad_0, pad_1),
                                                      count_include_pad=False)

            npixels = [math.ceil(self.npixels[0] / self.blur),
                       math.ceil(self.npixels[1] / self.blur)]
            
            assert(npixels[0] == features.shape[2])
            assert(npixels[1] == features.shape[3])
        else:
            npixels = self.npixels

        gaussian_tensor = features.data.new(
            bs, self.filter_size, self.filter_size,
            npixels[0], npixels[1]).fill_(0)

        gaussian = Variable(gaussian_tensor)
        
        
        for dx in range(-span, span + 1):
            for dy in range(-span, span + 1):

                dx1, dx2 = _get_ind(dx)
                dy1, dy2 = _get_ind(dy)

                feat_t = features[:, :, dx1:_negative(dx2), dy1:_negative(dy2)]
                feat_t2 = features[:, :, dx2:_negative(dx1), dy2:_negative(dy1)] # NOQA

                diff = feat_t - feat_t2
                diff_sq = diff * diff
                exp_diff = torch.exp(torch.sum(-0.5 * diff_sq, dim=1))

                gaussian[:, dx + span, dy + span,
                         dx2:_negative(dx1), dy2:_negative(dy1)] = exp_diff

        return gaussian.view(
            bs, 1, self.filter_size, self.filter_size,
            npixels[0], npixels[1])

mess = MessagePassingCol([], [], [], [], [])
        

### ConvCRF
Full implementation of a generic CRF.

In [9]:
class ConvCRF(nn.Module):
    def __init__(self, npixels, nclasses, conf, mode='conv', 
                 filter_size=5, clip_edges=0, blur=1, use_gpu=False,
                 norm='sym', merge=False,verbose=False, trainable=False,
                 convcomp=False, weight=None, final_softmax=True,
                 unary_weight=10, pyinn=False):
        
        super(ConvCRF, self).__init__()
        
        self.nclasses = nclasses
        self.filter_size = filter_size
        self.clip_edges = clip_edges
        self.use_gpu = use_gpu
        self.mode = mode
        self.norm = norm
        self.merge = merge
        self.kernel = None
        self.verbose = verbose
        self.blur = blur
        self.final_softmax = final_softmax
        self.pyinn = pyinn
        self.conf = conf
        self.unary_weight = unary_weight
        
        if type(npixels) is tuple or type(npixels) is list:
            self.height = npixels[0]
            self.width = npixels[1]
        else:
            self.npixels = npixels
            
        if trainable:
            def register(name, tensor):
                self.register_parameter(name, Parameter(tensor))
        else:
            def register(name, tensor):
                self.register_buffer(name, Variable(tensor))
        
        if weight is None:
            self.weight = None
        else:
            register('weight', weight)
        
        if convcomp:
            self.comp = nn.Conv2d(nclasses, nclasses, kernel_size=1, padding=0, stride=1, bias=False)
            self.comp.weight.data.fill_(0.1 * math.sqrt(2.0 / nclasses))
        else:
            self.comp = None

    def clean_filters(self):
        self.kernel = None

    def add_pairwise_energies(self, feat_list, compat_list, merge):
        assert(len(feat_list) == len(compat_list))

        assert(self.use_gpu)

        
            

### GaussCRF

In [5]:
class GaussCRF(nn.Module):
    def __init__(self, conf, shape, nclasses=None):
        super(GaussCRF, self).init()
        self.conf = confb
        self.shape = shape
        self.nclasses = nclasses
        
        self.trainable = conf['trainable']
        
        if not conf['trainable_bias']:
            self.register_buffer('mesh', self._create_mesh())
        else:
            self.register_parameter('mesh', Parameter(self._create_mesh()))
        
        if self.trainable:
            def register(name, tensor):
                self.register_parameter(name, Parameter(tensor))
        else:
            def register(name, tensor):
                self.register_buffer(name, Variable(tensor))
        
        register('pos_sdims', torch.Tensor([1 / conf['pos_feats']['sdims']]))
        
        if conf['col_feats']['use_bias']:
            register('col_sdims', torch.Tensor([1 / conf['col_feats']['sdims']]))
        else:
            self.col_sdims = None

        register('col_schan', torch.Tensor([1 / conf['col_feats']['schan']]))
        register('col_compat', torch.Tensor([conf['col_feats']['compat']]))
        register('pos_compat', torch.Tensor([conf['pos_feats']['compat']]))
            
        if conf['weight'] is None:
            weight = None
        elif conf['weight'] == 'scalar':
            val = conf['weight_init']
            weight = torch.Tensor([val])
        elif conf['weight'] == 'vector':
            val = conf['weight_init']
            weight = val * torch.ones(1, nclasses, 1, 1)

    def _create_mesh(self, requires_grad=False):
        hcord_range = [range(s) for s in self.shape]
        mesh = np.array(np.meshgrid(*hcord_range, indexing='ij'),
                        dtype=np.float32)
        return torch.from_numpy(mesh)
    
    def forward(self, unary, img, num_iter=5):
        conf = self.conf
        bs, c, x, y = img.shape
    

### Get predictions

In [21]:
IMG = 'data/img_bicycle.png'
image = imageio.imread(IMG)

FILTER_SIZE = 7

shape = image.shape[0:2]
config = default_conf
config['filter_size'] = FILTER_SIZE


# Pytorch uses C, H, W
# image = image.transpose(2,0,1) # [C, H, W]
image = image.transpose(2, 0, 1)  # shape: [3, hight, width]

# Add bactch dim
image = image.reshape([1, 3, shape[0], shape[1]]) 
# img_var = Variable(torch.Tensor(image)).cuda()

unary = unary.transpose(2, 0, 1)  # shape: [3, hight, width]
# Add batch dim
unary = unary.reshape([1, num_classes, shape[0], shape[1]])

# unary_var = Variable(torch.Tensor(unary)).cuda()

(1, 21, 500, 334)

### Create args

In [137]:
from dotmap import DotMap

args = DotMap()
args.pyinn = False
args.nospeed = True 

In [None]:
# prediction = do_crf_inference(image, unary, args)