# Image Segmentation using Segmenter
---
TA : Jaehoon Yoo (wogns98@kaist.ac.kr)

---
## Instructions
- In this assignment, we will perform semantic segmentation on PASCAL VOC 2011 dataset which contains 20 object categories. We use the Semantic Boundaries Dataset (SBD) as it contains more segmentation labels than the original dataset.
- To this end, you need to implement necessary network components, load and fine-tune the pretrained network, and report segmentation performance on the validation set.
- Fill in the section marked **Px.x** with the appropriate code. **You can only modify inside those areas, and not the skeleton code.**
- To begin, you should download this ipynb file into your own Google drive clicking `make a copy(사본만들기)`. Find the copy in your drive, change their name to `Segmentation_segmenter.ipynb`, if their names were changed to e.g. `Copy of Segmentation_segmenter.ipyb` or `Segmentation_segmenter.ipynb의 사본`.

---
## Prerequisite: Mount your gdrive.

In [1]:
import os
# from google.colab import drive
# drive.mount('/gdrive')

---
## Prerequisite: Setup the `root` directory properly.

In [2]:
# Specify the directory path where `Segmentation.ipynb` exists.
# For example, if you saved `Segmentation.ipynb` in `/gdrive/MyDrive/Segmentation_segmenter` directory,
# then set root = '/gdrive/MyDrive/Segmentation_segmenter'
root = '/gdrive/MyDrive/Segmentation_segmenter'
root = './'

---
# Basic settings

## Import libraries

In [3]:
import os
import time
import traceback
import logging
from easydict import EasyDict as edict
import numpy as np
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import VOCSegmentation, SBDataset
from torchvision.datasets.vision import StandardTransform
# from torchvision.models.vgg import VGG, vgg16, make_layers

torch.backends.cudnn.benchmark = True
torch.use_deterministic_algorithms(True, warn_only=True)

# !pip install git+https://github.com/lucasb-eyer/pydensecrf.git
import pydensecrf.densecrf as dcrf
import pydensecrf.utils as utils

!pip install einops
from einops import rearrange, reduce, repeat

!pip install timm==0.9.12
from timm.models.vision_transformer import vit_tiny_patch16_224, vit_tiny_patch16_384

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


  from .autonotebook import tqdm as notebook_tqdm


## Hyperparameters

In [4]:
# Basic settings
torch.manual_seed(42)
torch.cuda.manual_seed(42)

args = edict()
args.batch_size = 1
args.lr = 1e-4
args.momentum = 0.9
args.weight_decay = 5e-4
args.epoch = 2
args.tensorboard = True
args.gpu = True

device = 'cuda' if torch.cuda.is_available() and args.gpu else 'cpu'

# Create directory name.
result_dir = Path(root) / 'results'
result_dir.mkdir(parents=True, exist_ok=True)

## Tensorboard

In [5]:
# Setup tensorboard.
if args.tensorboard:
    %reload_ext tensorboard
    %tensorboard --logdir "/2025-ai-expert/비전/03_홍승훈교수님/0725/results" --samples_per_plugin images=100
else:
    writer = None

---
# Utility functions

Here are some utility functions that we will use throughout this assignment. You don't have to modify any of these.  
**Conditional Random Field (CRF)** is a technique to further improve segmentation performance, mainly focusing on better localization. Details can be found in the [DeepLab](https://arxiv.org/abs/1606.00915) paper.

In [6]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)


class toLongTensor:
    """ Convert a byte tensor to a long tensor """
    def __call__(self, img):
        output = torch.from_numpy(np.array(img).astype(np.int32)).long()
        output[output == 255] = 21
        return output


def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist


def label_accuracy_score(label_trues, label_preds, n_class):
    """ Returns overall accuracy and mean IoU """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        iou = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iou = np.nanmean(iou)
    return acc, mean_iou


class Colorize(object):
    """ Colorize the segmentation labels """
    def __init__(self, n=35, cmap=None):
        if cmap is None:
            raise NotImplementedError()
            self.cmap = labelcolormap(n)
        else:
            self.cmap = cmap
        self.cmap = self.cmap[:n]

    def preprocess(self, x):
        if len(x.size()) > 3 and x.size(1) > 1:
            # if x has a shape of [B, C, H, W],
            # where B and C denote a batch size and the number of semantic classe
            # then translate it into a shape of [B, 1, H, W]
            x = x.argmax(dim=1, keepdim=True).float()
        assert (len(x.shape) == 4) and (x.size(1) == 1), 'x should have a shape of [B, 1, H, W]'
        return x

    def __call__(self, x):
        x = self.preprocess(x)
        if (x.dtype == torch.float) and (x.max() < 2):
            x = x.mul(255).long()
        color_images = []
        gray_image_shape = x.shape[1:]
        for gray_image in x:
            color_image = torch.ByteTensor(3, *gray_image_shape[1:]).fill_(0)
            for label, cmap in enumerate(self.cmap):
                mask = (label == gray_image[0]).cpu()
                color_image[0][mask] = cmap[0]
                color_image[1][mask] = cmap[1]
                color_image[2][mask] = cmap[2]
            color_images.append(color_image)
        color_images = torch.stack(color_images)
        return color_images


def uint82bin(n, count=8):
    """ Returns the binary of integer n, count refers to amount of bits """
    return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])


def get_color_map():
    """ Returns N color map """
    N=25
    color_map = np.zeros((N, 3), dtype=np.uint8)
    for i in range(N):
        r, g, b = 0, 0, 0
        id = i
        for j in range(7):
            str_id = uint82bin(id)
            r = r ^ (np.uint8(str_id[-1]) << (7-j))
            g = g ^ (np.uint8(str_id[-2]) << (7-j))
            b = b ^ (np.uint8(str_id[-3]) << (7-j))
            id = id >> 3
        color_map[i, 0] = r
        color_map[i, 1] = g
        color_map[i, 2] = b
    color_map = torch.from_numpy(color_map)
    return color_map


def dense_crf(img, output_probs):
    """ Conditional Random Field for better segmentation
        Refer to https://github.com/lucasb-eyer/pydensecrf for details.
    """
    c = output_probs.shape[0]
    h = output_probs.shape[1]
    w = output_probs.shape[2]

    U = utils.unary_from_softmax(output_probs)
    U = np.ascontiguousarray(U)

    img = np.ascontiguousarray(img)

    d = dcrf.DenseCRF2D(w, h, c)
    d.setUnaryEnergy(U)
    d.addPairwiseGaussian(sxy=1, compat=3)
    d.addPairwiseBilateral(sxy=67, srgb=3, rgbim=img, compat=4)

    Q = d.inference(10)
    Q = np.array(Q).reshape((c, h, w))
    return Q


def add_padding(img):
    """ Zero-pad image(or any array-like object) to 500x500. """
    w, h = img.shape[-2], img.shape[-1]
    MAX_SIZE = w + 36
    IGNORE_IDX = 21

    assert max(w, h) <= MAX_SIZE, f'both height and width should be less than {MAX_SIZE}'

    _pad_left = (MAX_SIZE - w) // 2
    _pad_right = (MAX_SIZE - w + 1) // 2
    _pad_up = (MAX_SIZE - h) // 2
    _pad_down = (MAX_SIZE - h + 1) // 2

    _pad = (_pad_up, _pad_down, _pad_left, _pad_right)

    padding_img = transforms.Pad(_pad)
    padding_target = transforms.Pad(_pad, fill=IGNORE_IDX)

    img = F.pad(img, pad=_pad)
    return img


---
# Define `DataLoader` for training & validation set

If the cell below fails with error message "Destination path `./cls` already exists", try again with `download=False`.

In [7]:
mean = [.485, .456, .406]
std = [.229, .224, .225]

