In [1]:
import argparse
import logging
import os
import random
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from mcode.utils import clip_gradient, AvgMeter

from glob import glob
from skimage.io import imread
import matplotlib.pyplot as plt
import pandas as pd
from collections import OrderedDict
from torch.autograd import Variable
from datetime import datetime
import torch.nn.functional as F
import cv2
from albumentations.augmentations import transforms
from albumentations.core.composition import Compose, OneOf
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import argparse
import copy
import os
import os.path as osp
import time

import mmcv
import torch
from mmcv.runner import init_dist
from mmcv.utils import Config, DictAction, get_git_hash

from mmseg import __version__
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger

In [3]:
# model segformerb4 settings
from mmseg.models import build_segmentor
from mmcv.runner.optimizer import build_optimizer
import torch

def segformer(arch):
    num_layers = []
    pretrained = f'pretrained/mit_{arch}_mmseg.pth'
    if arch == 'b1':
        num_layers = [2,2,2,2]
    if arch == 'b2':
        num_layers = [3, 4, 6, 3 ]
    if arch == 'b3':
        num_layers = [3, 4, 18, 3]
    if arch == 'b4':
        num_layers = [3, 8, 27, 3]
    model = dict(
        type='SunSegmentor',
        backbone=dict(
            type='MixVisionTransformer',
            in_channels=3,
            embed_dims=64,
            num_stages=4,
            num_layers=num_layers,
            num_heads=[1, 2, 5, 8],
            patch_sizes=[7, 3, 3, 3],
            sr_ratios=[8, 4, 2, 1],
            out_indices=(0, 1, 2, 3),
            mlp_ratio=4,
            qkv_bias=True,
            drop_rate=0.0,
            attn_drop_rate=0.0,
            drop_path_rate=0.1,
            pretrained=pretrained),
        decode_head=dict(
            type='DRPHead',
            in_channels=[64, 128, 320, 512],
            in_index=[0, 1, 2, 3],
            channels=128,
            dropout_ratio=0.1,
            num_classes=1,
            norm_cfg=dict(type='BN', requires_grad=True),
            align_corners=False,
            loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
        # model training and testing settings
        train_cfg=dict(),
        test_cfg=dict(mode='whole'))
    model = build_segmentor(model)
    model.init_weights()
    return model

In [4]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, img_paths, mask_paths, aug=True, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.aug = aug
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]
        # image = imread(img_path)
        # mask = imread(mask_path)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, 0)
        # name = self.img_paths[idx].split('/')[-1]

        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        else:
            image = cv2.resize(image, (352, 352))
            mask = cv2.resize(mask, (352, 352)) 

        image = image.astype('float32') / 255
        image = image.transpose((2, 0, 1))

        mask = mask[:,:,np.newaxis]
        mask = mask.astype('float32') / 255
        mask = mask.transpose((2, 0, 1))

        return np.asarray(image), np.asarray(mask)

In [5]:
epsilon = 1e-7
def recall_m(y_true, y_pred):
  true_positives = torch.sum(torch.round(torch.clip(y_true * y_pred, 0, 1)))
  possible_positives = torch.sum(torch.round(torch.clip(y_true, 0, 1)))
  recall = true_positives / (possible_positives + epsilon)
  return recall

def precision_m(y_true, y_pred):
  true_positives = torch.sum(torch.round(torch.clip(y_true * y_pred, 0, 1)))
  predicted_positives = torch.sum(torch.round(torch.clip(y_pred, 0, 1)))
  precision = true_positives / (predicted_positives + epsilon)
  return precision

def dice_m(y_true, y_pred):
  precision = precision_m(y_true, y_pred)
  recall = recall_m(y_true, y_pred)
  return 2*((precision*recall)/(precision+recall+epsilon))

def iou_m(y_true, y_pred):
  precision = precision_m(y_true, y_pred)
  recall = recall_m(y_true, y_pred)
  return recall*precision/(recall+precision-recall*precision +epsilon)

In [6]:
class FocalLossV1(nn.Module):

    def __init__(self,
                 alpha=0.25,
                 gamma=2,
                 reduction='mean',):
        super(FocalLossV1, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.crit = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, logits, label):
        # compute loss
        logits = logits.float() # use fp32 if logits is fp16
        with torch.no_grad():
            alpha = torch.empty_like(logits).fill_(1 - self.alpha)
            alpha[label == 1] = self.alpha

        probs = torch.sigmoid(logits)
        pt = torch.where(label == 1, probs, 1 - probs)
        ce_loss = self.crit(logits, label.float())
        loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss)
        if self.reduction == 'mean':
            loss = loss.mean()
        if self.reduction == 'sum':
            loss = loss.sum()
        return loss

