# Segmenter
Author Paper: https://arxiv.org/abs/2105.05633

In [1]:
%pip install einops timm imutils torchvision lightning torchmetrics 

You should consider upgrading via the '/Users/haily/.pyenv/versions/3.10.4/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


# Segmenter implemtation
Implementation is adopted from https://github.com/rstrudel/segmenter/

In [4]:
import os

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import _load_weights
from einops import rearrange
from torchinfo import summary
from tqdm import tqdm

# Import utils
from lightning.pytorch.loggers import CSVLogger
from lightning_utils import SegModule, SegDM
import lightning as L

def init_weights(m):
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
    # Rescale the grid of position embeddings when loading from state_dict. Adapted from
    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
    posemb_tok, posemb_grid = (
        posemb[:, :num_extra_tokens],
        posemb[0, num_extra_tokens:],
    )
    if grid_old_shape is None:
        gs_old_h = int(math.sqrt(len(posemb_grid)))
        gs_old_w = gs_old_h
    else:
        gs_old_h, gs_old_w = grid_old_shape

    gs_h, gs_w = grid_new_shape
    posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
    return posemb

def padding(im, patch_size, fill_value=0):
    # make the image sizes divisible by patch_size
    H, W = im.size(2), im.size(3)
    pad_h, pad_w = 0, 0
    if H % patch_size > 0:
        pad_h = patch_size - (H % patch_size)
    if W % patch_size > 0:
        pad_w = patch_size - (W % patch_size)
    im_padded = im
    if pad_h > 0 or pad_w > 0:
        im_padded = F.pad(im, (0, pad_w, 0, pad_h), value=fill_value)
    return im_padded


def unpadding(y, target_size):
    H, W = target_size
    H_pad, W_pad = y.size(2), y.size(3)
    # crop predictions on extra pixels coming from padding
    extra_h = H_pad - H
    extra_w = W_pad - W
    if extra_h > 0:
        y = y[:, :, :-extra_h]
    if extra_w > 0:
        y = y[:, :, :, :-extra_w]
    return y

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout, out_dim=None):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        if out_dim is None:
            out_dim = dim
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.drop = nn.Dropout(dropout)

    @property
    def unwrapped(self):
        return self

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, heads, dropout):
        super().__init__()
        self.heads = heads
        head_dim = dim // heads
        self.scale = head_dim ** -0.5
        self.attn = None

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)

    @property
    def unwrapped(self):
        return self

    def forward(self, x, mask=None):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.heads, C // self.heads)
            .permute(2, 0, 3, 1, 4)
        )
        
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
       

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
    
        return x, attn

class Block(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, dropout)
        self.mlp = FeedForward(dim, mlp_dim, dropout)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x, mask=None, return_attention=False):
        x = self.norm1(x)

        y, attn = self.attn(x, mask)
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        
        return x

class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim, channels):
        super().__init__()

        self.image_size = image_size
        if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0:
            raise ValueError("image dimensions must be divisible by the patch size")
        self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, im):
        B, C, H, W = im.shape
        x = self.proj(im).flatten(2).transpose(1, 2)
        return x