# define transform functions.
im_size = 384
transform_train = transforms.Compose([
    transforms.Resize((im_size, im_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
transform_train_target = transforms.Compose([
    transforms.Resize((im_size, im_size), interpolation=transforms.InterpolationMode.NEAREST),
    toLongTensor()
])
transform_test = transforms.Compose([
    transforms.Resize((im_size, im_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
transform_test_target = transforms.Compose([
    transforms.Resize((im_size, im_size), interpolation=transforms.InterpolationMode.NEAREST),
    toLongTensor()
])

# define dataloader.
sbd_transform_train = StandardTransform(transform_train, transform_train_target)
sbd_transform_test = StandardTransform(transform_test, transform_test_target)
try:
    train_dataset = SBDataset(root='../0724/', image_set='train', mode='segmentation', download=False, transforms=sbd_transform_train)
except:
    train_dataset = SBDataset(root='../0724/', image_set='train', mode='segmentation', download=False, transforms=sbd_transform_train)
test_dataset = SBDataset(root='../0724/', image_set='val', mode='segmentation', download=False, transforms=sbd_transform_test)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

In [8]:
len(test_dataset)

2857

---
# Define networks

## P1. Implement Attention layer [(Illustration)](https://docs.google.com/drawings/d/1HOI4QoqSACBFCeW0xVTOkn3XINBoEDNbr2xbtb9LZVY)
### (a) Declare q, k, v projection layers.
### (b) Declare output projection layer.
### (c) Declare dropout layers.
### (d) Implement forward method
The `forward` method should
- Map the input to queries, keys, and values.
- Compute the the inner-product between queries and keys
  - Multiply self.scale to the inner-product for stable training
- Compute the attention score by applying the softmax.
- Dropout the attention score
- Aggregate values based on the attention score
- Apply output projection layer
- Apply projection dropout.
- return the output and attention score.

In [9]:
class Attention(nn.Module):
    def __init__(self, hidden_size, heads, dropout):
        super().__init__()
        self.heads = heads
        head_dim = hidden_size // heads
        self.scale = head_dim ** -0.5
        self.attn = None

        ################################################
        # Write your code here
        self.q = nn.Linear(hidden_size, hidden_size, bias=False)
        self.k = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v = nn.Linear(hidden_size, hidden_size, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.proj_drop = nn.Dropout(dropout)
        # self.q =
        # self.k =
        # self.v =
        # self.attn_drop =
        # self.proj =
        # self.proj_drop =
        ################################################

    def forward(self, x, mask=None, debug=False):
        B, N, C = x.shape

        ################################################
        # Write your code here.
        # Add more lines as you wish.
        # reshape to (B, N, nheads, d_K) and transpose to (B, nheads, N, d_K)
        q = self.q(x).reshape(B, N, self.heads, C // self.heads).permute(0, 2, 1, 3)
        k = self.k(x).reshape(B, N, self.heads, C // self.heads).permute(0, 2, 1, 3)
        v = self.v(x).reshape(B, N, self.heads, C // self.heads).permute(0, 2, 1, 3)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # if mask is not None:
            # mask = mask.unsqueeze(1).unsqueeze(2)
            # attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        # output shape : (B, nheads, N, d_K)
        output = (attn @ v).transpose(1, 2).reshape(B, N, C)
        output = self.proj(output)
        output = self.proj_drop(output)
        # q =
        # k =
        # v =
        # attn =
        # output =
        ################################################

        if debug:
          return q, k, v, output, attn
        else:
          return output, attn

    @property
    def unwrapped(self):
        return self

In [10]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout, out_dim=None):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        if out_dim is None:
            out_dim = dim
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

    @property
    def unwrapped(self):
        return self


class Block(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, dropout)
        self.mlp = FeedForward(dim, mlp_dim, dropout)

    def forward(self, x, mask=None, return_attention=False):
        y, attn = self.attn(self.norm1(x), mask)
        if return_attention:
            return attn
        x = x + y
        x = x + self.mlp(self.norm2(x))
        return x

## P1 Tests

This section tests your solution for P1. **Please do not modify the code!**

In [11]:
@torch.no_grad()
def run_tests_p1():
    n_pass, n_test = 0, 4
    B, N, C = 2, 10, 192
    try:
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        net = Attention(C, 3, 0.1)
        print(f"[TEST  1/{n_test} Passed] Attention.__init__ executed without errors")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  1/{n_test} Failed] Attention.__init__ execution error; please see the traceback below")
        print(f"\n{traceback.format_exc()}")

        net = nn.Identity()

    try:
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        x = torch.randn(B, N, C)
        q, k, v, output, attn = net(x, debug=True)
        print(f"[TEST  2/{n_test} Passed] Attention.forward executed without errors")
        n_pass += 1

    except Exception as e:
        print(f"[TEST  2/{n_test} Failed] Attention.forward execution error;")
        print(f"\n{traceback.format_exc()}")
        return

    try:
        q_shape = k_shape = v_shape = torch.Size([2, 3, 10, 64])
        attn_shape = torch.Size([2, 3, 10, 10])
        output_shape = torch.Size([2, 10, 192])

        not_matches = []
        if q.shape != q_shape:
            not_matches.append('q')
        if k.shape != k_shape:
            not_matches.append('k')
        if v.shape != v_shape:
            not_matches.append('v')
        if attn.shape != attn_shape:
            not_matches.append('attn')
        if output.shape != output_shape:
            not_matches.append('output')
        assert len(not_matches) == 0
        print(f"[TEST  3/{n_test} Passed] Shape of q, k, v, attn, output are matched")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  3/{n_test} Failed] Shape of {not_matches} are not matched")

    try:
        q_mean, q_std = q.abs().mean(), q.abs().std()
        k_mean, k_std = k.abs().mean(), k.abs().std()
        v_mean, v_std = v.abs().mean(), v.abs().std()
        attn_mean, attn_std = attn.abs().mean(), attn.abs().std()
        output_mean, output_std = output.abs().mean(), output.abs().std()
        not_matches = []
        if ((q_mean - 0.4701).abs() > 1e-4) or ((q_std - 0.3834).abs() > 1e-4):
            not_matches.append('q')
        if ((k_mean - 0.4509).abs() > 1e-4) or ((k_std - 0.3436).abs() > 1e-4):
            not_matches.append('k')
        if ((v_mean - 0.4423).abs() > 1e-4) or ((v_std - 0.3363).abs() > 1e-4):
            not_matches.append('v')
        if ((attn_mean - 0.0999).abs() > 1e-4) or ((attn_std - 0.0498).abs() > 1e-4):
            not_matches.append('attn')
        if ((output_mean - 0.1003).abs() > 1e-4) or ((output_std - 0.0870).abs() > 1e-4):
            not_matches.append('output')
        assert len(not_matches) == 0
        print(f"[TEST  4/{n_test} Passed] All values are matched")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  4/{n_test} Failed] Value of {not_matches} are not matched")

    if n_pass == n_test:
        print(f"\n[TEST] 🎉🎉🥳 All {n_pass}/{n_test} tests passed!")


run_tests_p1()

[TEST  1/4 Passed] Attention.__init__ executed without errors
[TEST  2/4 Passed] Attention.forward executed without errors
[TEST  3/4 Passed] Shape of q, k, v, attn, output are matched
[TEST  4/4 Failed] Value of ['q', 'k', 'v', 'attn', 'output'] are not matched


## P2. Implement MaskTransformer [(Illustration)](https://docs.google.com/drawings/d/1TQ9lDfIimTom7df7_6ThzF9cOnedqIWN0xK8cB_-cSk)

The MaskTransformer takes the output from the ViT encoder to predict the segmentation mask. First, the visual tokens from the ViT encoder are concatenated with class embeddings and sent to the MaskTransformer. The MaskTransformer processes these tokens, and then it calculates the segmentation mask by computing the inner product between the class embeddings and the visual tokens.

In [12]:
from timm.models.layers import trunc_normal_
class MaskTransformer(nn.Module):
    def __init__(
        self,
        n_cls=21,
        patch_size=16,
        d_encoder=192,
        n_layers=2,
        n_heads=3,
        d_model=192,
        dropout=0.1,
    ):
        super().__init__()
        self.d_encoder = d_encoder
        self.patch_size = patch_size
        self.n_layers = n_layers
        self.n_cls = n_cls
        self.d_model = d_model
        self.d_ff = d_ff = d_model * 4
        self.scale = d_model ** -0.5

        d_e = d_encoder
        d_m = d_model

        #######################################################################
        # Write your code here
        # self.blocks =
        # self.cls_emb =
        # self.proj_dec =
        # self.proj_patch =
        # self.proj_classes =
        # self.decoder_norm =
        # self.mask_norm =
        #######################################################################

        #######################################################################
        self.blocks = nn.ModuleList(
            [Block(d_model, n_heads, d_ff, dropout,) for i in range(n_layers)]
        )

        self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
        self.proj_dec = nn.Linear(d_encoder, d_model)

        # TODO: change these as nn.Linear
        self.proj_patch = nn.Linear(d_model, d_model, bias=False)
        self.proj_classes = nn.Linear(d_model, d_model, bias=False)

        self.decoder_norm = nn.LayerNorm(d_model)
        self.mask_norm = nn.LayerNorm(n_cls)
        #######################################################################

        self.apply(init_weights)
        self.proj_patch.weight.data = self.scale * torch.randn_like(self.proj_patch.weight.data)
        self.proj_classes.weight.data = self.scale * torch.randn_like(self.proj_classes.weight.data)
        trunc_normal_(self.cls_emb, std=0.02)

    def forward(self, x, im_size, debug=False):
        H, W = im_size
        GS = H // self.patch_size # grid_size

        #####################################################
        # Write your code here
        # Add more lines as you wish.
        # block_output =
        # patches =
        # cls_seg_feat =
        # masks =
        #####################################################
        x = self.proj_dec(x)
        cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
        x = torch.cat((x, cls_emb), 1)
        for blk in self.blocks:
            x = blk(x)
        block_output = self.decoder_norm(x)

        patches, cls_seg_feat = block_output[:, : -self.n_cls], block_output[:, -self.n_cls :]
        patches = self.proj_patch(patches)
        cls_seg_feat = self.proj_classes(cls_seg_feat)

        patches = patches / patches.norm(dim=-1, keepdim=True)
        cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)

        masks = patches @ cls_seg_feat.transpose(1, 2)
        masks = self.mask_norm(masks)
        # masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS)) # equivalent
        masks = masks.transpose(-2, -1).reshape(x.size(0), -1, GS, GS)
        #####################################################
        if debug:
          return block_output, patches, cls_seg_feat, masks

        return masks

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"cls_emb"}


In [13]:
class DecoderLinear(nn.Module):
    def __init__(self, n_cls=21, patch_size=16, d_encoder=192):
        super().__init__()

        self.d_encoder = d_encoder
        self.patch_size = patch_size
        self.n_cls = n_cls

        self.head = nn.Linear(self.d_encoder, n_cls)
        self.apply(init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return set()

    def forward(self, x, im_size):
        H, W = im_size
        GS = H // self.patch_size
        x = self.head(x)
        x = rearrange(x, "b (h w) c -> b c h w", h=GS)

        return x

### P2 Tests

In [14]:
@torch.no_grad()
def run_tests_p2():
    n_pass, n_test = 0, 4
    B, N, C = 2, 16, 192
    im_size = (64, 64)
    try:
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        net = MaskTransformer()
        print(f"[TEST  1/{n_test} Passed] MaskTransformer.__init__ executed without errors")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  1/{n_test} Failed] MaskTransformer.__init__ execution error; please see the traceback below")
        print(f"\n{traceback.format_exc()}")
        return

    try:
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        x = torch.randn(B, N, C)
        block_output, patches, cls_seg_feat, masks = net(x, im_size, debug=True)
        print(f"[TEST  2/{n_test} Passed] MaskTransformer.forward executed without errors")
        n_pass += 1

    except Exception as e:
        print(f"[TEST  2/{n_test} Failed] MaskTransformer.forward execution error;")
        print(f"\n{traceback.format_exc()}")
        return

    b_mean, b_std = block_output.abs().mean(), block_output.abs().std()
    p_mean, p_std = patches.abs().mean(), patches.abs().std()
    c_mean, c_std = cls_seg_feat.abs().mean(), cls_seg_feat.abs().std()
    m_mean, m_std = masks.abs().mean(), masks.abs().std()

    try:
        not_matches = []
        if block_output.shape != torch.Size([2, 37, 192]):
            not_matches.append('block_output')
        if patches.shape != torch.Size([2, 16, 192]):
            not_matches.append('patches')
        if cls_seg_feat.shape != torch.Size([2, 21, 192]):
            not_matches.append('cls_seg_feat')
        if masks.shape != torch.Size([2, 21, 4, 4]):
            not_matches.append('masks')
        assert len(not_matches) == 0
        print(f"[TEST  3/{n_test} Passed] Shape of outputs are matched")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  3/{n_test} Failed] Shape of {not_matches} are not matched")

    try:
        not_matches = []
        if ((b_mean - 0.7917).abs() > 1e-4) or ((b_std - 0.6106).abs() > 1e-4):
            not_matches.append('q')
        if ((p_mean - 0.0576).abs() > 1e-4) or ((p_std - 0.0435).abs() > 1e-4):
            not_matches.append('k')
        if ((c_mean - 0.0577).abs() > 1e-4) or ((c_std - 0.0433).abs() > 1e-4):
            not_matches.append('cls_seg_feat')
        if ((m_mean - 0.8144).abs() > 1e-4) or ((m_std - 0.5783).abs() > 1e-4):
            not_matches.append('masks')
        assert len(not_matches) == 0
        print(f"[TEST  4/{n_test} Passed] All values are matched")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  4/{n_test} Failed] Value of {not_matches} are not matched")

    if n_pass == n_test:
        print(f"\n[TEST] 🎉🎉🥳 All {n_pass}/{n_test} tests passed!")


run_tests_p2()

[TEST  1/4 Passed] MaskTransformer.__init__ executed without errors
[TEST  2/4 Passed] MaskTransformer.forward executed without errors
[TEST  3/4 Passed] Shape of outputs are matched
[TEST  4/4 Failed] Value of ['q', 'cls_seg_feat', 'masks'] are not matched


# P3. Implement Segmenter [(Illustration)](https://docs.google.com/drawings/d/19914B8kWbAZrIwiFJBxkRESSN9_T5MRJ43YT6NL-SCg)
Segmenter class connects the ViT encoder with MaskTransformer (decoder).
Segmenter will get images shaped as (B, 3, H, W) and return the class logit map shaped as (B, num_cls, H, W).

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.models.layers import trunc_normal_

class Segmenter(nn.Module):
    def __init__(
        self,
        n_cls=21,
        use_tf=True,
    ):
        super().__init__()
        self.n_cls = n_cls
        # encoder = vit_tiny_patch16_224(pretrained=True)
        encoder = vit_tiny_patch16_384(pretrained=True)
        decoder = MaskTransformer() if use_tf else DecoderLinear()
        self.patch_size = 16
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, im):
        ############################################################
        H, W = im.size(2), im.size(3)

        x = self.encoder.forward_features(im)

        # remove CLS/DIST tokens for decoding
        num_extra_tokens = 1
        x = x[:, num_extra_tokens:]

        masks = self.decoder(x, (H, W))
        masks = F.interpolate(masks, size=(H, W), mode="bilinear")
        ############################################################
        return masks

    @torch.jit.ignore
    def no_weight_decay(self):
        def append_prefix_no_weight_decay(prefix, module):
            return set(map(lambda x: prefix + x, module.no_weight_decay()))

        nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
            append_prefix_no_weight_decay("decoder.", self.decoder)
        )
        return nwd_params

### P3 Tests

In [16]:
@torch.no_grad()
def run_tests_p3():
    n_pass, n_test = 0, 4
    try:
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        net = Segmenter()
        print(f"[TEST  1/{n_test} Passed] Segmenter.__init__ executed without errors")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  1/{n_test} Failed] Segmenter.__init__ execution error; please see the traceback below")
        print(f"\n{traceback.format_exc()}")
        return

    try:
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        im = torch.randn(1, 3, im_size, im_size)
        masks = net(im)
        print(f"[TEST  2/{n_test} Passed] Segmenter.forward executed without errors")
        n_pass += 1

    except Exception as e:
        print(f"[TEST  2/{n_test} Failed] Segmenter.forward execution error;")
        print(f"\n{traceback.format_exc()}")
        return

    m_mean, m_std = masks.abs().mean(), masks.abs().std()
    try:
        assert masks.shape == torch.Size([1, 21, im_size, im_size])
        print(f"[TEST  3/{n_test} Passed] Shape of mask is matched")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  3/{n_test} Failed] Shape of mask is not matched")

    try:
        assert ((m_mean - 0.7249).abs() <= 1e-4) and ((m_std - 0.4966).abs() <= 1e-4)
        print(f"[TEST  4/{n_test} Passed] Mask logits are matched")
        n_pass += 1
    except Exception as e:
        print(f"[TEST  4/{n_test} Failed] Mask logits are not matched")

    if n_pass == n_test:
        print(f"\n[TEST] 🎉🎉🥳 All {n_pass}/{n_test} tests passed!")

