# <font style = "color:rgb(50,120,229)">Object-Contextual Representations for Semantic Segmentation</font>

## <font style = "color:rgb(50,120,229)">Paper Details</font>

1. **Authors**: Yuhui Yuan, Xilin Chen, Jingdong Wang
2. **Paper Link**: https://arxiv.org/pdf/1909.11065v2.pdf
3. **Category**: Semantic Segmentation

## <font style = "color:rgb(50,120,229)">Introduction</font>

**Semantic Segmentation**: Assigning a label to each pixel in an image.

**Approach**: Contextual Aggregation.

**Motivation**: Class label assigned to one pixel is the category of the object that the pixel belongs to.

### <font style = "color:rgb(8,133,37)">What does Context mean?</font>

<div class="alert alert-block alert-info">
The context of one position typically refers to a set of positions, e.g., the surrounding pixels.
</div>

If we refer to another paper (**Context Based Object Categorization: A Critical Survey**), Contextual Features are used to represent the interaction of an object with its surroundings. It can be divided into the following 3 categories:
1. **Semantic Context** - This focuses on object co-occurence and allows to correct label of one object without affecting the label of other objects. For example, a tree is more likely to co-occur with a plant than a whale.
2. **Spatial Context** - This focuses on the position of objects. For example, a dog is more likely to be present above grass and below sky rater than above sky and below grass. 
3. **Scale Context** - This focuses on relative size of objects. For example, a car is relatively smaller than a truck and not the other way round.

### <font style = "color:rgb(8,133,37)">Approach</font>
The approach discussed in the paper consists of the following 3 steps:

1. **Coarse Soft Segmentation** - This involves dividing the contextual pixels (surrounding pixels) into soft object regions. The word "soft" here means that our focus is NOT on carrying out accurate segmentation.
2. **Object Region Representation** - We use the soft segmentation obtained from the above step and the pixel representation to represent each object region.
3. **Object-Contextual Representation** (OCR) - We use the output from the above 2 steps along with Pixel-Region relation to obtain the augmented representations.

<img src="images/paper1/image01.png" alt="Pipeline of the approach" title="The Pipeline of the approach" />
<center><b>Figure 1</b>: The pipeline of the approach discussed in the paper. <a href="https://arxiv.org/pdf/1909.11065v2.pdf">Source</a></center>

### <font style = "color:rgb(8,133,37)">Differences</font>

**OCR vs Multi-Scale Context**

1. OCR differentiates contextual pixels which belong to the same class to the contextual pixels which belong to different class.
2. Multi-Scale Context approach only differentiates pixels present at different positions.

**OCR vs other Relational Context schemes**

The approach discussed in the paper considers not only the object region representations but also the pixel and pixel-region relations, unlike other approaches.

It should also be mentioned here that the current approach is also a relational context approach.

**OCR vs Coarse-to-fine Segmentation**

While "Coarse-to-fine Segmentation" is also followed in the current approach, the difference is the way the coarse segmentation is used. The OCR approach uses the coarse segmentation to generate a contextual representation, whereas the other approaches use it directly as an extra representation.

**OCR vs Region-wise Segmentation**

The region-wise segmentation first groups the pixels into **super pixels** which are then assigned a label. OCR on the other hand, uses the grouped regions to learn a better labelling for the pixels, instead of directly using them for segmentation.

## <font style = "color:rgb(50,120,229)">Approach</font>

It's now time to go into the mathematical details of the approach.

### <font style = "color:rgb(8,133,37)">Semantic Segmentation - Problem Statement</font>

Given $K$ classes, assign each pixel $p_i$ of image $I$ a label $l_i$ (which is one of the $K$ <b>unique</b> classes).

<div class="alert alert-block alert-info">
<b>Multi-Scale Context (Optional)</b>

Multi-Scale context can be represented by the following equation:

$$y_i ^d = \sum_{p_s = p_i + d \Delta_t} K_t ^ d x_s \tag{1}$$

Where,

$y_i^d$ is the <b>output</b> representation of position $p_i$ for the $d$th dilated convolution,

$d$ is the dilation rate,

$t$ is the index of convolution (-1,0,1 for a 3x3 convolution),

$\Delta_t = (\Delta_w, \Delta_h) \mid \Delta_w = -1, 0, 1, \Delta_h = -1,0,1$ for a 3x3 convolution,

$x_s$ is the representation at $p_s$,

$K^d$ is the kernel for $d$th dilated convolution
</div>

<div class="alert alert-block alert-info">
<b>Relational Context (Optional)</b>

Relational context can be represented by the following equation:

$$y_i = \rho \left(\sum_{s \in I} w_{is} \delta(x_s) \right) \tag{2}$$

Where,

$y_i$ is the <b>output</b> representation of position $p_i$,

$I$ refers to the image,

$w_{is}$ is the relation between $x_i$ and $x_s$,

$\delta(\cdot)$ and $\rho(\cdot)$ are transform functions,

$x_s$ is the representation at $p_s$
</div>

Next comes my favorite part, the formulation of the current approach.

### <font style = "color:rgb(8,133,37)">Step 1: Soft Object Regions</font>