class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, n_layers, d_model, 
                 d_ff, n_heads, n_cls,
                 dropout=0.1, drop_path_rate=0.0, distilled=False, channels=3,
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, d_model, channels)
        self.patch_size = patch_size
        self.n_layers = n_layers
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.n_cls = n_cls

        # cls and pos tokens
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.distilled = distilled
        if self.distilled:
            self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
            self.pos_embed = nn.Parameter(
                torch.randn(1, self.patch_embed.num_patches + 2, d_model)
            )
            self.head_dist = nn.Linear(d_model, n_cls)
        else:
            self.pos_embed = nn.Parameter(
                torch.randn(1, self.patch_embed.num_patches + 1, d_model)
            )

        # transformer blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
        self.blocks = nn.ModuleList(
            [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
        )

        # output head
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, n_cls)

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        if self.distilled:
            trunc_normal_(self.dist_token, std=0.02)
        self.pre_logits = nn.Identity()

        self.apply(init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token", "dist_token"}

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix=""):
        _load_weights(self, checkpoint_path, prefix)

    def forward(self, im, return_features=False):
        B, _, H, W = im.shape
        PS = self.patch_size

        x = self.patch_embed(im)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        if self.distilled:
            dist_tokens = self.dist_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        else:
            x = torch.cat((cls_tokens, x), dim=1)

        pos_embed = self.pos_embed
        num_extra_tokens = 1 + self.distilled
        if x.shape[1] != pos_embed.shape[1]:
            pos_embed = resize_pos_embed(
                pos_embed,
                self.patch_embed.grid_size,
                (H // PS, W // PS),
                num_extra_tokens,
            )
        x = x + pos_embed
        x = self.dropout(x)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        if return_features:
            return x

        if self.distilled:
            x, x_dist = x[:, 0], x[:, 1]
            x = self.head(x)
            x_dist = self.head_dist(x_dist)
            x = (x + x_dist) / 2
        else:
            x = x[:, 0]
            x = self.head(x)
        return x

    def get_attention_map(self, im, layer_id):
        if layer_id >= self.n_layers or layer_id < 0:
            raise ValueError(
                f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
            )
        B, _, H, W = im.shape
        PS = self.patch_size

        x = self.patch_embed(im)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        if self.distilled:
            dist_tokens = self.dist_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        else:
            x = torch.cat((cls_tokens, x), dim=1)

        pos_embed = self.pos_embed
        num_extra_tokens = 1 + self.distilled
        if x.shape[1] != pos_embed.shape[1]:
            pos_embed = resize_pos_embed(
                pos_embed,
                self.patch_embed.grid_size,
                (H // PS, W // PS),
                num_extra_tokens,
            )
        x = x + pos_embed

        for i, blk in enumerate(self.blocks):
            if i < layer_id:
                x = blk(x)
            else:
                return blk(x, return_attention=True)

class DecoderLinear(nn.Module):
    def __init__(self, n_cls, patch_size, d_encoder):
        super().__init__()

        self.d_encoder = d_encoder
        self.patch_size = patch_size
        self.n_cls = n_cls

        self.head = nn.Linear(self.d_encoder, n_cls)
        self.apply(init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return set()

    def forward(self, x, im_size):
        H, W = im_size
        GS = H // self.patch_size
        x = self.head(x)
        ## h*w = n
        x = rearrange(x, "b (h w) c -> b c h w", h=GS)

        return x
    
class Segmenter(nn.Module):
    def __init__(self, encoder, decoder, n_cls):
        super().__init__()
        self.n_cls = n_cls
        self.patch_size = encoder.patch_size
        self.encoder = encoder
        self.decoder = decoder

    @torch.jit.ignore
    def no_weight_decay(self):
        def append_prefix_no_weight_decay(prefix, module):
            return set(map(lambda x: prefix + x, module.no_weight_decay()))

        nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
            append_prefix_no_weight_decay("decoder.", self.decoder)
        )
        return nwd_params

    def forward(self, im):
        H_ori, W_ori = im.size(2), im.size(3)
        im = padding(im, self.patch_size)
        H, W = im.size(2), im.size(3)

        x = self.encoder(im, return_features=True)

        # remove CLS/DIST tokens for decoding
        num_extra_tokens = 1 + self.encoder.distilled
        x = x[:, num_extra_tokens:]

        masks = self.decoder(x, (H, W))

        masks = F.interpolate(masks, size=(H, W), mode="bilinear")
        masks = unpadding(masks, (H_ori, W_ori))

        return masks

In [3]:
# darwin_config = {
#     'encoder': {
#         'image_size': (512, 512),
#         'patch_size': 16,
#         'd_model': 192,
#         'n_heads': 8,
#         'd_ff': 128,
#         'n_layers': 12,
#         'distilled': False,
#         'channels': 3,
#     },
#     'decoder': {
#         'drop_path_rate': 0.0,
#         'dropout': 0.1,
#         'n_layers': 2,
#     },
# }

# dm = SegDM(
#     batch_size=4,
#     mask_dir='/kaggle/input/img-segmentation/Darwin/mask',
#     img_dir='/kaggle/input/img-segmentation/Darwin/img',
# )
# logger = CSVLogger("logs", name=f"darwin", flush_logs_every_n_steps=1)
# model = SegmenterModule(darwin_config, num_classes=2)
# trainer = L.Trainer(fast_dev_run=False, logger=logger, max_time="00:11:00:00")
# trainer.fit(model, dm)
# trainer.test(model, dm)

# Shenzen

In [6]:
shenzen_config = {
    'encoder': {
        'image_size': (512, 512),
        'patch_size': 16,
        'd_model': 768,
        'n_heads': 3,
        'd_ff': 768 * 4, # 4 x d_model
        'n_layers': 12,
        'n_cls': 1000,
        'distilled': False,
        'channels': 3,
    },
}

n_classes = 2
# Create model
encoder = VisionTransformer(**shenzen_config['encoder'])
decoder = DecoderLinear(
            patch_size=encoder.patch_size, 
            n_cls=n_classes, 
            d_encoder=encoder.d_model)

shenzen_model = Segmenter(encoder, decoder, n_classes)

shenzen_dm = SegDM(
    batch_size=2,
    mask_dir='./datasets/Shenzhen/mask',
    img_dir='./datasets/Shenzhen/img',
)

shenzen_logger = CSVLogger("logs", name=f"shenzen_segmenter")
shenzen_module = SegModule(shenzen_model, num_classes=n_classes, result_path='shenzen_segmenter_result.csv')

shenzen_trainer = L.Trainer(fast_dev_run=False, logger=shenzen_logger, max_epochs=20)
shenzen_trainer.fit(shenzen_module, shenzen_dm)
shenzen_trainer.test(shenzen_module, shenzen_dm)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Missing logger folder: logs/shenzen_segmenter

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | Segmenter        | 87.2 M | train
1 | loss_fn   | CrossEntropyLoss | 0      | train
2 | f1        | BinaryF1Score    | 0      | train
3 | accuracy  | BinaryAccuracy   | 0      | train
4 | recall    | BinaryRecall     | 0      | train
5 | precision | BinaryPrecision  | 0      | train
6 | mean_iou  | MeanIoU          | 0      | train
-------------------------------------------------------
87.2 M    Trainable params
0         Non-trainable params
87.2 M    Total params
348.820   Total estimated model params size (MB)


318 examples in the training set...
106 examples in the validation set...
Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0:  27%|██▋       | 43/159 [00:47<02:07,  0.91it/s, v_num=0]

/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


142 examples in the test set...
Testing DataLoader 0:  69%|██████▉   | 49/71 [00:15<00:07,  3.07it/s]

In [None]:
# import matplotlib.pyplot as plt

# shenzen_dm = SegDM(
#     batch_size=4,
#     mask_dir='/kaggle/input/img-segmentation/Shenzhen/mask',
#     img_dir='/kaggle/input/img-segmentation/Shenzhen/img',
# )
# shenzen_dm.setup('test')

# loader = shenzen_dm.test_dataloader()
# test_imgs, test_masks = next(iter(loader))
# print(test_imgs.shape)
# loss = torch.nn.BCEWithLogitsLoss()


# trained_shenzen = SegmenterModule.load_from_checkpoint("/kaggle/input/segmenter/pytorch/shenzen-e9/2/shenzen-epoch9-step1590.ckpt")
# trained_shenzen.freeze()

# pred_masks = trained_shenzen.model(test_imgs)

# print('loss ', loss(pred_masks[1], test_masks[1]))
# pred_mask = torch.argmax(pred_masks[1], dim=0)
# test_mask = torch.argmax(test_masks[1], dim=0)

# print(pred_masks[1])
# print(test_masks[1])
# plt.imshow(pred_mask, cmap='gray')
# plt.imshow(test_mask, cmap='gray')

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# # Create a random RGB image
# rgb_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)

# # Display the RGB image
# plt.imshow(rgb_image)
# plt.title('Random RGB Image')
# plt.axis('off')  # Hide the axis
# plt.show()

# Covid

In [None]:
# covid_config = {
#     'encoder': {
#         'image_size': (299, 299),
#         'patch_size': 23,
#         'd_model': 192,
#         'n_heads': 8,
#         'd_ff': 128,
#         'n_layers': 12,
#         'distilled': False,
#         'channels': 1,
#     },
#     'decoder': {
#         'drop_path_rate': 0.0,
#         'dropout': 0.1,
#         'n_layers': 2,
#     },
# }

# covid_dm = SegDM(
#     batch_size=4,
#     img_dir='/kaggle/input/img-segmentation/Covid19 Radiography/COVID-19_Radiography_Dataset/COVID/images',
#     mask_dir='/kaggle/input/img-segmentation/Covid19 Radiography/COVID-19_Radiography_Dataset/COVID/masks',
# )
# covid_logger = CSVLogger("logs", name=f"covid", flush_logs_every_n_steps=1)
# covid_model = SegmenterModule(covid_config, num_classes=2)
# covid_trainer = L.Trainer(fast_dev_run=True, logger=covid_logger, max_epochs=10)
# covid_trainer.fit(covid_model, covid_dm)
# covid_trainer.test(covid_model, covid_dm)