run_tests_p3()

[TEST  1/4 Passed] Segmenter.__init__ executed without errors
[TEST  2/4 Passed] Segmenter.forward executed without errors
[TEST  3/4 Passed] Shape of mask is matched
[TEST  4/4 Failed] Mask logits are not matched


---
# Training function

## Training pipeline

### (a) Forward/Backward step for training
- Feed the image through the model.
- Perform a gradient step based on the loss. Loss can be calculated using `criterion`, located at the beginning of the function.
- Choose the highest logit per pixel as prediction.

### (b) Forward step for validation
- Feed the image through the model.
- Calculate loss on current image.
- Choose the highest logit per pixel as prediction.

In [17]:
import tqdm

def get_prediction(criterion, net, image, label):
    output = net(image)
    loss = criterion(output, label).mean()
    pred = torch.argmax(output, dim=1)
    return output, loss, pred

def train_net(net, resume=False):
    # 21 is the index for boundaries: therefore we ignore this index.
    criterion = nn.CrossEntropyLoss(ignore_index=21, reduction='none')
    colorize = Colorize(21, get_color_map())
    best_valid_iou = 0

    if resume:
        checkpoint = torch.load(ckpt_path)
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = checkpoint['epoch']
        best_valid_iou = checkpoint['best_valid_iou']
        print(f'Resume training from epoch {epoch+1}')
    else:
        epoch = 0

    while epoch < args.epoch:
        t1 = time.time()
        saved_images, saved_labels = [], []

        # start training
        net.train()

        loss_total = 0
        ious = []
        pixel_accs = []

        for batch_idx, (image, label) in tqdm.tqdm(enumerate(train_loader)):
            # save images for visualization.
            if len(saved_images) < 4:
                saved_images.append(image.cpu())
                saved_labels.append(add_padding(label.cpu()))

            # move variables to gpu.
            image = image.to(device)
            label = label.to(device)

            output, loss, pred = get_prediction(criterion, net, image, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # update total loss.
            loss_total += loss.item()

            # target
            target = label.squeeze(1).cpu().numpy()

            # calculate pixel accuarcy and mean IoU
            acc, mean_iu = label_accuracy_score(target, pred.cpu().detach().numpy(), n_class=21)

            pixel_accs.append(acc)
            ious.append(mean_iu)

            if batch_idx % 50 == 0:
                print(f'Epoch : {epoch} || {batch_idx}/{len(train_loader)} || loss : {loss.item():.3f}, iou : {mean_iu * 100:.3f} pixel_acc : {acc * 100:.3f}')
                writer.add_scalar('train_loss_step', loss.item(), batch_idx + epoch * len(train_loader))
                writer.add_scalar('pixel_acc_step', acc, batch_idx + epoch * len(train_loader))
                writer.add_scalar('mean_iou_step', mean_iu, batch_idx + epoch * len(train_loader))


        # calculate average IoU
        total_ious = np.array(ious).T
        total_ious = np.nanmean(total_ious).mean()
        total_pixel_acc = np.array(pixel_accs).mean()

        writer.add_scalar('train_loss', loss_total / len(train_loader), epoch)
        writer.add_scalar('pixel_acc', total_pixel_acc, epoch)
        writer.add_scalar('mean_iou', total_ious, epoch)

        # image visualization
        un_norms, preds, outputs = [], [], []
        for image, label in zip(saved_images, saved_labels):
            # denormalize the image.
            image_permuted = image.permute(1, 0, 2, 3)
            un_norm = torch.zeros_like(image_permuted)
            for idx, (im, m, s) in enumerate(zip(image_permuted, mean, std)):
                un_norm[idx] = (im * s) + m
            un_norm = un_norm.permute(1, 0, 2, 3)
            un_norms.append(add_padding(un_norm))

            with torch.no_grad():
                output = net(image.to(device))
                pred = torch.argmax(output, dim=1)
                preds.append(add_padding(pred))

        # stitch images into a grid.
        un_norm = make_grid(torch.cat(un_norms), nrow=2)
        label = make_grid(colorize(torch.stack(saved_labels)), nrow=2)
        pred = make_grid(colorize(torch.stack(preds)), nrow=2)

        # write images to Tensorboard.
        writer.add_image('img', un_norm, epoch)
        writer.add_image('gt', label, epoch)
        writer.add_image('pred', pred, epoch)

        t = time.time() - t1
        print(f'>> Epoch : {epoch} || AVG loss : {loss_total / len(train_loader):.3f}, iou : {total_ious * 100:.3f} pixel_acc : {total_pixel_acc * 100:.3f} {t:.3f} secs')

        # evaluation
        net.eval()
        saved_images, saved_labels = [], []

        valid_loss_total = 0
        valid_ious = []
        valid_pixel_accs = []

        with torch.no_grad():
            for batch_idx, (image, label) in tqdm.tqdm(enumerate(test_loader)):
                # save images for visualization.
                if len(saved_images) < 4:
                    saved_images.append(image.cpu())
                    saved_labels.append(add_padding(label.cpu()))

                # move variables to gpu.
                image = image.to(device)
                label = label.to(device)

                output, loss, pred = get_prediction(criterion, net, image, label)

                # update total loss.
                valid_loss_total += loss.item()

                output = output.data.cpu().numpy()
                target = label.squeeze(1).cpu().numpy()

                acc, mean_iu = label_accuracy_score(target, pred.cpu().numpy(), n_class=21)

                valid_pixel_accs.append(acc)
                valid_ious.append(mean_iu)

        # calculate average IoU
        total_valid_ious = np.array(valid_ious).T
        total_valid_ious = np.nanmean(total_valid_ious).mean()
        total_valid_pixel_acc = np.array(valid_pixel_accs).mean()

        writer.add_scalar('valid_train_loss', valid_loss_total / len(test_loader), epoch)
        writer.add_scalar('valid_pixel_acc', total_valid_pixel_acc, epoch)
        writer.add_scalar('valid_mean_iou', total_valid_ious, epoch)

        # image visualization + CRF
        un_norms, preds, pred_crfs, outputs = [], [], [], []
        for image, label in zip(saved_images, saved_labels):
            # denormalize the image.
            image_permuted = image.permute(1, 0, 2, 3)
            un_norm = torch.zeros_like(image_permuted)
            for idx, (im, m, s) in enumerate(zip(image_permuted, mean, std)):
                un_norm[idx] = (im * s) + m
            un_norm = un_norm.permute(1, 0, 2, 3)
            un_norms.append(add_padding(un_norm))

            with torch.no_grad():
                output = net(image.to(device))
                outputs.append(add_padding(output))
                pred = torch.argmax(output, dim=1)
                preds.append(add_padding(pred))

            # CRF
            output_softmax = torch.nn.functional.softmax(output, dim=1).detach().cpu()
            un_norm_int = (un_norm * 255).squeeze().permute(1, 2, 0).numpy().astype(np.ubyte)
            pred_crf = dense_crf(un_norm_int, output_softmax.squeeze().numpy())
            pred_crfs.append(add_padding(torch.argmax(torch.Tensor(pred_crf), dim=0)).unsqueeze(0))

        # stitch images into a grid.
        valid_un_norm = make_grid(torch.cat(un_norms), nrow=2)
        valid_label = make_grid(colorize(torch.stack(saved_labels)), nrow=2)
        valid_pred = make_grid(colorize(torch.stack(preds)), nrow=2)
        valid_pred_crf = make_grid(colorize(torch.stack(pred_crfs)), nrow=2)

        # write images to tensorboard.
        writer.add_image('valid_img', valid_un_norm, epoch)
        writer.add_image('valid_gt', valid_label, epoch)
        writer.add_image('valid_pred', valid_pred, epoch)
        writer.add_image('valid_pred_crf', valid_pred_crf, epoch)

        print(f'>> Epoch : {epoch} || AVG valid loss : {valid_loss_total / len(test_loader):.3f}, iou : {total_valid_ious * 100:.3f} pixel_acc : {total_valid_pixel_acc * 100:.3f} {t:.3f} secs')

        # save checkpoints every epoch.
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_valid_iou': best_valid_iou
        }
        torch.save(checkpoint, ckpt_path)

        # save best checkpoint.
        if total_valid_ious > best_valid_iou:
            best_valid_iou = total_valid_ious
            torch.save(net.state_dict(), ckpt_dir / 'best.pt')

        epoch += 1
    print(f'>> Best validation set iou: {best_valid_iou}')

---
# Train models through the pipeline

In this section, you will
- Create/load directory.
- Select which model to train.
- Create model and optimizer.

The training process will automatically save checkpoints to your Google drive after every epoch under `parent_dir`. Training could take up to 40 minutes per epoch. As we provide  pretrained weights to start with, you will only be training for 2 epochs on your own. Uncomment the lines after `# Select model here.` to choose which model to train.  
**You must load the provided pretrained weights**, otherwise achieving reasonable performance will take much longer.  
**If you would like to resume** from an existing `model.pt`, then
- Comment out the line below `Load pretrained weights here.`,
- Specify `parent_dir` as instructed,
- Run the first code cell again, then run `train_net` with `resume=True` parameter.  

<font color="red">Do not terminate your process right after an epoch has finished.</font> Writing the saved model back to Google drive will take an extra couple of minutes, and aborting in the middle will likely ruin your checkpoint file.

In [18]:
num_trial=0
result_dir= Path(root) / 'results'
parent_dir = result_dir / f'trial_{num_trial}'
while parent_dir.is_dir():
    num_trial = int(parent_dir.name.replace('trial_',''))
    parent_dir = result_dir / f'trial_{num_trial+1}'

# modify parent_dir here if you want to resume from a checkpoint, or to rename directory.
# parent_dir = result_dir / 'trial_99'
print(f'Logs and ckpts will be saved in : {parent_dir}')

