# Edge consistency and edge potentials

Partial implementation of the paper [Gated-SCNN: Gated Shape CNNs for Semantic Segmentation](https://arxiv.org/pdf/1907.05740.pdf)

[Gumbel-Max Trick](https://laurent-dinh.github.io/2016/11/22/gumbel-max.html)

[Categorical reparameterization with gumbel-softmax](https://arxiv.org/pdf/1611.01144.pdf)

[The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables](https://arxiv.org/pdf/1611.00712.pdf)

In [38]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

In [46]:
def class_balanced_cross_entropy(logits, labels, name='balanced_cross_entropy_loss'):
    """
    The class-balanced cross entropy loss
    Args:
        logits: of shape (b, ...).
        labels: of the same shape. the ground truth in {0,1}.
    Returns:
        class-balanced cross entropy loss.
    """
    # with tf.name_scope('class_balanced_binary_cross_entropy'):
    y = tf.cast(labels, tf.float32)

    count_neg = tf.reduce_sum(1. - y)
    count_pos = tf.reduce_sum(y)
    beta = count_neg / (count_neg + count_pos)

    pos_weight = beta / (1 - beta)
    neg_weight = tf.ones_like(pos_weight, dtype=tf.float32)
    class_weights = tf.stack([neg_weight, pos_weight], 0)

    weights = tf.gather(class_weights, labels)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits, weights=weights)
    # loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits, weights=weights)
    return loss

def semantic_edge_consistency_regularizer(gt, pred, pred_1st, thresh=0.8):
    """
    :param gt [b, h, w, c]:
    :param pred [b, h, w, c]:
    :param pred_1st [b, h, w, 1]:
    :param thresh probability to consider an edge in our prediction:
    :return cross entropy of classifications near on an edge:
     whereever we have predicted an edge, calculated the cross entropy there.
    This penalises the edges more strongly, encouraging them to be correct at the boundary
    """
    mask = tf.stop_gradient(tf.cast((pred_1st > thresh), tf.float32))
    mask = tf.expand_dims(mask, axis=-1)
    gt = tf.stop_gradient(tf.cast(gt, tf.float32))
    pred = tf.stop_gradient(tf.cast(pred, tf.float32))
    pred = tf.math.multiply(pred, mask)
    gt = tf.math.multiply(gt, mask)

    # return tf.reduce_mean(tf.losses.categorical_crossentropy(gt, pred, from_logits=True))
    return tf.reduce_mean(tf.compat.v2.losses.categorical_crossentropy(gt, pred, from_logits=True))

def semantic_edge_potential_regularizer(gt, pred, thresh=0.8):
    """
    :param gt [b, h, w, c] segmentation labels:
    :param pred [b, h, w, c] segmentation logits:
    :param thresh intensity to be considered edge:
    :return the difference in boundaries between predicted versus actual
            where the boundaries come from the 2nd, rather than the 1st stage:
    """
    def _gumbel_softmax(logits, eps=1e-8, tau=1.):
        """
        :param logits:
        :param eps:
        :param tau temprature:
        :return soft approximation to argmax:
        see https://arxiv.org/abs/1611.01144
        """
        g = tf.random.uniform(tf.shape(logits))
        g = -tf.math.log(eps - tf.math.log(g + eps))
        return tf.nn.softmax((logits + g) / tau)

    def _all_close(x, y, rtol=1e-5, atol=1e-8):
        return tf.reduce_all(tf.abs(x - y) <= tf.abs(y) * rtol + atol)

    def _gradient_mag(tensor, from_rgb=False, eps=1e-12):
        if from_rgb:
            tensor = tf.image.rgb_to_grayscale(tensor[..., :3])
        tensor_edge = tf.image.sobel_edges(tensor)

        def _normalised_mag():
            mag = tf.reduce_sum(tensor_edge ** 2, axis=-1) + eps
            mag = tf.math.sqrt(mag)
            mag /= tf.reduce_max(mag, axis=[1, 2], keepdims=True)
            return mag

        z = tf.zeros_like(tensor)
        normalised_mag = tf.cond(
            _all_close(tensor_edge, tf.zeros_like(tensor_edge)),
            lambda: z,
            _normalised_mag, 
            name='potato')

        return normalised_mag

    # gt = tf.stop_gradient(tf.cast(gt, tf.float32))
    gt = tf.cast(gt, tf.float32)
    # soft approximation to argmax, so we can build an edge
    pred = _gumbel_softmax(pred)  ## --

    # normalised image gradients to give us edges
    # images will be [b, h, w, n_classes]
    gt_edges = _gradient_mag(gt)
    pred_edges = _gradient_mag(pred)

    # [b*h*w, n]
    gt_edges = tf.reshape(gt_edges, [-1, tf.shape(gt_edges)[-1]])
    pred_edges = tf.reshape(pred_edges, [-1, tf.shape(gt_edges)[-1]])

    # take the difference between these two gradient magnitudes
    # we will first take all the edges from the ground truth image
    # and then all the edges from the predicted
    edge_difference = tf.abs(gt_edges - pred_edges)

    # gt edges and disagreement with pred
    mask_gt = tf.cast((gt_edges > thresh ** 2), tf.float32)
    contrib_0 = tf.boolean_mask(edge_difference, mask_gt)

    contrib_0 = tf.cond(
        tf.greater(tf.size(contrib_0), 0),
        lambda: tf.reduce_mean(contrib_0),
        lambda: 0.)

    # vice versa
    # mask_pred = tf.stop_gradient(tf.cast((pred_edges > thresh ** 2), tf.float32))
    mask_pred = tf.cast((pred_edges > thresh ** 2), tf.float32)
    contrib_1 = tf.reduce_mean(tf.boolean_mask(edge_difference, mask_pred))
    contrib_1 = tf.cond(
        tf.greater(tf.size(contrib_1), 0),
        lambda: tf.reduce_mean(contrib_1),
        lambda: 0.)
    return tf.reduce_mean(0.5 * contrib_0 + 0.5 * contrib_1)

In [43]:
sess = tf.Session()
a = tf.random.uniform((2, 25, 25, 3), minval=0, maxval=2, dtype=tf.dtypes.int32)
b = tf.random.uniform((2, 25, 25, 3), minval=0, maxval=2, dtype=tf.dtypes.int32)

In [45]:
gt = sess.run(a)
pd = sess.run(a)

In [48]:
loss = semantic_edge_potential_regularizer(gt, pd)

In [63]:
gt = tf.cast(gt, tf.float32)
pred = _gumbel_softmax(pd)

gt_edges = _gradient_mag(gt)
pred_edges = _gradient_mag(pred)

gt_edges = tf.reshape(gt_edges, [-1, tf.shape(gt_edges)[-1]])
pred_edges = tf.reshape(pred_edges, [-1, tf.shape(gt_edges)[-1]])

In [67]:
tensor_edge = tf.image.sobel_edges(gt)
z = tf.zeros_like(gt)

In [70]:
ddxk = _all_close(tensor_edge, tf.zeros_like(tensor_edge))

In [78]:
sess.run(tf.reduce_sum(tensor_edge ** 2, axis=-1)).shape

(2, 25, 25, 3)