def structure_loss(pred, mask):
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wfocal = FocalLossV1()(pred, mask)
    wfocal = (wfocal*weit).sum(dim=(2,3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1)/(union - inter+1)
    return (wfocal + wiou).mean()

In [7]:
def train(train_loader, model, optimizer, epoch, lr_scheduler, deep=False):
    model.train()
    # ---- multi-scale training ----
    size_rates = [0.75, 1, 1.25]
    loss_record = AvgMeter()
    dice, iou = AvgMeter(), AvgMeter()
    for i, pack in enumerate(train_loader, start=1):
        if epoch <= 1:
                optimizer.param_groups[0]["lr"] = (epoch * i) / (1.0 * total_step) * init_lr
        else:
            lr_scheduler.step()

        for rate in size_rates: 
            optimizer.zero_grad()
            # ---- data prepare ----
            images, gts = pack
            images = Variable(images).cuda(1)
            gts = Variable(gts).cuda(1)
            # ---- rescale ----
            trainsize = int(round(trainsize_init*rate/32)*32)
            images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            # ---- forward ----
            map5, map4, map3, map2, map1 = model(images)
            map1 = F.upsample(map1, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            map2 = F.upsample(map2, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            map3 = F.upsample(map3, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            map4 = F.upsample(map4, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            map5 = F.upsample(map5, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            loss = structure_loss(map1, gts) + structure_loss(map2, gts) + structure_loss(map3, gts) + structure_loss(map4, gts) + structure_loss(map5, gts)
            # ---- metrics ----
            dice_score = dice_m(map2, gts)
            iou_score = iou_m(map2, gts)
            # ---- backward ----
            loss.backward()
            clip_gradient(optimizer, clip)
            optimizer.step()
            # ---- recording loss ----
            if rate == 1:
                loss_record.update(loss.data, batchsize)
                dice.update(dice_score.data, batchsize)
                iou.update(iou_score.data, batchsize)

        # ---- train visualization ----
        if i == total_step:
            print('{} Training Epoch [{:03d}/{:03d}], '
                  '[loss: {:0.4f}, dice: {:0.4f}, iou: {:0.4f}]'.
                  format(datetime.now(), epoch, num_epochs,\
                         loss_record.show(), dice.show(), iou.show()))

    ckpt_path = save_path + 'last.pth'
    print('[Saving Checkpoint:]', ckpt_path)
    checkpoint = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': lr_scheduler.state_dict()
    }
    torch.save(checkpoint, ckpt_path)

    log = OrderedDict([
        ('loss', loss_record.show()), ('dice', dice.show()), ('iou', iou.show()),
    ])

    return log

In [8]:
def recall_np(y_true, y_pred):
    true_positives = np.sum(np.round(np.clip(y_true * y_pred, 0, 1)))
    possible_positives = np.sum(np.round(np.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + epsilon)
    return recall

def precision_np(y_true, y_pred):
    true_positives = np.sum(np.round(np.clip(y_true * y_pred, 0, 1)))
    predicted_positives = np.sum(np.round(np.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + epsilon)
    return precision

def dice_np(y_true, y_pred):
    precision = precision_np(y_true, y_pred)
    recall = recall_np(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+epsilon))

def iou_np(y_true, y_pred):
    intersection = np.sum(np.round(np.clip(y_true * y_pred, 0, 1)))
    union = np.sum(y_true)+np.sum(y_pred)-intersection
    return intersection/(union+epsilon)

def get_scores(gts, prs):
    mean_precision = 0
    mean_recall = 0
    mean_iou = 0
    mean_dice = 0
    for gt, pr in zip(gts, prs):
        mean_precision += precision_np(gt, pr)
        mean_recall += recall_np(gt, pr)
        mean_iou += iou_np(gt, pr)
        mean_dice += dice_np(gt, pr)

    mean_precision /= len(gts)
    mean_recall /= len(gts)
    mean_iou /= len(gts)
    mean_dice /= len(gts)        
    
    print(f"scores: dice={mean_dice}, miou={mean_iou}, precision={mean_precision}, recall={mean_recall}")

    return (mean_iou, mean_dice, mean_precision, mean_recall)

from mcode.config import *
from tabulate import tabulate

def inference(model):
    print("#" * 20)
    model.eval()
    dataset_names = ['Kvasir', 'CVC-ClinicDB', 'CVC-ColonDB', 'CVC-300', 'ETIS-LaribPolypDB']
    table = []
    headers = ['Dataset', 'IoU', 'Dice']
    ious, dices = AverageMeter(), AverageMeter()

    for dataset_name in dataset_names:
        data_path = f'{test_folder}/{dataset_name}'
        X_test = glob.glob('{}/images/*'.format(data_path))
        X_test.sort()
        y_test = glob.glob('{}/masks/*'.format(data_path))
        y_test.sort()

        test_dataset = Dataset(X_test, y_test)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            pin_memory=True,
            drop_last=False)

        print('Dataset_name:', dataset_name)
        tp_all = 0
        fp_all = 0
        fn_all = 0
        mean_iou = 0
        gts = []
        prs = []
        for i, pack in enumerate(test_loader, start=1):
            image, gt = pack
            # name = name[0]
            gt = gt[0][0]
            gt = np.asarray(gt, np.float32)
            image = image.cuda(1)

            res, res2, res3, res4, res5 = model(image)
            res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
            res = res.sigmoid().data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)
            pr = res.round()
            gts.append(gt)
            prs.append(pr)
            # cv2.imwrite(os.path.join(save_path, dataset_name, name), res)
        mean_iou, mean_dice, _, _ = get_scores(gts, prs)
        ious.update(mean_iou)
        dices.update(mean_dice)
        table.append([dataset_name, mean_iou, mean_dice])
    table.append(['Total', ious.avg, dices.avg])
    print(tabulate(table, headers=headers, tablefmt="fancy_grid"))
    print("#"*20)



In [9]:
init_lr = 1e-4
batchsize = 8
trainsize_init = 352
clip = 0.5
num_epochs= 20
train_save = 'ColonFormerB4'

save_path = 'run/test/{}/'.format(train_save)
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)
else:
    print("Save path existed")
#     sys.exit(1)


log = pd.DataFrame(index=[], columns=[
    'epoch', 'lr', 'loss', 'dice', 'iou', 'val_loss', 'val_dice', 'val_iou'
])
train_img_paths = []
train_mask_paths = []
train_img_paths = glob.glob('/home/nguyen.van.quan/scatsimclr/TrainDataset/image/*')
train_mask_paths = glob.glob('/home/nguyen.van.quan/scatsimclr/TrainDataset/mask/*')
train_img_paths.sort()
train_mask_paths.sort()

train_dataset = Dataset(train_img_paths, train_mask_paths)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batchsize,
    shuffle=True,
    pin_memory=True,
    drop_last=True
)

total_step = len(train_loader)
model = segformer('b1').cuda(1)

# ---- flops and params ----
params = model.parameters()
optimizer = torch.optim.Adam(params, init_lr)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                    T_max=len(train_loader)*num_epochs,
                                    eta_min=init_lr/1000)



start_epoch = 1

print("#"*20, f"Start Training", "#"*20)
for epoch in range(start_epoch, num_epochs+1):
    train_log = train(train_loader, model, optimizer, epoch, lr_scheduler)


    if epoch >= num_epochs-20:
        inference(model)

2023-01-09 03:37:38,981 - mmcv - INFO - initialize MixVisionTransformer with init_cfg {'type': 'Pretrained', 'checkpoint': 'pretrained/mit_b1_mmseg.pth'}
2023-01-09 03:37:38,981 - mmcv - INFO - load model from: pretrained/mit_b1_mmseg.pth
2023-01-09 03:37:38,983 - mmcv - INFO - load checkpoint from local path: pretrained/mit_b1_mmseg.pth


Save path existed


2023-01-09 03:37:39,074 - mmcv - INFO - initialize DRPHead with init_cfg {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
2023-01-09 03:37:39,085 - mmcv - INFO - 
backbone.layers.0.0.projection.weight - torch.Size([64, 3, 7, 7]): 
PretrainedInit: load from pretrained/mit_b1_mmseg.pth 
 
2023-01-09 03:37:39,085 - mmcv - INFO - 
backbone.layers.0.0.projection.bias - torch.Size([64]): 
PretrainedInit: load from pretrained/mit_b1_mmseg.pth 
 
2023-01-09 03:37:39,086 - mmcv - INFO - 
backbone.layers.0.0.norm.weight - torch.Size([64]): 
PretrainedInit: load from pretrained/mit_b1_mmseg.pth 
 
2023-01-09 03:37:39,086 - mmcv - INFO - 
backbone.layers.0.0.norm.bias - torch.Size([64]): 
PretrainedInit: load from pretrained/mit_b1_mmseg.pth 
 
2023-01-09 03:37:39,087 - mmcv - INFO - 
backbone.layers.0.1.0.norm1.weight - torch.Size([64]): 
PretrainedInit: load from pretrained/mit_b1_mmseg.pth 
 
2023-01-09 03:37:39,087 - mmcv - INFO - 
backbone.layers.0.1.0.norm1.bias - torch.Size

#################### Start Training ####################




2023-01-09 03:40:50.198157 Training Epoch [001/020], [loss: 2.6127, dice: 0.8152, iou: 0.6943]
[Saving Checkpoint:] run/test/ColonFormerB4/last.pth
####################
Dataset_name: Kvasir
scores: dice=0.8331712367622537, miou=0.7496770392146545, precision=0.8682838595882381, recall=0.8579024576624572
Dataset_name: CVC-ClinicDB
scores: dice=0.7613430344557335, miou=0.6711967503626015, precision=0.8022565800783995, recall=0.8044255898888888
Dataset_name: CVC-ColonDB
scores: dice=0.6564078183398577, miou=0.54951816153832, precision=0.7284840184488995, recall=0.7065402135883325
Dataset_name: CVC-300
scores: dice=0.8052087174475677, miou=0.7066440335083727, precision=0.836788197344132, recall=0.8227181564226879
Dataset_name: ETIS-LaribPolypDB
scores: dice=0.5710594247054775, miou=0.4767441417064318, precision=0.5330815315925166, recall=0.790445817790942
╒═══════════════════╤══════════╤══════════╕
│ Dataset           │      IoU │     Dice │
╞═══════════════════╪══════════╪══════════╡
│ Kva