log_dir = parent_dir
ckpt_dir = parent_dir
ckpt_path = parent_dir / 'model.pt'
writer = SummaryWriter(log_dir)

# select model here.
model = Segmenter(use_tf=True).cuda()
# model = Segmenter(use_tf=False).cuda()

# define optimizer.
params = list(model.parameters())
nwd_names = list(model.no_weight_decay())
wd_params = []
nwd_params = []
for n, p in model.named_parameters():
    ignore = False
    for ign in nwd_names:
        if n.startswith(ign):
            nwd_params.append(p)
            ignore=True
            continue
    if not ignore:
        wd_params.append(p)
optimizer = SGD([{'params': nwd_params, 'weight_decay': 0},
                 {'params': wd_params}],
                lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

Logs and ckpts will be saved in : results/trial_0


In [19]:
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = True
train_net(model, resume=False)

3it [00:00,  5.93it/s]

Epoch : 0 || 0/8498 || loss : 3.518, iou : 0.183 pixel_acc : 1.656


53it [00:03, 20.72it/s]

Epoch : 0 || 50/8498 || loss : 0.413, iou : 32.998 pixel_acc : 98.994


105it [00:05, 22.17it/s]

Epoch : 0 || 100/8498 || loss : 2.020, iou : 13.509 pixel_acc : 40.496


153it [00:07, 21.62it/s]

Epoch : 0 || 150/8498 || loss : 0.486, iou : 59.557 pixel_acc : 93.115


204it [00:10, 20.80it/s]

Epoch : 0 || 200/8498 || loss : 1.886, iou : 5.432 pixel_acc : 37.453


255it [00:12, 20.53it/s]

Epoch : 0 || 250/8498 || loss : 0.833, iou : 21.361 pixel_acc : 85.211


303it [00:14, 20.63it/s]

Epoch : 0 || 300/8498 || loss : 0.498, iou : 17.630 pixel_acc : 91.788


354it [00:17, 21.10it/s]

Epoch : 0 || 350/8498 || loss : 1.159, iou : 15.680 pixel_acc : 62.417


405it [00:19, 21.31it/s]

Epoch : 0 || 400/8498 || loss : 0.739, iou : 27.410 pixel_acc : 83.228


453it [00:21, 20.17it/s]

Epoch : 0 || 450/8498 || loss : 0.616, iou : 14.664 pixel_acc : 85.656


503it [00:24, 18.69it/s]

Epoch : 0 || 500/8498 || loss : 2.880, iou : 6.955 pixel_acc : 29.643


555it [00:26, 22.17it/s]

Epoch : 0 || 550/8498 || loss : 0.778, iou : 43.412 pixel_acc : 86.825


603it [00:29, 20.47it/s]

Epoch : 0 || 600/8498 || loss : 0.676, iou : 10.709 pixel_acc : 84.377


654it [00:31, 21.12it/s]

Epoch : 0 || 650/8498 || loss : 0.681, iou : 18.397 pixel_acc : 85.130


702it [00:33, 20.57it/s]

Epoch : 0 || 700/8498 || loss : 1.075, iou : 17.709 pixel_acc : 66.603


753it [00:36, 20.64it/s]

Epoch : 0 || 750/8498 || loss : 0.533, iou : 17.975 pixel_acc : 89.875


804it [00:38, 22.59it/s]

Epoch : 0 || 800/8498 || loss : 0.706, iou : 44.084 pixel_acc : 94.072


855it [00:40, 21.23it/s]

Epoch : 0 || 850/8498 || loss : 0.531, iou : 20.323 pixel_acc : 87.035


903it [00:43, 21.71it/s]

Epoch : 0 || 900/8498 || loss : 0.387, iou : 12.047 pixel_acc : 91.703


954it [00:45, 21.52it/s]

Epoch : 0 || 950/8498 || loss : 0.270, iou : 31.751 pixel_acc : 93.846


1004it [00:48, 16.60it/s]

Epoch : 0 || 1000/8498 || loss : 0.224, iou : 32.588 pixel_acc : 97.765


1053it [00:51, 16.83it/s]

Epoch : 0 || 1050/8498 || loss : 1.167, iou : 32.466 pixel_acc : 73.294


1104it [00:54, 18.42it/s]

Epoch : 0 || 1100/8498 || loss : 0.562, iou : 35.053 pixel_acc : 90.701


1155it [00:56, 21.37it/s]

Epoch : 0 || 1150/8498 || loss : 0.374, iou : 30.613 pixel_acc : 95.770


1205it [00:59, 20.99it/s]

Epoch : 0 || 1200/8498 || loss : 0.419, iou : 46.369 pixel_acc : 96.124


1254it [01:01, 20.07it/s]

Epoch : 0 || 1250/8498 || loss : 1.905, iou : 17.496 pixel_acc : 42.742


1304it [01:04, 19.54it/s]

Epoch : 0 || 1300/8498 || loss : 0.736, iou : 13.890 pixel_acc : 78.353


1353it [01:06, 18.95it/s]

Epoch : 0 || 1350/8498 || loss : 0.293, iou : 36.677 pixel_acc : 94.027


1404it [01:09, 19.82it/s]

Epoch : 0 || 1400/8498 || loss : 0.535, iou : 41.939 pixel_acc : 92.624


1455it [01:11, 20.97it/s]

Epoch : 0 || 1450/8498 || loss : 0.803, iou : 21.733 pixel_acc : 82.610


1503it [01:13, 22.50it/s]

Epoch : 0 || 1500/8498 || loss : 0.375, iou : 18.547 pixel_acc : 90.226


1553it [01:16, 21.21it/s]

Epoch : 0 || 1550/8498 || loss : 0.243, iou : 84.592 pixel_acc : 96.280


1604it [01:18, 19.92it/s]

Epoch : 0 || 1600/8498 || loss : 0.732, iou : 33.401 pixel_acc : 81.989


1655it [01:21, 21.58it/s]

Epoch : 0 || 1650/8498 || loss : 0.356, iou : 83.892 pixel_acc : 91.696


1703it [01:23, 20.29it/s]

Epoch : 0 || 1700/8498 || loss : 0.376, iou : 30.064 pixel_acc : 91.265


1754it [01:25, 20.46it/s]

Epoch : 0 || 1750/8498 || loss : 0.636, iou : 21.927 pixel_acc : 84.012


1805it [01:28, 21.22it/s]

Epoch : 0 || 1800/8498 || loss : 0.692, iou : 23.449 pixel_acc : 81.994


1853it [01:30, 20.33it/s]

Epoch : 0 || 1850/8498 || loss : 0.370, iou : 89.927 pixel_acc : 94.978


1904it [01:32, 21.49it/s]

Epoch : 0 || 1900/8498 || loss : 0.278, iou : 94.189 pixel_acc : 97.011


1955it [01:35, 20.50it/s]

Epoch : 0 || 1950/8498 || loss : 0.515, iou : 79.811 pixel_acc : 88.795


2003it [01:37, 20.59it/s]

Epoch : 0 || 2000/8498 || loss : 0.389, iou : 45.262 pixel_acc : 94.977


2054it [01:40, 22.16it/s]

Epoch : 0 || 2050/8498 || loss : 0.337, iou : 91.425 pixel_acc : 95.541


2102it [01:42, 21.53it/s]

Epoch : 0 || 2100/8498 || loss : 0.639, iou : 33.620 pixel_acc : 87.001


2153it [01:44, 19.91it/s]

Epoch : 0 || 2150/8498 || loss : 0.369, iou : 82.390 pixel_acc : 91.434


2204it [01:47, 21.03it/s]

Epoch : 0 || 2200/8498 || loss : 0.246, iou : 39.514 pixel_acc : 95.548


2255it [01:49, 21.79it/s]

Epoch : 0 || 2250/8498 || loss : 0.476, iou : 44.002 pixel_acc : 93.323


2303it [01:51, 21.29it/s]

Epoch : 0 || 2300/8498 || loss : 0.331, iou : 59.779 pixel_acc : 93.591


2354it [01:54, 22.64it/s]

Epoch : 0 || 2350/8498 || loss : 0.333, iou : 19.321 pixel_acc : 89.297


2402it [01:56, 20.35it/s]

Epoch : 0 || 2400/8498 || loss : 0.496, iou : 75.891 pixel_acc : 86.295


2453it [01:58, 20.91it/s]

Epoch : 0 || 2450/8498 || loss : 0.113, iou : 32.882 pixel_acc : 98.646


2504it [02:01, 20.51it/s]

Epoch : 0 || 2500/8498 || loss : 0.216, iou : 31.339 pixel_acc : 97.698


2555it [02:03, 20.68it/s]

Epoch : 0 || 2550/8498 || loss : 0.115, iou : 59.717 pixel_acc : 99.362


2603it [02:06, 20.48it/s]

Epoch : 0 || 2600/8498 || loss : 0.370, iou : 38.377 pixel_acc : 92.838


2654it [02:08, 20.77it/s]

Epoch : 0 || 2650/8498 || loss : 1.861, iou : 5.679 pixel_acc : 34.677


2702it [02:10, 20.68it/s]

Epoch : 0 || 2700/8498 || loss : 0.949, iou : 21.995 pixel_acc : 73.857


2753it [02:13, 20.88it/s]

Epoch : 0 || 2750/8498 || loss : 0.344, iou : 42.682 pixel_acc : 93.551


2804it [02:15, 21.10it/s]

Epoch : 0 || 2800/8498 || loss : 0.456, iou : 54.440 pixel_acc : 89.714


2854it [02:18, 20.97it/s]

Epoch : 0 || 2850/8498 || loss : 0.542, iou : 56.229 pixel_acc : 86.808


2905it [02:20, 31.29it/s]

Epoch : 0 || 2900/8498 || loss : 0.755, iou : 42.206 pixel_acc : 67.838


2957it [02:21, 32.49it/s]

Epoch : 0 || 2950/8498 || loss : 0.109, iou : 42.312 pixel_acc : 98.499


3005it [02:23, 33.04it/s]

Epoch : 0 || 3000/8498 || loss : 1.127, iou : 35.637 pixel_acc : 68.783


3057it [02:24, 32.58it/s]

Epoch : 0 || 3050/8498 || loss : 0.475, iou : 79.605 pixel_acc : 88.765


3105it [02:26, 32.16it/s]

Epoch : 0 || 3100/8498 || loss : 0.189, iou : 90.730 pixel_acc : 97.210


3157it [02:27, 33.28it/s]

Epoch : 0 || 3150/8498 || loss : 0.588, iou : 23.728 pixel_acc : 89.957


3205it [02:29, 33.50it/s]

Epoch : 0 || 3200/8498 || loss : 0.463, iou : 59.695 pixel_acc : 91.688


3253it [02:30, 33.49it/s]

Epoch : 0 || 3250/8498 || loss : 0.316, iou : 34.469 pixel_acc : 94.084


3305it [02:33, 21.99it/s]

Epoch : 0 || 3300/8498 || loss : 0.756, iou : 24.188 pixel_acc : 75.405


3353it [02:35, 22.22it/s]

Epoch : 0 || 3350/8498 || loss : 0.890, iou : 55.780 pixel_acc : 81.156


3404it [02:37, 20.86it/s]

Epoch : 0 || 3400/8498 || loss : 0.423, iou : 33.324 pixel_acc : 93.212


3455it [02:40, 20.16it/s]

Epoch : 0 || 3450/8498 || loss : 0.444, iou : 41.789 pixel_acc : 91.943


3503it [02:42, 21.69it/s]

Epoch : 0 || 3500/8498 || loss : 0.461, iou : 33.808 pixel_acc : 80.927


3554it [02:44, 20.87it/s]

Epoch : 0 || 3550/8498 || loss : 0.571, iou : 51.475 pixel_acc : 85.839


3605it [02:47, 21.04it/s]

Epoch : 0 || 3600/8498 || loss : 0.543, iou : 52.471 pixel_acc : 87.678


3653it [02:49, 20.68it/s]

Epoch : 0 || 3650/8498 || loss : 0.540, iou : 49.676 pixel_acc : 85.392


3704it [02:51, 20.25it/s]

Epoch : 0 || 3700/8498 || loss : 0.121, iou : 52.640 pixel_acc : 97.361


3755it [02:54, 20.86it/s]

Epoch : 0 || 3750/8498 || loss : 0.593, iou : 44.683 pixel_acc : 76.714


3803it [02:56, 20.71it/s]

Epoch : 0 || 3800/8498 || loss : 0.658, iou : 33.056 pixel_acc : 76.759


3854it [02:59, 21.68it/s]

Epoch : 0 || 3850/8498 || loss : 0.365, iou : 15.829 pixel_acc : 88.915


3903it [03:01, 19.18it/s]

Epoch : 0 || 3900/8498 || loss : 0.445, iou : 39.976 pixel_acc : 89.384


3953it [03:03, 21.21it/s]

Epoch : 0 || 3950/8498 || loss : 0.645, iou : 72.249 pixel_acc : 83.982


4004it [03:06, 19.64it/s]

Epoch : 0 || 4000/8498 || loss : 0.218, iou : 31.826 pixel_acc : 95.382


4053it [03:08, 20.60it/s]

Epoch : 0 || 4050/8498 || loss : 0.223, iou : 91.948 pixel_acc : 96.778


4104it [03:11, 19.76it/s]

Epoch : 0 || 4100/8498 || loss : 0.753, iou : 37.517 pixel_acc : 69.941


4154it [03:14, 18.98it/s]

Epoch : 0 || 4150/8498 || loss : 0.368, iou : 44.450 pixel_acc : 89.495


4204it [03:16, 19.55it/s]

Epoch : 0 || 4200/8498 || loss : 0.928, iou : 33.000 pixel_acc : 79.123


4254it [03:18, 20.59it/s]

Epoch : 0 || 4250/8498 || loss : 0.725, iou : 47.198 pixel_acc : 81.707


4303it [03:21, 20.13it/s]

Epoch : 0 || 4300/8498 || loss : 0.676, iou : 52.133 pixel_acc : 79.464


4354it [03:24, 17.83it/s]

Epoch : 0 || 4350/8498 || loss : 0.345, iou : 68.899 pixel_acc : 91.577


4404it [03:26, 17.48it/s]

Epoch : 0 || 4400/8498 || loss : 0.259, iou : 45.710 pixel_acc : 96.142


4454it [03:29, 17.42it/s]

Epoch : 0 || 4450/8498 || loss : 0.284, iou : 35.340 pixel_acc : 91.469


4504it [03:32, 15.42it/s]

Epoch : 0 || 4500/8498 || loss : 0.455, iou : 49.115 pixel_acc : 89.960


4554it [03:35, 17.43it/s]

Epoch : 0 || 4550/8498 || loss : 0.432, iou : 44.402 pixel_acc : 87.329


4604it [03:38, 17.03it/s]

Epoch : 0 || 4600/8498 || loss : 0.471, iou : 41.991 pixel_acc : 90.763


4654it [03:41, 17.43it/s]

Epoch : 0 || 4650/8498 || loss : 0.605, iou : 8.056 pixel_acc : 88.611


4704it [03:44, 16.74it/s]

Epoch : 0 || 4700/8498 || loss : 0.396, iou : 43.388 pixel_acc : 88.480


4754it [03:47, 17.52it/s]

Epoch : 0 || 4750/8498 || loss : 0.635, iou : 44.532 pixel_acc : 80.520


4804it [03:50, 16.88it/s]

Epoch : 0 || 4800/8498 || loss : 0.287, iou : 40.194 pixel_acc : 93.771


4854it [03:53, 17.27it/s]

Epoch : 0 || 4850/8498 || loss : 1.055, iou : 19.081 pixel_acc : 62.236


4904it [03:56, 16.77it/s]

Epoch : 0 || 4900/8498 || loss : 0.464, iou : 28.418 pixel_acc : 88.940


4954it [03:59, 16.85it/s]

Epoch : 0 || 4950/8498 || loss : 0.314, iou : 91.537 pixel_acc : 95.837


5004it [04:02, 16.53it/s]

Epoch : 0 || 5000/8498 || loss : 0.442, iou : 39.907 pixel_acc : 86.344


5053it [04:05, 19.95it/s]

Epoch : 0 || 5050/8498 || loss : 0.330, iou : 89.422 pixel_acc : 94.425


5103it [04:07, 19.81it/s]

Epoch : 0 || 5100/8498 || loss : 2.613, iou : 4.690 pixel_acc : 31.974


5154it [04:09, 22.87it/s]

Epoch : 0 || 5150/8498 || loss : 0.493, iou : 33.580 pixel_acc : 85.950


5203it [04:12, 19.25it/s]

Epoch : 0 || 5200/8498 || loss : 0.401, iou : 60.319 pixel_acc : 91.139


5253it [04:14, 20.28it/s]

Epoch : 0 || 5250/8498 || loss : 0.510, iou : 52.212 pixel_acc : 87.912


5304it [04:17, 21.88it/s]

Epoch : 0 || 5300/8498 || loss : 0.389, iou : 29.057 pixel_acc : 87.965


5355it [04:19, 21.59it/s]

Epoch : 0 || 5350/8498 || loss : 0.199, iou : 59.509 pixel_acc : 96.476


5404it [04:22, 17.75it/s]

Epoch : 0 || 5400/8498 || loss : 0.435, iou : 22.511 pixel_acc : 87.409


5454it [04:25, 16.87it/s]

Epoch : 0 || 5450/8498 || loss : 0.219, iou : 75.858 pixel_acc : 95.034


5504it [04:28, 16.48it/s]

Epoch : 0 || 5500/8498 || loss : 0.331, iou : 61.466 pixel_acc : 90.399


5554it [04:30, 16.45it/s]

Epoch : 0 || 5550/8498 || loss : 0.433, iou : 31.384 pixel_acc : 87.754


5604it [04:33, 16.58it/s]

Epoch : 0 || 5600/8498 || loss : 0.290, iou : 95.469 pixel_acc : 97.813


5653it [04:36, 18.37it/s]

Epoch : 0 || 5650/8498 || loss : 0.190, iou : 59.870 pixel_acc : 96.881


5702it [04:39, 20.70it/s]

Epoch : 0 || 5700/8498 || loss : 1.165, iou : 23.547 pixel_acc : 67.414


5752it [04:41, 21.05it/s]

Epoch : 0 || 5750/8498 || loss : 0.432, iou : 40.434 pixel_acc : 90.796


5804it [04:44, 19.65it/s]

Epoch : 0 || 5800/8498 || loss : 0.625, iou : 74.119 pixel_acc : 85.168


5855it [04:46, 21.33it/s]

Epoch : 0 || 5850/8498 || loss : 0.481, iou : 15.640 pixel_acc : 88.224


5903it [04:49, 20.99it/s]

Epoch : 0 || 5900/8498 || loss : 0.505, iou : 67.204 pixel_acc : 83.913


5954it [04:51, 20.69it/s]

Epoch : 0 || 5950/8498 || loss : 0.456, iou : 26.019 pixel_acc : 83.967


6004it [04:54, 20.39it/s]

Epoch : 0 || 6000/8498 || loss : 0.285, iou : 86.539 pixel_acc : 93.564


6052it [04:56, 18.02it/s]

Epoch : 0 || 6050/8498 || loss : 0.087, iou : 87.896 pixel_acc : 97.757


6104it [04:59, 18.64it/s]

Epoch : 0 || 6100/8498 || loss : 0.230, iou : 30.790 pixel_acc : 92.328


6153it [05:01, 20.31it/s]

Epoch : 0 || 6150/8498 || loss : 0.464, iou : 43.406 pixel_acc : 82.789


6204it [05:03, 21.51it/s]

Epoch : 0 || 6200/8498 || loss : 0.299, iou : 83.757 pixel_acc : 91.168


6252it [05:06, 20.34it/s]

Epoch : 0 || 6250/8498 || loss : 0.197, iou : 93.072 pixel_acc : 97.540


6304it [05:08, 20.14it/s]

Epoch : 0 || 6300/8498 || loss : 0.154, iou : 48.895 pixel_acc : 97.790


6352it [05:11, 19.31it/s]

Epoch : 0 || 6350/8498 || loss : 0.211, iou : 91.717 pixel_acc : 96.703


6404it [05:13, 19.77it/s]

Epoch : 0 || 6400/8498 || loss : 0.152, iou : 44.962 pixel_acc : 95.793


6452it [05:16, 20.02it/s]

Epoch : 0 || 6450/8498 || loss : 0.693, iou : 48.121 pixel_acc : 82.869


6504it [05:18, 19.70it/s]

Epoch : 0 || 6500/8498 || loss : 0.153, iou : 72.567 pixel_acc : 94.864


6554it [05:21, 19.91it/s]

Epoch : 0 || 6550/8498 || loss : 0.572, iou : 19.245 pixel_acc : 81.805


6604it [05:23, 19.78it/s]

Epoch : 0 || 6600/8498 || loss : 0.125, iou : 94.878 pixel_acc : 98.642


6653it [05:26, 20.49it/s]

Epoch : 0 || 6650/8498 || loss : 0.830, iou : 27.908 pixel_acc : 69.945


6704it [05:28, 20.57it/s]

Epoch : 0 || 6700/8498 || loss : 0.154, iou : 91.432 pixel_acc : 96.786


6754it [05:31, 20.57it/s]

Epoch : 0 || 6750/8498 || loss : 0.178, iou : 94.658 pixel_acc : 97.302


6805it [05:33, 22.59it/s]

Epoch : 0 || 6800/8498 || loss : 0.140, iou : 52.468 pixel_acc : 96.455


6855it [05:36, 18.97it/s]

Epoch : 0 || 6850/8498 || loss : 0.380, iou : 22.939 pixel_acc : 89.940


6904it [05:38, 20.13it/s]

Epoch : 0 || 6900/8498 || loss : 0.256, iou : 87.175 pixel_acc : 93.237


6954it [05:41, 19.76it/s]

Epoch : 0 || 6950/8498 || loss : 0.651, iou : 28.788 pixel_acc : 78.536


7004it [05:43, 19.79it/s]

Epoch : 0 || 7000/8498 || loss : 0.228, iou : 75.493 pixel_acc : 93.458


7053it [05:46, 20.71it/s]

Epoch : 0 || 7050/8498 || loss : 0.933, iou : 43.908 pixel_acc : 76.716


7104it [05:48, 20.71it/s]

Epoch : 0 || 7100/8498 || loss : 0.398, iou : 55.884 pixel_acc : 91.048


7154it [05:50, 20.04it/s]

Epoch : 0 || 7150/8498 || loss : 0.179, iou : 92.322 pixel_acc : 96.678


7205it [05:53, 20.09it/s]

Epoch : 0 || 7200/8498 || loss : 0.292, iou : 53.218 pixel_acc : 93.731


7254it [05:55, 19.71it/s]

Epoch : 0 || 7250/8498 || loss : 0.088, iou : 24.757 pixel_acc : 99.026


7303it [05:58, 17.05it/s]

Epoch : 0 || 7300/8498 || loss : 0.426, iou : 86.000 pixel_acc : 92.482


7353it [06:01, 16.68it/s]

Epoch : 0 || 7350/8498 || loss : 3.273, iou : 23.014 pixel_acc : 21.320


7403it [06:04, 16.90it/s]

Epoch : 0 || 7400/8498 || loss : 0.576, iou : 42.087 pixel_acc : 80.074


7453it [06:07, 16.93it/s]

Epoch : 0 || 7450/8498 || loss : 0.373, iou : 83.903 pixel_acc : 91.615


7503it [06:10, 16.91it/s]

Epoch : 0 || 7500/8498 || loss : 0.229, iou : 88.834 pixel_acc : 94.271


7553it [06:13, 17.45it/s]

Epoch : 0 || 7550/8498 || loss : 0.329, iou : 57.880 pixel_acc : 93.085


7604it [06:16, 18.62it/s]

Epoch : 0 || 7600/8498 || loss : 0.266, iou : 55.078 pixel_acc : 93.763


7654it [06:19, 17.37it/s]

Epoch : 0 || 7650/8498 || loss : 0.433, iou : 81.930 pixel_acc : 92.883


7705it [06:22, 20.01it/s]

Epoch : 0 || 7700/8498 || loss : 0.104, iou : 48.891 pixel_acc : 97.782


7753it [06:24, 17.68it/s]

Epoch : 0 || 7750/8498 || loss : 0.215, iou : 92.343 pixel_acc : 96.387


7805it [06:27, 20.49it/s]

Epoch : 0 || 7800/8498 || loss : 0.296, iou : 58.680 pixel_acc : 93.768


7854it [06:30, 19.75it/s]

Epoch : 0 || 7850/8498 || loss : 0.653, iou : 31.224 pixel_acc : 78.793


7904it [06:32, 19.55it/s]

Epoch : 0 || 7900/8498 || loss : 0.379, iou : 50.845 pixel_acc : 89.826


7953it [06:35, 19.65it/s]

Epoch : 0 || 7950/8498 || loss : 0.192, iou : 45.860 pixel_acc : 95.329


8004it [06:37, 20.34it/s]

Epoch : 0 || 8000/8498 || loss : 0.658, iou : 23.024 pixel_acc : 74.238


8055it [06:40, 21.15it/s]

Epoch : 0 || 8050/8498 || loss : 0.653, iou : 22.607 pixel_acc : 75.638


8103it [06:42, 20.68it/s]

Epoch : 0 || 8100/8498 || loss : 0.309, iou : 55.085 pixel_acc : 91.762


8153it [06:44, 19.77it/s]

Epoch : 0 || 8150/8498 || loss : 0.289, iou : 93.944 pixel_acc : 96.887


8203it [06:47, 20.48it/s]

Epoch : 0 || 8200/8498 || loss : 0.055, iou : 71.015 pixel_acc : 98.591


8255it [06:49, 21.44it/s]

Epoch : 0 || 8250/8498 || loss : 0.058, iou : 88.214 pixel_acc : 98.313


8303it [06:52, 20.37it/s]

Epoch : 0 || 8300/8498 || loss : 0.672, iou : 22.149 pixel_acc : 73.756


8352it [06:54, 14.86it/s]

Epoch : 0 || 8350/8498 || loss : 0.326, iou : 77.645 pixel_acc : 88.813


8402it [06:58, 15.40it/s]

Epoch : 0 || 8400/8498 || loss : 0.133, iou : 83.315 pixel_acc : 96.595


8454it [07:01, 15.07it/s]

Epoch : 0 || 8450/8498 || loss : 0.029, iou : 58.418 pixel_acc : 99.332


8498it [07:03, 20.05it/s]


>> Epoch : 0 || AVG loss : 0.548, iou : 48.268 pixel_acc : 86.137 424.377 secs


2857it [00:58, 48.52it/s]


>> Epoch : 0 || AVG valid loss : 0.367, iou : 61.005 pixel_acc : 89.607 424.377 secs


4it [00:00, 16.24it/s]

Epoch : 1 || 0/8498 || loss : 0.295, iou : 61.655 pixel_acc : 95.353


54it [00:03, 15.59it/s]

Epoch : 1 || 50/8498 || loss : 0.784, iou : 33.783 pixel_acc : 68.728


104it [00:06, 15.74it/s]

Epoch : 1 || 100/8498 || loss : 0.713, iou : 34.912 pixel_acc : 72.278


154it [00:09, 17.21it/s]

Epoch : 1 || 150/8498 || loss : 0.267, iou : 20.728 pixel_acc : 90.657


203it [00:12, 17.14it/s]

Epoch : 1 || 200/8498 || loss : 0.122, iou : 95.755 pixel_acc : 98.087


253it [00:15, 16.57it/s]

Epoch : 1 || 250/8498 || loss : 0.343, iou : 33.311 pixel_acc : 85.636


303it [00:18, 16.91it/s]

Epoch : 1 || 300/8498 || loss : 0.358, iou : 60.194 pixel_acc : 94.203


353it [00:21, 17.49it/s]

Epoch : 1 || 350/8498 || loss : 0.404, iou : 29.552 pixel_acc : 91.783


403it [00:24, 16.62it/s]

Epoch : 1 || 400/8498 || loss : 0.557, iou : 53.153 pixel_acc : 74.696


453it [00:27, 16.77it/s]

Epoch : 1 || 450/8498 || loss : 0.155, iou : 63.066 pixel_acc : 96.922


503it [00:30, 15.63it/s]

Epoch : 1 || 500/8498 || loss : 0.669, iou : 52.496 pixel_acc : 79.333


553it [00:33, 15.58it/s]

Epoch : 1 || 550/8498 || loss : 0.767, iou : 36.821 pixel_acc : 84.792


603it [00:36, 16.28it/s]

Epoch : 1 || 600/8498 || loss : 0.345, iou : 22.588 pixel_acc : 88.204


653it [00:39, 16.74it/s]

Epoch : 1 || 650/8498 || loss : 0.424, iou : 24.206 pixel_acc : 85.796


704it [00:42, 17.15it/s]

Epoch : 1 || 700/8498 || loss : 0.109, iou : 80.685 pixel_acc : 97.000


754it [00:45, 15.83it/s]

Epoch : 1 || 750/8498 || loss : 0.306, iou : 91.423 pixel_acc : 96.105


804it [00:48, 16.77it/s]

Epoch : 1 || 800/8498 || loss : 0.246, iou : 95.318 pixel_acc : 97.630


854it [00:51, 16.32it/s]

Epoch : 1 || 850/8498 || loss : 0.238, iou : 55.819 pixel_acc : 94.116


904it [00:54, 16.72it/s]

Epoch : 1 || 900/8498 || loss : 0.560, iou : 70.583 pixel_acc : 83.850


954it [00:57, 16.58it/s]

Epoch : 1 || 950/8498 || loss : 0.243, iou : 87.788 pixel_acc : 93.715


1004it [01:00, 17.51it/s]

Epoch : 1 || 1000/8498 || loss : 0.142, iou : 42.068 pixel_acc : 97.143


1054it [01:03, 16.93it/s]

Epoch : 1 || 1050/8498 || loss : 0.099, iou : 93.158 pixel_acc : 98.090


1104it [01:06, 17.13it/s]

Epoch : 1 || 1100/8498 || loss : 0.830, iou : 30.199 pixel_acc : 62.134


1154it [01:09, 16.54it/s]

Epoch : 1 || 1150/8498 || loss : 0.154, iou : 92.848 pixel_acc : 96.349


1204it [01:12, 17.29it/s]

Epoch : 1 || 1200/8498 || loss : 0.712, iou : 65.527 pixel_acc : 78.944


1254it [01:15, 17.16it/s]

Epoch : 1 || 1250/8498 || loss : 0.168, iou : 84.828 pixel_acc : 95.626


1303it [01:17, 17.32it/s]

Epoch : 1 || 1300/8498 || loss : 0.279, iou : 93.067 pixel_acc : 96.880


1353it [01:20, 16.45it/s]

Epoch : 1 || 1350/8498 || loss : 0.050, iou : 49.606 pixel_acc : 99.211


1403it [01:23, 16.48it/s]

Epoch : 1 || 1400/8498 || loss : 0.401, iou : 40.039 pixel_acc : 87.920


1453it [01:26, 17.09it/s]

Epoch : 1 || 1450/8498 || loss : 0.266, iou : 56.516 pixel_acc : 91.473


1503it [01:29, 16.90it/s]

Epoch : 1 || 1500/8498 || loss : 0.580, iou : 37.457 pixel_acc : 82.414


1553it [01:32, 17.06it/s]

Epoch : 1 || 1550/8498 || loss : 0.489, iou : 50.969 pixel_acc : 82.309


1603it [01:35, 17.08it/s]

Epoch : 1 || 1600/8498 || loss : 0.069, iou : 93.682 pixel_acc : 98.659


1653it [01:38, 16.37it/s]

Epoch : 1 || 1650/8498 || loss : 0.219, iou : 53.568 pixel_acc : 94.589


1703it [01:41, 15.55it/s]

Epoch : 1 || 1700/8498 || loss : 0.657, iou : 37.735 pixel_acc : 79.105


1753it [01:45, 16.06it/s]

Epoch : 1 || 1750/8498 || loss : 1.463, iou : 27.323 pixel_acc : 71.764


1803it [01:48, 15.40it/s]

Epoch : 1 || 1800/8498 || loss : 0.028, iou : 86.213 pixel_acc : 99.392


1853it [01:51, 15.40it/s]

Epoch : 1 || 1850/8498 || loss : 0.233, iou : 62.517 pixel_acc : 95.729


1903it [01:54, 15.55it/s]

Epoch : 1 || 1900/8498 || loss : 0.377, iou : 72.157 pixel_acc : 85.238


1953it [01:57, 15.50it/s]

Epoch : 1 || 1950/8498 || loss : 0.108, iou : 94.696 pixel_acc : 97.508


2003it [02:01, 15.72it/s]

Epoch : 1 || 2000/8498 || loss : 0.037, iou : 49.529 pixel_acc : 99.059


2053it [02:04, 15.90it/s]

Epoch : 1 || 2050/8498 || loss : 0.089, iou : 86.486 pixel_acc : 97.336


2103it [02:07, 15.81it/s]

Epoch : 1 || 2100/8498 || loss : 0.072, iou : 79.103 pixel_acc : 97.731


2153it [02:10, 15.81it/s]

Epoch : 1 || 2150/8498 || loss : 0.294, iou : 53.563 pixel_acc : 94.835


2203it [02:13, 15.97it/s]

Epoch : 1 || 2200/8498 || loss : 0.241, iou : 89.254 pixel_acc : 94.980


2253it [02:16, 15.33it/s]

Epoch : 1 || 2250/8498 || loss : 0.249, iou : 59.162 pixel_acc : 95.759


2303it [02:19, 15.70it/s]

Epoch : 1 || 2300/8498 || loss : 0.193, iou : 89.325 pixel_acc : 95.308


2353it [02:23, 15.76it/s]

Epoch : 1 || 2350/8498 || loss : 0.577, iou : 64.614 pixel_acc : 82.445


2403it [02:26, 15.76it/s]

Epoch : 1 || 2400/8498 || loss : 0.245, iou : 91.316 pixel_acc : 95.814


2453it [02:29, 15.31it/s]

Epoch : 1 || 2450/8498 || loss : 0.265, iou : 85.468 pixel_acc : 92.508


2503it [02:32, 15.73it/s]

Epoch : 1 || 2500/8498 || loss : 0.203, iou : 45.432 pixel_acc : 93.899


2553it [02:36, 15.59it/s]

Epoch : 1 || 2550/8498 || loss : 0.461, iou : 70.001 pixel_acc : 89.183


2603it [02:39, 15.44it/s]

Epoch : 1 || 2600/8498 || loss : 0.187, iou : 62.806 pixel_acc : 96.386


2653it [02:42, 15.49it/s]

Epoch : 1 || 2650/8498 || loss : 0.228, iou : 91.012 pixel_acc : 95.590


2703it [02:45, 15.89it/s]

Epoch : 1 || 2700/8498 || loss : 0.084, iou : 90.791 pixel_acc : 97.731


2753it [02:48, 17.76it/s]

Epoch : 1 || 2750/8498 || loss : 0.276, iou : 83.650 pixel_acc : 91.789


2803it [02:51, 16.79it/s]

Epoch : 1 || 2800/8498 || loss : 0.272, iou : 84.884 pixel_acc : 94.294


2853it [02:54, 16.94it/s]

Epoch : 1 || 2850/8498 || loss : 0.734, iou : 40.977 pixel_acc : 74.556


2903it [02:57, 16.72it/s]

Epoch : 1 || 2900/8498 || loss : 0.302, iou : 55.249 pixel_acc : 90.367


2953it [03:00, 17.42it/s]

Epoch : 1 || 2950/8498 || loss : 0.027, iou : 68.760 pixel_acc : 99.729


3003it [03:03, 17.04it/s]

Epoch : 1 || 3000/8498 || loss : 0.337, iou : 71.201 pixel_acc : 88.759


3053it [03:06, 17.25it/s]

Epoch : 1 || 3050/8498 || loss : 0.180, iou : 78.216 pixel_acc : 95.437


3103it [03:09, 17.04it/s]

Epoch : 1 || 3100/8498 || loss : 0.168, iou : 93.445 pixel_acc : 96.621


3153it [03:12, 16.61it/s]

Epoch : 1 || 3150/8498 || loss : 0.203, iou : 92.387 pixel_acc : 96.229


3203it [03:15, 15.76it/s]

Epoch : 1 || 3200/8498 || loss : 0.067, iou : 96.327 pixel_acc : 98.572


3253it [03:18, 16.98it/s]

Epoch : 1 || 3250/8498 || loss : 0.104, iou : 91.309 pixel_acc : 97.470


3304it [03:21, 16.51it/s]

Epoch : 1 || 3300/8498 || loss : 0.247, iou : 59.047 pixel_acc : 94.165


3354it [03:24, 16.62it/s]

Epoch : 1 || 3350/8498 || loss : 0.158, iou : 90.131 pixel_acc : 97.384


3403it [03:27, 17.65it/s]

Epoch : 1 || 3400/8498 || loss : 0.387, iou : 44.825 pixel_acc : 91.003


3453it [03:30, 16.64it/s]

Epoch : 1 || 3450/8498 || loss : 0.660, iou : 65.408 pixel_acc : 82.743


3503it [03:33, 16.93it/s]

Epoch : 1 || 3500/8498 || loss : 0.063, iou : 67.771 pixel_acc : 98.604


3553it [03:36, 16.64it/s]

Epoch : 1 || 3550/8498 || loss : 0.423, iou : 85.305 pixel_acc : 92.118


3603it [03:39, 16.53it/s]

Epoch : 1 || 3600/8498 || loss : 0.183, iou : 91.911 pixel_acc : 95.799


3653it [03:42, 16.73it/s]

Epoch : 1 || 3650/8498 || loss : 0.222, iou : 60.118 pixel_acc : 94.599


3703it [03:45, 16.60it/s]

Epoch : 1 || 3700/8498 || loss : 0.116, iou : 94.632 pixel_acc : 97.348


3753it [03:48, 16.75it/s]

Epoch : 1 || 3750/8498 || loss : 0.756, iou : 42.040 pixel_acc : 87.347


3803it [03:51, 16.67it/s]

Epoch : 1 || 3800/8498 || loss : 0.221, iou : 43.822 pixel_acc : 95.863


3853it [03:54, 16.62it/s]

Epoch : 1 || 3850/8498 || loss : 0.065, iou : 81.407 pixel_acc : 97.907


3903it [03:57, 16.11it/s]

Epoch : 1 || 3900/8498 || loss : 0.341, iou : 77.221 pixel_acc : 89.078


3955it [04:00, 19.55it/s]

Epoch : 1 || 3950/8498 || loss : 0.462, iou : 67.931 pixel_acc : 82.857


4003it [04:02, 22.53it/s]

Epoch : 1 || 4000/8498 || loss : 0.396, iou : 76.491 pixel_acc : 88.915


4054it [04:04, 23.43it/s]

Epoch : 1 || 4050/8498 || loss : 0.121, iou : 94.293 pixel_acc : 97.066


4105it [04:07, 21.17it/s]

Epoch : 1 || 4100/8498 || loss : 0.070, iou : 92.449 pixel_acc : 99.170


4153it [04:09, 21.57it/s]

Epoch : 1 || 4150/8498 || loss : 0.124, iou : 92.394 pixel_acc : 97.540


4203it [04:12, 15.35it/s]

Epoch : 1 || 4200/8498 || loss : 0.048, iou : 28.564 pixel_acc : 99.020


4253it [04:15, 16.16it/s]

Epoch : 1 || 4250/8498 || loss : 0.390, iou : 47.673 pixel_acc : 87.940


4303it [04:18, 16.04it/s]

Epoch : 1 || 4300/8498 || loss : 0.080, iou : 93.604 pixel_acc : 98.050


4353it [04:22, 15.87it/s]

Epoch : 1 || 4350/8498 || loss : 0.163, iou : 87.270 pixel_acc : 95.561


4403it [04:25, 15.70it/s]

Epoch : 1 || 4400/8498 || loss : 0.113, iou : 94.904 pixel_acc : 98.073


4453it [04:28, 15.36it/s]

Epoch : 1 || 4450/8498 || loss : 0.447, iou : 74.254 pixel_acc : 85.534


4503it [04:31, 16.08it/s]

Epoch : 1 || 4500/8498 || loss : 0.251, iou : 94.848 pixel_acc : 97.715


4553it [04:34, 15.90it/s]

Epoch : 1 || 4550/8498 || loss : 0.192, iou : 48.715 pixel_acc : 92.594


4603it [04:37, 15.76it/s]

Epoch : 1 || 4600/8498 || loss : 1.706, iou : 25.691 pixel_acc : 55.679


4653it [04:40, 15.67it/s]

Epoch : 1 || 4650/8498 || loss : 0.202, iou : 33.834 pixel_acc : 95.984


4703it [04:44, 15.65it/s]

Epoch : 1 || 4700/8498 || loss : 0.347, iou : 77.791 pixel_acc : 89.715


4753it [04:47, 16.01it/s]

Epoch : 1 || 4750/8498 || loss : 0.249, iou : 60.867 pixel_acc : 96.355


4803it [04:50, 14.98it/s]

Epoch : 1 || 4800/8498 || loss : 0.050, iou : 33.124 pixel_acc : 99.134


4853it [04:53, 14.22it/s]

Epoch : 1 || 4850/8498 || loss : 0.431, iou : 82.456 pixel_acc : 91.535


4903it [04:57, 14.75it/s]

Epoch : 1 || 4900/8498 || loss : 0.275, iou : 47.097 pixel_acc : 90.871


4955it [05:00, 24.07it/s]

Epoch : 1 || 4950/8498 || loss : 0.163, iou : 83.727 pixel_acc : 96.142


5007it [05:02, 33.24it/s]

Epoch : 1 || 5000/8498 || loss : 0.256, iou : 59.184 pixel_acc : 92.699


5055it [05:03, 33.62it/s]

Epoch : 1 || 5050/8498 || loss : 0.170, iou : 92.048 pixel_acc : 95.860


5103it [05:06, 14.67it/s]

Epoch : 1 || 5100/8498 || loss : 0.037, iou : 52.244 pixel_acc : 99.025


5153it [05:09, 15.86it/s]

Epoch : 1 || 5150/8498 || loss : 0.452, iou : 31.297 pixel_acc : 87.531


5203it [05:12, 16.08it/s]

Epoch : 1 || 5200/8498 || loss : 0.380, iou : 42.258 pixel_acc : 91.054


5253it [05:15, 19.01it/s]

Epoch : 1 || 5250/8498 || loss : 0.573, iou : 23.419 pixel_acc : 79.835


5303it [05:18, 16.09it/s]

Epoch : 1 || 5300/8498 || loss : 0.025, iou : 68.509 pixel_acc : 99.413


5353it [05:21, 15.90it/s]

Epoch : 1 || 5350/8498 || loss : 0.912, iou : 40.638 pixel_acc : 76.126


5403it [05:23, 21.75it/s]

Epoch : 1 || 5400/8498 || loss : 0.215, iou : 53.827 pixel_acc : 96.724


5454it [05:26, 17.16it/s]

Epoch : 1 || 5450/8498 || loss : 0.180, iou : 94.925 pixel_acc : 97.773


5503it [05:29, 17.44it/s]

Epoch : 1 || 5500/8498 || loss : 0.199, iou : 81.640 pixel_acc : 94.490


5554it [05:32, 17.07it/s]

Epoch : 1 || 5550/8498 || loss : 0.234, iou : 90.743 pixel_acc : 95.162


5603it [05:35, 16.72it/s]

Epoch : 1 || 5600/8498 || loss : 1.320, iou : 47.069 pixel_acc : 65.734


5653it [05:38, 17.18it/s]

Epoch : 1 || 5650/8498 || loss : 0.182, iou : 90.763 pixel_acc : 95.666


5703it [05:41, 17.12it/s]

Epoch : 1 || 5700/8498 || loss : 0.148, iou : 61.787 pixel_acc : 97.966


5753it [05:44, 16.32it/s]

Epoch : 1 || 5750/8498 || loss : 0.368, iou : 77.226 pixel_acc : 87.150


5803it [05:47, 17.34it/s]

Epoch : 1 || 5800/8498 || loss : 0.124, iou : 83.260 pixel_acc : 96.121


5853it [05:50, 17.40it/s]

Epoch : 1 || 5850/8498 || loss : 0.125, iou : 94.167 pixel_acc : 97.031


5903it [05:53, 16.49it/s]

Epoch : 1 || 5900/8498 || loss : 0.063, iou : 86.128 pixel_acc : 98.599


5953it [05:56, 16.91it/s]

Epoch : 1 || 5950/8498 || loss : 0.648, iou : 46.744 pixel_acc : 71.666


6003it [05:59, 16.98it/s]

Epoch : 1 || 6000/8498 || loss : 0.215, iou : 81.729 pixel_acc : 92.572


6053it [06:02, 16.67it/s]

Epoch : 1 || 6050/8498 || loss : 0.128, iou : 62.955 pixel_acc : 97.643


6103it [06:04, 16.70it/s]

Epoch : 1 || 6100/8498 || loss : 0.433, iou : 83.376 pixel_acc : 90.994


6153it [06:07, 16.49it/s]

Epoch : 1 || 6150/8498 || loss : 0.033, iou : 33.259 pixel_acc : 99.776


6203it [06:10, 17.21it/s]

Epoch : 1 || 6200/8498 || loss : 0.120, iou : 49.408 pixel_acc : 98.817


6253it [06:13, 16.98it/s]

Epoch : 1 || 6250/8498 || loss : 0.198, iou : 55.480 pixel_acc : 94.059


6303it [06:16, 16.96it/s]

Epoch : 1 || 6300/8498 || loss : 0.276, iou : 62.807 pixel_acc : 91.980


6353it [06:19, 16.53it/s]

Epoch : 1 || 6350/8498 || loss : 0.038, iou : 83.201 pixel_acc : 99.093


6403it [06:22, 17.19it/s]

Epoch : 1 || 6400/8498 || loss : 0.081, iou : 85.315 pixel_acc : 97.680


6453it [06:25, 16.81it/s]

Epoch : 1 || 6450/8498 || loss : 0.380, iou : 25.746 pixel_acc : 88.120


6503it [06:28, 16.70it/s]

Epoch : 1 || 6500/8498 || loss : 0.160, iou : 79.696 pixel_acc : 96.332


6553it [06:31, 17.25it/s]

Epoch : 1 || 6550/8498 || loss : 0.419, iou : 73.735 pixel_acc : 85.482


6603it [06:34, 16.85it/s]

Epoch : 1 || 6600/8498 || loss : 0.123, iou : 93.984 pixel_acc : 97.177


6653it [06:37, 17.45it/s]

Epoch : 1 || 6650/8498 || loss : 0.243, iou : 59.159 pixel_acc : 94.872


6703it [06:40, 16.68it/s]

Epoch : 1 || 6700/8498 || loss : 0.407, iou : 55.643 pixel_acc : 89.031


6753it [06:43, 16.54it/s]

Epoch : 1 || 6750/8498 || loss : 0.363, iou : 54.747 pixel_acc : 90.873


6803it [06:46, 17.12it/s]

Epoch : 1 || 6800/8498 || loss : 0.326, iou : 41.287 pixel_acc : 87.311


6853it [06:49, 17.32it/s]

Epoch : 1 || 6850/8498 || loss : 0.316, iou : 61.745 pixel_acc : 96.081


6903it [06:52, 17.36it/s]

Epoch : 1 || 6900/8498 || loss : 0.306, iou : 35.601 pixel_acc : 84.905


6954it [06:55, 16.65it/s]

Epoch : 1 || 6950/8498 || loss : 0.109, iou : 93.976 pixel_acc : 97.444


7004it [06:58, 16.63it/s]

Epoch : 1 || 7000/8498 || loss : 0.387, iou : 41.655 pixel_acc : 88.822


7054it [07:01, 16.91it/s]

Epoch : 1 || 7050/8498 || loss : 0.187, iou : 90.653 pixel_acc : 95.659


7104it [07:03, 17.01it/s]

Epoch : 1 || 7100/8498 || loss : 0.653, iou : 30.317 pixel_acc : 72.780


7154it [07:06, 16.72it/s]

Epoch : 1 || 7150/8498 || loss : 0.195, iou : 60.649 pixel_acc : 96.823


7204it [07:09, 16.85it/s]

Epoch : 1 || 7200/8498 || loss : 0.142, iou : 88.548 pixel_acc : 96.309


7254it [07:12, 16.68it/s]

Epoch : 1 || 7250/8498 || loss : 0.794, iou : 28.338 pixel_acc : 74.158


7304it [07:15, 16.57it/s]

Epoch : 1 || 7300/8498 || loss : 1.332, iou : 28.260 pixel_acc : 70.245


7354it [07:18, 16.52it/s]

Epoch : 1 || 7350/8498 || loss : 0.175, iou : 91.826 pixel_acc : 96.002


7404it [07:21, 16.84it/s]

Epoch : 1 || 7400/8498 || loss : 0.315, iou : 85.051 pixel_acc : 91.933


7454it [07:24, 16.64it/s]

Epoch : 1 || 7450/8498 || loss : 0.580, iou : 31.486 pixel_acc : 77.386


7504it [07:27, 16.81it/s]

Epoch : 1 || 7500/8498 || loss : 0.183, iou : 75.172 pixel_acc : 93.907


7554it [07:30, 18.51it/s]

Epoch : 1 || 7550/8498 || loss : 0.267, iou : 59.951 pixel_acc : 94.735


7604it [07:33, 16.88it/s]

Epoch : 1 || 7600/8498 || loss : 0.509, iou : 68.220 pixel_acc : 85.681


7654it [07:35, 18.88it/s]

Epoch : 1 || 7650/8498 || loss : 0.100, iou : 92.274 pixel_acc : 97.444


7704it [07:38, 16.48it/s]

Epoch : 1 || 7700/8498 || loss : 0.193, iou : 89.657 pixel_acc : 94.551


7754it [07:41, 17.44it/s]

Epoch : 1 || 7750/8498 || loss : 0.547, iou : 61.149 pixel_acc : 84.134


7804it [07:44, 16.78it/s]

Epoch : 1 || 7800/8498 || loss : 0.118, iou : 59.804 pixel_acc : 97.666


7854it [07:47, 16.87it/s]

Epoch : 1 || 7850/8498 || loss : 0.322, iou : 78.175 pixel_acc : 91.792


7904it [07:50, 16.78it/s]

Epoch : 1 || 7900/8498 || loss : 0.377, iou : 80.671 pixel_acc : 89.799


7954it [07:53, 16.89it/s]

Epoch : 1 || 7950/8498 || loss : 0.194, iou : 91.707 pixel_acc : 96.093


8003it [07:55, 21.82it/s]

Epoch : 1 || 8000/8498 || loss : 0.313, iou : 21.536 pixel_acc : 90.743


8054it [07:58, 21.84it/s]

Epoch : 1 || 8050/8498 || loss : 0.141, iou : 61.382 pixel_acc : 97.059


8105it [08:00, 20.95it/s]

Epoch : 1 || 8100/8498 || loss : 0.308, iou : 88.998 pixel_acc : 97.196


8153it [08:02, 22.26it/s]

Epoch : 1 || 8150/8498 || loss : 0.393, iou : 59.021 pixel_acc : 93.745


8204it [08:05, 19.27it/s]

Epoch : 1 || 8200/8498 || loss : 0.831, iou : 26.137 pixel_acc : 71.283


8254it [08:08, 15.98it/s]

Epoch : 1 || 8250/8498 || loss : 0.127, iou : 79.747 pixel_acc : 96.732


8304it [08:11, 14.87it/s]

Epoch : 1 || 8300/8498 || loss : 0.196, iou : 89.824 pixel_acc : 94.761


8352it [08:14, 14.53it/s]

Epoch : 1 || 8350/8498 || loss : 0.194, iou : 87.762 pixel_acc : 94.718


8404it [08:18, 15.83it/s]

Epoch : 1 || 8400/8498 || loss : 0.450, iou : 40.481 pixel_acc : 85.729


8454it [08:21, 15.93it/s]

Epoch : 1 || 8450/8498 || loss : 0.066, iou : 47.759 pixel_acc : 98.252


8498it [08:23, 16.87it/s]


>> Epoch : 1 || AVG loss : 0.320, iou : 64.337 pixel_acc : 90.971 504.364 secs


2857it [00:52, 54.14it/s]


>> Epoch : 1 || AVG valid loss : 0.329, iou : 63.801 pixel_acc : 90.100 504.364 secs
>> Best validation set iou: 0.6380050805102193


# Aggregating Results

After you've trained Segmenter, load your best models and run the following block to check validation accuracy, and compare IoU improvements made by CRF. Since the validation set contains nearly 3,000 images, this will take up to 30 minutes.

You can regard that your implementation is correct if performance is in ± 2%p (pixel accuracy), 0.01 (mIoU) of the following values:
- Segmenter(use_tf=True)
  - After Epoch 0
    - AVG valid iou : 60.29 pixel_acc : 89.83
  - After Epoch 1
    - AVG valid iou : 65.19 pixel_acc : 90.11

- Segmenter(use_tf=False)
  - After Epoch 0
    - AVG valid iou : 60.13 pixel_acc : 89.30
  - After Epoch 1
    - AVG valid iou : 62.21 pixel_acc : 90.51


The exact values are subject to change, don't worry too much if you missed the range by a small margin.

In [21]:
# specify path to your best trained model.
# for example if you want to load Segmenter from folder 'trial_5', modify 'trial_99' into 'trial_5'.
segmenter_path = result_dir / 'trial_0' / 'best.pt'

# OPTIONAL: Read text below this code cell.
use_crf = False

net = Segmenter().to(device)
net.load_state_dict(torch.load(segmenter_path, map_location=device))

criterion = nn.CrossEntropyLoss(ignore_index=21, reduction='none')
colorize = Colorize(21, get_color_map())

net.eval()

valid_loss_total = 0
valid_ious = []
valid_pixel_accs = []
valid_ious_crf = []
valid_pixel_accs_crf = []

with torch.no_grad():
    for batch_idx, (image, label) in enumerate(test_loader):
        # Move variables to gpu.
        image = image.to(device)
        label = label.to(device)

        output, loss, pred = get_prediction(criterion, net, image, label)
        # CRF for some images.
        if use_crf:
            image_permuted = image.cpu().permute(1, 0, 2, 3)
            un_norm = torch.zeros_like(image_permuted)
            for idx, (im, m, s) in enumerate(zip(image_permuted, mean, std)):
                un_norm[idx] = (im * s) + m
            un_norm = un_norm.permute(1, 0, 2, 3)

            output_softmax = torch.nn.functional.softmax(output, dim=1).detach().cpu()
            un_norm_int = (un_norm * 255).squeeze().permute(1, 2, 0).numpy().astype(np.ubyte)
            pred_crf = dense_crf(un_norm_int, output_softmax.squeeze().numpy())
            pred_crf = np.expand_dims(np.argmax(pred_crf, 0), 0)

            target = label.squeeze(1).cpu().numpy()
            acc_crf, mean_iu_crf = label_accuracy_score(target, pred_crf, n_class=21)
            valid_pixel_accs_crf.append(acc_crf)
            valid_ious_crf.append(mean_iu_crf)

        target = label.squeeze(1).cpu().numpy()
        acc, mean_iu = label_accuracy_score(target, pred.cpu().numpy(), n_class=21)

        # update total loss.
        valid_loss_total += loss.item()

        valid_pixel_accs.append(acc)
        valid_ious.append(mean_iu)

        # # this is only for testing
        # if batch_idx > 50:
        #   break

    # calculate average IoU
    total_valid_ious = np.array(valid_ious).T
    total_valid_ious = np.nanmean(total_valid_ious).mean()
    total_valid_pixel_acc = np.array(valid_pixel_accs).mean()

    print(f'{type(net).__name__}:')
    print(f'Pixel accuracy: {total_valid_pixel_acc * 100:.3f}, mIoU: {total_valid_ious:.3f}')

    if use_crf:
        total_valid_ious_crf = np.array(valid_ious_crf).T
        total_valid_ious_crf = np.nanmean(total_valid_ious_crf).mean()
        total_valid_pixel_acc_crf = np.array(valid_pixel_accs_crf).mean()
        print(f'CRF Pixel accuracy: {total_valid_pixel_acc_crf * 100:.3f}, CRF mIoU: {total_valid_ious_crf:.3f}')


Segmenter:
Pixel accuracy: 90.100, mIoU: 0.638


**Optional**: One way to improve the semantic segmentation is to apply Conditional Randon Field (CRF) as post-processing. In a nutshell, CRF will constrain the labeling via penalizing different labels to similar pixels. Since the CRF works in the original image, some detailed structure information lost in the encoder can be reconstructed via this process.

You can practice the CRF by setting `use_crf=True` in the above code block. Feel free to try it and see how it refines the labels.