##Download images & target labels for training AMN

In [None]:
!gdown https://drive.google.com/uc?id=1eJTnYHar27Eeo8hvK8YRS6Iwgl94Ngww

!apt install pv

!unzip -o source_code.zip | pv -l >/dev/null

!gdown https://drive.google.com/uc?id=1nNO5uVH5FzBuiVG85D2gLrG3tel87NHv

!unzip -o data.zip -d ./AMN | pv -l >/dev/null

Downloading...
From: https://drive.google.com/uc?id=1eJTnYHar27Eeo8hvK8YRS6Iwgl94Ngww
To: /content/source_code.zip
100% 557k/557k [00:00<00:00, 124MB/s]
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
Suggested packages:
  doc-base
The following NEW packages will be installed:
  pv
0 upgraded, 1 newly installed, 0 to remove and 19 not upgraded.
Need to get 48.3 kB of archives.
After this operation, 123 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic/main amd64 pv amd64 1.6.6-1 [48.3 kB]
Fetched 48.3 kB in 1s (43.2 kB/s)
Selecting previously unselected package pv.
(Reading database ... 155680 files and directories currently installed.)
Preparing to unpack .../archives/pv_1.6.6-1_amd64.deb ...
Unpacking pv (1.6.6-1) ...
Setting up pv (1.6.6-1) ...
Processing triggers f

In [None]:
cd AMN

/content/AMN


In [None]:
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git
!pip install chainercv

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/lucasb-eyer/pydensecrf.git
  Cloning https://github.com/lucasb-eyer/pydensecrf.git to /tmp/pip-req-build-omxcr5x8
  Running command git clone -q https://github.com/lucasb-eyer/pydensecrf.git /tmp/pip-req-build-omxcr5x8
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: pydensecrf
  Building wheel for pydensecrf (PEP 517) ... [?25l[?25hdone
  Created wheel for pydensecrf: filename=pydensecrf-1.0rc2-cp37-cp37m-linux_x86_64.whl size=2781694 sha256=6d0236b6c8b21047cd9a4c8c223f413da205a5e8b51d36bf236dac2d41b17153
  Stored in directory: /tmp/pip-ephem-wheel-cache-ojlk8gaw/wheels/c1/7e/80/99adc0b2f215180486e24dd9c700028343ba5f566514a0ef05
Successfully built pydensecrf
Installing collected packages: pydensecrf
Su

In [None]:
import torch
import torch.nn as nn
from torch.backends import cudnn
cudnn.enabled = True
from torch.utils.data import DataLoader
import torch.nn.functional as F

import voc12.dataloader
from misc import pyutils, imutils
from net import resnet50

import numpy as np
from chainercv.datasets import VOCSemanticSegmentationDataset
from chainercv.evaluations import calc_semantic_segmentation_confusion
from tqdm.auto import tqdm
from PIL import Image

--------------------------------------------------------------------------------
CuPy (cupy-cuda111) version 9.4.0 may not be compatible with this version of Chainer.
Please consider installing the supported version by running:
  $ pip install 'cupy-cuda111>=7.7.0,<8.0.0'

See the following page for more details:
  https://docs.cupy.dev/en/latest/install.html
--------------------------------------------------------------------------------

  requirement=requirement, help=help))


##Configuration for training AMN

In [None]:
class arguments:
    def __init__(self):
        self.voc12_root = 'data'
        self.train_list = 'voc12/train_aug.txt'
        self.infer_list = 'voc12/train.txt'
        self.chainer_eval_set = 'train'
        self.ir_label_out_dir = 'result/ir_label'
        self.amn_network = 'resnet50_amn'
        self.amn_crop_size = 512
        self.amn_batch_size = 16
        self.amn_num_epochs = 1
        self.num_workers = 2
        self.eps = 0.4

args = arguments()

## Load and prepare the PASCAL VOC 2012 datasets


In [None]:
train_dataset = voc12.dataloader.VOC12SegmentationDataset(args.train_list,
                                                            label_dir=args.ir_label_out_dir,
                                                            voc12_root=args.voc12_root,
                                                            hor_flip=True,
                                                            crop_size=args.amn_crop_size,
                                                            crop_method="random",
                                                            rescale=(0.5, 1.5)
                                                            )

train_data_loader = DataLoader(train_dataset, batch_size=args.amn_batch_size,
                                shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True)

val_dataset = voc12.dataloader.VOC12SegmentationDataset(args.infer_list,
                                                            label_dir=args.ir_label_out_dir,
                                                            voc12_root=args.voc12_root,
                                                            crop_size=None,
                                                            crop_method="none",
                                                            )