The image $I$ is partitioned into $K$ soft object regions: {$M_1, M_2, ..., M_K$} where $M_i$ corresponds to class $i$.

$M_i$ is a 2D map where each entry represents the probability of the corresponding pixel belonging to the class $i$. 

### <font style = "color:rgb(8,133,37)">Step 2: Object Region Representation</font>

The representations of all the pixels obtained in step 1 are aggregated as follows:

$$f_k = \sum_{i \in I} m_{ki} x_{i} \tag{3}$$

Where, $m_{ki}$ represents the **normalized** probability of the pixel $p_i$ belonging to class $k$.

### <font style = "color:rgb(8,133,37)">Step 3: Object Contextual Representation</font>

First we obtain the relation between a pixel and an object region:

$$w_{ik} = \cfrac{e^{\kappa(x_i, f_k)}}{\sum_{j=1}^{K} e^{\kappa(x_i, f_k)}}\tag{4}$$

Where, 

$\kappa (\cdot)$ is a relation function,

Finally, we can obtain the object contextual representation $y_i$ for pixel $p_i$ as shown below:

$$y_i = \rho \left(\sum_{k=1}^{K} w_{ik} \delta(f_k) \right) \tag{5}$$

Notice the similarity between equation (5) and equation (2). 

### <font style = "color:rgb(8,133,37)">Step 4: Augmented Representation</font>

The final representation for pixel $p_i$ is calculated as follows:

$$z_i = g([x_i^T y_i^T]^T) \tag{6}$$

$g(\cdot)$ here is a transform function with the only purpose to join the effect of original representation $x_i$ and object contextual representation $y_i$

## <font style = "color:rgb(50,120,229)">References</font>
1. Object-Contextual Representations for Semantic Segmentation - https://arxiv.org/pdf/1909.11065v2.pdf
2. Context Based Object Categorization: A Critical Survey - https://vision.cornell.edu/se3/wp-content/uploads/2014/09/context_review08_0.pdf
3. Jupyter Markdown - https://www.ibm.com/support/knowledgecenter/en/SSGNPV_1.1.3/dsx/markd-jupyter.html

# <font style = "color:rgb(50,120,229)">Network Architecture</font>

Code Source: https://github.com/rosinality/ocr-pytorch (by **Kim Seonghyeon**)

In [1]:
# Import required modules
import torch
from torch import nn
from torch.nn import functional as F

So we already know that in PyTorch, everything is clubbed into a class and one of the most important functions we are looking for is **`forward`** for forward propagation implementation.

In [2]:
def conv2d(in_channel, out_channel, kernel_size):
    layers = [
        nn.Conv2d(
            in_channel, out_channel, kernel_size, padding=kernel_size // 2, bias=False
        ),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
    ]

    return nn.Sequential(*layers)

In [3]:
def conv1d(in_channel, out_channel):
    layers = [
        nn.Conv1d(in_channel, out_channel, 1, bias=False),
        nn.BatchNorm1d(out_channel),
        nn.ReLU(),
    ]

    return nn.Sequential(*layers)

In [4]:
class OCR(nn.Module):
    def __init__(self, n_class, backbone, feat_channels=[768, 1024]):
        super().__init__()

        self.backbone = backbone

        ch16, ch32 = feat_channels

        self.L = nn.Conv2d(ch16, n_class, 1)
        self.X = conv2d(ch32, 512, 3)

        self.phi = conv1d(512, 256)
        self.psi = conv1d(512, 256)
        self.delta = conv1d(512, 256)
        self.rho = conv1d(256, 512)
        self.g = conv2d(512 + 512, 512, 1)

        self.out = nn.Conv2d(512, n_class, 1)

        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

    def forward(self, input, target=None):
        input_size = input.shape[2:]
        stg16, stg32 = self.backbone(input)[-2:]

        X = self.X(stg32)
        L = self.L(stg16)
        batch, n_class, height, width = L.shape
        l_flat = L.view(batch, n_class, -1)
        # M: NKL
        M = torch.softmax(l_flat, -1)
        channel = X.shape[1]
        X_flat = X.view(batch, channel, -1)
        # f_k: NCK
        f_k = (M @ X_flat.transpose(1, 2)).transpose(1, 2)

        # query: NKD
        query = self.phi(f_k).transpose(1, 2)
        # key: NDL
        key = self.psi(X_flat)
        logit = query @ key
        # attn: NKL
        attn = torch.softmax(logit, 1)

        # delta: NDK
        delta = self.delta(f_k)
        # attn_sum: NDL
        attn_sum = delta @ attn
        # x_obj = NCHW
        X_obj = self.rho(attn_sum).view(batch, -1, height, width)

        concat = torch.cat([X, X_obj], 1)
        X_bar = self.g(concat)
        out = self.out(X_bar)
        out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=False)

        if self.training:
            aux_out = F.interpolate(
                L, size=input_size, mode='bilinear', align_corners=False
            )

            loss = self.criterion(out, target)
            aux_loss = self.criterion(aux_out, target)

            return {'loss': loss, 'aux': aux_loss}, None

        else:
            return {}, out

Notice that all the transform functions have been implemented as `1 x 1 conv --> BN --> ReLU` as per the original paper.