val_data_loader = DataLoader(val_dataset, batch_size=1,
                                shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False)

##Define a Activation Manipulation Network (AMN)

1.   ResNet50 encoder
2.   ASPP decoder for per-pixel classification (PCL)
3.   Label conditionining (LC) module

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class _ASPP(nn.Module):
    """
    Atrous spatial pyramid pooling (ASPP)
    DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs
    https://arxiv.org/abs/1606.00915
    """
    def __init__(self, in_ch, out_ch, rates):
        super(_ASPP, self).__init__()
        for i, rate in enumerate(rates):
            self.add_module(
                "c{}".format(i),
                nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True),
            )

        for m in self.children():
            nn.init.normal_(m.weight, mean=0, std=0.01)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return sum([stage(x) for stage in self.children()])


class AMN(nn.Module):

    def __init__(self):
        super(AMN, self).__init__()

        ###################################################################
        # 1.Resnet50 encoder pretrained on ImageNet
        ###################################################################
        self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 2, 1))

        self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool,
                                    self.resnet50.layer1)
        self.stage2 = nn.Sequential(self.resnet50.layer2)
        self.stage3 = nn.Sequential(self.resnet50.layer3)
        self.stage4 = nn.Sequential(self.resnet50.layer4)

        astrous_rates = [6, 12, 18, 24]
        ###################################################################
        # 2.ASPP decoder
        # [Dropout] probability: 0.1
        # [ASPP module] in_ch: 2048, out_ch=21, rates=astrous_rates
        ###################################################################
        # (YOUR CODE HERE)
        # self.classifier = nn.Sequential(...)
        ###################################################################

        ###################################################################
        # 3.LC module
        # [fc layer] in_features: 20, out_features: 2048
        ###################################################################
        # (YOUR CODE HERE)
        # self.label_enc =
        ###################################################################

        self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4])
        self.newly_added = nn.ModuleList([self.classifier, self.label_enc])

    def forward(self, img, label_cls):

        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        feature_map = self.stage4(x)

        ###################################################################
        # 1.Obtain label embedding using LC
        # 2.Multiplying the label embedding to the feature map
        ###################################################################
        # (YOUR CODE HERE)
        # label_embedding = 
        # feature_map = 
        ###################################################################

        logit = self.classifier(feature_map)

        return logit

    def train(self, mode=True):
        for p in self.resnet50.conv1.parameters():
            p.requires_grad = False
        for p in self.resnet50.bn1.parameters():
            p.requires_grad = False

    def trainable_parameters(self):

        return (list(self.backbone.parameters()), list(self.newly_added.parameters()))

In [None]:
model = AMN().cuda()

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

##Define a Loss function and optimizer
*   Balanced cross-entropy loss
*   Adam optimizer

In [None]:
def balanced_cross_entropy(logits, labels, labels_smooth):
    """
    :param logits: shape: (N, C, H, W)
    :param labels: shape: (N, H, W)
    :param one_hot_labels: shape: (N, C, W, W)
    :return: loss
    """

    N, C, H, W = logits.shape
    ###################################################################
    # balanced cross-entropy
    ###################################################################
    # (YOUR CODE HERE)
    # log_logits = 
    # loss_structure = 
    #
    # ignore_mask_fg = 
    # ignore_mask_bg = 
    #
    # ...
    # 
    # loss_fg = 
    # loss_bg =
    #
    ###################################################################

    loss = (loss_fg + loss_bg) / 2

    return loss

In [None]:
param_groups = model.trainable_parameters()

###################################################################
# Adam optimizer
###################################################################
optimizer = torch.optim.Adam(
    params=[
        {
            'params': param_groups[0],
            'lr': 5e-06,
            'weight_decay': 1.0e-4,
        },
        {
            'params': param_groups[1],
            'lr': 1e-04,
            'weight_decay': 1.0e-4,
        },
    ],
)

## Train the Activation Manipulation Network

In [None]:
# model = torch.nn.DataParallel(model).cuda()

model.train()

avg_meter = pyutils.AverageMeter()

for ep in range(args.amn_num_epochs):
    loader_iter = iter(train_data_loader)

    pbar = tqdm(
        range(1, len(train_data_loader) + 1),
        total=len(train_data_loader),
        dynamic_ncols=True,
    )

    ###################################################################
    # train phase
    ###################################################################
    for iteration, _ in enumerate(pbar):
        optimizer.zero_grad()
        try:
            pack = next(loader_iter)
        except:
            loader_iter = iter(train_data_loader)
            pack = next(loader_iter)

        img = pack['img'].cuda(non_blocking=True)
        label_amn = pack['label'].long().cuda(non_blocking=True)
        label_cls = pack['label_cls'].cuda(non_blocking=True)

        ###################################################################
        # forward pass
        ###################################################################
        logit = model(img, label_cls)

        B, C, H, W = logit.shape

        label_amn = imutils.resize_labels(label_amn.cpu(), size=logit.shape[-2:]).cuda()

        label_ = label_amn.clone()
        label_[label_amn == 255] = 0

        ###################################################################
        # Adopt label smoothing to subside the noise in initial seed
        # e.g., [0 1 0 0 ... 0] ==> [0.02 0.6 0.02 0.02 ... 0.02]
        ###################################################################
        label_smooth = torch.full(size=(B, C, H, W), fill_value=args.eps/(C-1)).cuda()
        label_smooth.scatter_(dim=1, index=torch.unsqueeze(label_, dim=1), value=1-args.eps)

        ###################################################################
        # compute loss value w/ balanced cross-entropy
        ###################################################################
        loss_pcl = balanced_cross_entropy(logit, label_amn, label_smooth)

        loss = loss_pcl
        loss.backward()

        optimizer.step()

        avg_meter.add({'loss_pcl': loss_pcl.item()})

        pbar.set_description(f"[{ep + 1}/{args.amn_num_epochs}] "
                            f"PCL: [{avg_meter.pop('loss_pcl'):.4f}]")


    ###################################################################
    # eval phase
    ###################################################################
    with torch.no_grad():
        model.eval()
        dataset = VOCSemanticSegmentationDataset(split=args.chainer_eval_set, data_dir=args.voc12_root)
        labels = []
        preds = []

        for i, pack in enumerate(tqdm(val_data_loader)):

            img_name = pack['name'][0]
            img = pack['img']
            label_cls = pack['label_cls'][0]

            img = img.cuda()

            ###################################################################
            # forward pass
            ###################################################################
            logit = model(img,  pack['label_cls'].cuda())

            size = img.shape[-2:]
            strided_up_size = imutils.get_strided_up_size(size, 16)

            valid_cat = torch.nonzero(label_cls)[:, 0]
            keys = np.pad(valid_cat + 1, (1, 0), mode='constant')

            logit_up = F.interpolate(logit, strided_up_size, mode='bilinear', align_corners=False)
            logit_up = logit_up[0, :, :size[0], :size[1]]

            logit_up = F.softmax(logit_up, dim=0)[keys].cpu().numpy()

            cls_labels = np.argmax(logit_up, axis=0)
            cls_labels = keys[cls_labels]

            preds.append(cls_labels.copy())

            gt_label = dataset.get_example_by_keys(i, (1,))[0]

            labels.append(gt_label.copy())

        ###################################################################
        # compute IoU (Intersection over Union)
        ###################################################################
        # IoU =  Area of Overlap / Area of Union
        ###################################################################
        confusion = calc_semantic_segmentation_confusion(preds, labels)

        gtj = confusion.sum(axis=1)
        resj = confusion.sum(axis=0)
        gtjresj = np.diag(confusion)
        denominator = gtj + resj - gtjresj
        iou = gtjresj / denominator

        print(f'[{ep + 1}/{args.amn_num_epochs}] miou: {np.nanmean(iou):.4f}')

        model.train()

torch.cuda.empty_cache()

  0%|          | 0/661 [00:00<?, ?it/s]

  0%|          | 0/1464 [00:00<?, ?it/s]

[1/1] miou: 0.5874


## What is Intersection over Union?
Intersection over Union is an **evaluation metric** used to measure the accuracy of an object detector on a particular dataset.

To calculate IOU, we need:

1. The ground-truth bounding boxes
2. The predicted bounding boxes from our model

![IOU_1](https://upload.wikimedia.org/wikipedia/commons/thumb/2/2d/Intersection_over_Union_-_object_detection_bounding_boxes.jpg/300px-Intersection_over_Union_-_object_detection_bounding_boxes.jpg) 


Computing Intersection over Union can be determined via:

![IOU_2](https://upload.wikimedia.org/wikipedia/commons/thumb/c/c7/Intersection_over_Union_-_visual_equation.png/300px-Intersection_over_Union_-_visual_equation.png) ![IOU_3](https://upload.wikimedia.org/wikipedia/commons/thumb/e/e6/Intersection_over_Union_-_poor%2C_good_and_excellent_score.png/300px-Intersection_over_Union_-_poor%2C_good_and_excellent_score.png)

**An Intersection over Union score >= 0.5 is normally considered a “good” prediction.**



## [reference]

1. https://en.wikipedia.org/wiki/Jaccard_index
2. https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/