# Install libraries

In [None]:
# ! pip install thop
# ! pip install segmentation-models-pytorch
# ! pip install pytorch_lightning
# ! pip install torchview

# Mount drive and data location

In [None]:
from google.colab import drive
drive.mount("/content/gdrive", force_remount=True)
!ls '/content/gdrive/My Drive/'
%cd /content/gdrive/My Drive/polyp_seg/data/CVC-ClinicDB
! pwd

Mounted at /content/gdrive
motya  polyp_seg


# Data Preprocessing

In [None]:
! pwd
import os

# Paths to the directories containing images and masks
images_dir = './train/images' #'/content/gdrive/My Drive/polyp_seg/data/polyDB/images/train'
masks_dir = './train/masks'  # '/content/gdrive/My Drive/polyp_seg/data/polyDB/masks/train'

# Output file
output_file = './train.txt' # '/content/gdrive/My Drive/polyp_seg/data/polyDB/train.txt'

# List all files in the images and masks directories
image_files = sorted(os.listdir(images_dir))  # Sorting ensures correct pairing
mask_files = sorted(os.listdir(masks_dir))

# Ensure both directories have the same number of files
if len(image_files) != len(mask_files):
    print("Warning: The number of images and masks do not match!")

# Open the output file in write mode
with open(output_file, 'w') as f:
    # Loop over both the image and mask files
    for img_file, mask_file in zip(image_files, mask_files):
        # Construct the full path in the desired format
        line = f"{os.path.join(images_dir, img_file)} {os.path.join(masks_dir, mask_file)}\n"
        f.write(line)

print(f"train.txt file has been created with {len(image_files)} entries.")


/content/gdrive/MyDrive/polyp_seg/data/CVC-ClinicDB
train.txt file has been created with 488 entries.


# Initialize model weight, learning rate

In [None]:
import numpy as np
from PIL import Image
from thop import profile
from thop import clever_format
from torch import nn
import torch.nn.init as initer
import cv2, os


def initialize_weights(*models):
    """
    Initialize Model Weights
    """
    for model in models:
        for module in model.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()


def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1):
    """Sets the learning rate to the base LR decayed by 10 every step epochs"""
    lr = base_lr * (multiplier ** (epoch // step_epoch))
    return lr


def poly_learning_rate(base_lr, curr_iter, max_iter, power=0.9):
    """poly learning rate policy"""
    lr = base_lr * (1 - float(curr_iter) / max_iter) ** power
    return lr


def check_mkdir(dir_name):
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)


def check_makedirs(dir_name):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)


def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'):
    """
    :param model: Pytorch Model which is nn.Module
    :param conv:  'kaiming' or 'xavier'
    :param batchnorm: 'normal' or 'constant'
    :param linear: 'kaiming' or 'xavier'
    :param lstm: 'kaiming' or 'xavier'
    """
    for m in model.modules():
        if isinstance(m, (nn.modules.conv._ConvNd)):
            if conv == 'kaiming':
                initer.kaiming_normal_(m.weight)
            elif conv == 'xavier':
                initer.xavier_normal_(m.weight)
            else:
                raise ValueError("init type of conv error.\n")
            if m.bias is not None:
                initer.constant_(m.bias, 0)

        elif isinstance(m, (nn.modules.batchnorm._BatchNorm)):
            if batchnorm == 'normal':
                initer.normal_(m.weight, 1.0, 0.02)
            elif batchnorm == 'constant':
                initer.constant_(m.weight, 1.0)
            else:
                raise ValueError("init type of batchnorm error.\n")
            initer.constant_(m.bias, 0.0)

        elif isinstance(m, nn.Linear):
            if linear == 'kaiming':
                initer.kaiming_normal_(m.weight)
            elif linear == 'xavier':
                initer.xavier_normal_(m.weight)
            else:
                raise ValueError("init type of linear error.\n")
            if m.bias is not None:
                initer.constant_(m.bias, 0)

        elif isinstance(m, nn.LSTM):
            for name, param in m.named_parameters():
                if 'weight' in name:
                    if lstm == 'kaiming':
                        initer.kaiming_normal_(param)
                    elif lstm == 'xavier':
                        initer.xavier_normal_(param)
                    else:
                        raise ValueError("init type of lstm error.\n")
                elif 'bias' in name:
                    initer.constant_(param, 0)


def group_weight(weight_group, module, lr):
    group_decay = []
    group_no_decay = []
    for m in module.modules():
        if isinstance(m, nn.Linear):
            group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, nn.modules.conv._ConvNd):
            group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, nn.modules.batchnorm._BatchNorm):
            if m.weight is not None:
                group_no_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
    assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
    weight_group.append(dict(params=group_decay, lr=lr))
    weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
    return weight_group


def colorize(gray, palette):
    # gray: numpy array of the label and 1*3N size list palette
    color = Image.fromarray(gray.astype(np.uint8)).convert('P')
    color.putpalette(palette)
    return color


def overlay(img, mask, color=(1.0, 0, 0), alpha=0.4, resize=None):
    """Combines image and its segmentation mask into a single image.

    Params:
        image: Training image.
        mask: Segmentation mask.
        color: Color for segmentation mask rendering.
        alpha: Segmentation mask's transparency.
        resize: If provided, both image and its mask are resized before blending them together.

    Returns:
        image_combined: The combined image.

    """
    color = np.asarray(color).reshape(3, 1, 1)
    colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
    masked = np.ma.MaskedArray(img, mask=colored_mask, fill_value=color)
    image_overlay = masked.filled()
    img = img.transpose(1, 2, 0)
    image_overlay = image_overlay.transpose(1, 2, 0)

    if resize is not None:
        img = cv2.resize(img, resize)
        image_overlay = cv2.resize(image_overlay, resize)

    image_combined = cv2.addWeighted(img, 1 - alpha, image_overlay, alpha, 0)

    return image_combined


def CalParams(model, input_tensor):
    """
    Usage:
        Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter)
    Necessarity:
        from thop import profile
        from thop import clever_format
    :param model:
    :param input_tensor:
    :return:
    """
    flops, params = profile(model, inputs=(input_tensor,))
    flops, params = clever_format([flops, params], "%.3f")
    print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params))

# Make dataset

In [None]:
import os
import os.path
import cv2
import numpy as np
import torch
from torchvision import transforms as T
from torch.utils.data import Dataset
from scipy.ndimage.morphology import distance_transform_edt
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker


def make_dataset(split='train', data_root=None, data_list=None):
    assert split in ['train', 'val', 'test']
    if not os.path.isfile(data_list):
        raise (RuntimeError("Image list file do not exist: " + data_list + "\n"))
    image_label_list = []
    list_read = open(data_list).readlines()
    print("Totally {} samples in {} set.".format(len(list_read), split))
    print("Starting Checking image&label pair {} list...".format(split))
    for line in list_read:
        line = line.strip()
        line_split = line.split()
        if split == 'test':
            if len(line_split) != 1:
                raise (RuntimeError("Image list file read line error : " + line + "\n"))
            image_name = os.path.join(data_root, line_split[0])
            label_name = image_name  # just set place holder for label_name, not for use
        else:
            if len(line_split) != 2:
                raise (RuntimeError("Image list file read line error : " + line + "\n"))
            image_name = os.path.join(data_root, line_split[0])
            label_name = os.path.join(data_root, line_split[1])
        '''
        following check costs some time
        if is_image_file(image_name) and is_image_file(label_name) and os.path.isfile(image_name) and os.path.isfile(label_name):
            item = (image_name, label_name)
            image_label_list.append(item)
        else:
            raise (RuntimeError("Image list file line error : " + line + "\n"))
        '''
        item = (image_name, label_name)
        image_label_list.append(item)
    print("Checking image&label pair {} list done!".format(split))
    return image_label_list


class MyDataset(Dataset):
    def __init__(self, split='train', data_root=None, data_list=None, transform=None, sigma=10):
        self.sigma = sigma
        print('sigma is {}'.format(sigma))
        self.split = split
        self.data_list = make_dataset(split, data_root, data_list)
        self.transform = transform
        self.as_tensor = T.Compose([
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        image_path, label_path = self.data_list[index]
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)  # BGR 3 channel ndarray wiht shape H * W * 3
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # convert cv2 read image from BGR order to RGB order
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)  # GRAY 1 channel ndarray with shape H * W

        label = np.float32(label > 128)

        if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]:
            raise (RuntimeError("Image & label shape mismatch: " + image_path + " " + label_path + "\n"))

        augmented = self.transform(image=image, mask=label)
        image, label = augmented['image'], augmented['mask']

        ibdm = self.distribution_map(label, sigma=self.sigma)

        # if True and self.split=='train':
        #     plt.figure(dpi=400)
        #     plt.subplot(131)
        #     plt.title('image')
        #     plt.imshow(image)
        #     plt.xticks([]), plt.yticks([])  # 去除坐标轴
        #
        #     plt.subplot(132)
        #     plt.title('label')
        #     plt.imshow(label, cmap=plt.cm.gray)
        #     plt.xticks([]), plt.yticks([])  # 去除坐标轴
        #
        #     plt.subplot(133)
        #     plt.title('IBDM')
        #     plt.imshow(ibdm, cmap=plt.cm.jet)
        #     plt.xticks([]), plt.yticks([])  # 去除坐标轴
        #
        #     plt.show()
        #     plt.close()

        return self.as_tensor(image), \
               torch.tensor(label, dtype=torch.float), \
               torch.tensor(ibdm, dtype=torch.float)

    def distribution_map(self, mask, sigma):
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)  # 消除标注的问题孤立点

        dist1 = distance_transform_edt(mask)
        dist2 = distance_transform_edt(1-mask)
        dist = dist1 + dist2
        dist = dist - 1

        f = lambda x, sigma: 1/(np.sqrt(2*np.pi)*sigma) * np.exp(-x**2/(2*sigma**2))

        bdm = f(dist, sigma)

        bdm[bdm < 0] = 0

        return bdm * (sigma ** 2)




  from scipy.ndimage.morphology import distance_transform_edt


# Define metric

In [None]:
# from pytorch_lightning.metrics import Metric
import numpy as np
import torch


def dice(preds: torch.Tensor, target: torch.Tensor, th=0.5, if_sigmoid=True):
    if preds.shape != target.shape:
        preds = preds.squeeze(1)

    assert preds.shape == target.shape

    if not isinstance(preds, torch.FloatTensor):
        preds = preds.float()

    if if_sigmoid:
        preds = preds.sigmoid()

    preds = preds.view(-1)
    target = (target > 0).float().view(-1)

    p = (preds > th).float()
    inter = (p * target).float().sum().item()
    union = (p + target).float().sum().item()
    return 2.0 * inter / (union + 1e-6)


def mean_dice(preds: torch.Tensor, target: torch.Tensor, if_sigmoid=True):
    if preds.shape != target.shape:
        preds = preds.squeeze(1)

    assert preds.shape == target.shape

    if not isinstance(preds, torch.FloatTensor):
        preds = preds.float()

    if if_sigmoid:
        preds = preds.sigmoid()

    preds = preds.view(-1)
    target = (target > 0).float().view(-1)

    mdice = 0

    for th in np.arange(0, 1+1/255, 1/255):

        p = (preds > th).float()
        inter = (p * target).float().sum().item()
        union = (p + target).float().sum().item()

        mdice += 2.0 * inter / (union + 1e-6)

    return mdice/len(np.arange(0, 1+1/255, 1/255))

# Modal

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
# from metric.dice import mean_dice
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.heads import SegmentationHead
from torchview import draw_graph
# from utils.util import initialize_weights
import cv2, os
# from utils.util import overlay


def seg_loss(pred, mask):
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1)/(union - inter+1)
    return (wbce + wiou).mean()


def bdm_loss(pred, target, thresh=0.002, min_ratio=0.1):

    pred = pred.view(-1)
    target = target.view(-1)

    loss = F.mse_loss(pred, target, reduction='none')
    _, index = loss.sort()  # 从小到大排序

    threshold_index = index[-round(min_ratio * len(index))]  # 找到min_kept数量的hardexample的阈值

    if loss[threshold_index] < thresh:  # 为了保证参与loss的比例不少于min_ratio
        thresh = loss[threshold_index].item()

    loss[loss < thresh] = 0

    loss = loss.mean()

    return loss


class Conv2dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            use_batchnorm=True,
    ):

        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=not use_batchnorm,
        )
        relu = nn.ReLU(inplace=True)

        if use_batchnorm:
            bn = nn.BatchNorm2d(out_channels)

        else:
            bn = nn.Identity()

        super(Conv2dReLU, self).__init__(conv, bn, relu)


class RFB_modified(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(RFB_modified, self).__init__()
        self.relu = nn.ReLU(True)
        self.branch0 = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 1),
        )
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 1),
            nn.Conv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
            nn.Conv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
            nn.BatchNorm2d(out_channel),
            nn.Conv2d(out_channel, out_channel, 3, padding=3, dilation=3)
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 1),
            nn.Conv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
            nn.Conv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
            nn.BatchNorm2d(out_channel),
            nn.Conv2d(out_channel, out_channel, 3, padding=5, dilation=5)
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 1),
            nn.Conv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
            nn.Conv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
            nn.BatchNorm2d(out_channel),
            nn.Conv2d(out_channel, out_channel, 3, padding=7, dilation=7)
        )
        self.conv_cat = nn.Conv2d(4 * out_channel, out_channel, 3, padding=1)
        self.conv_res = nn.Conv2d(in_channel, out_channel, 1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))

        x = self.relu(x_cat + self.conv_res(x))
        return x


class Agg(nn.Module):
    def __init__(self, channel=64):
        super(Agg, self).__init__()
        self.h2l_pool = nn.AvgPool2d((2, 2), stride=2)
        self.l2h_up = nn.Upsample(scale_factor=2, mode="nearest")

        # stage 1
        self.h2h_1 = nn.Sequential(
            Conv2dReLU(channel, channel, 3, 1, 1)
        )
        self.h2l_1 = nn.Sequential(
            Conv2dReLU(channel, channel, 3, 1, 1)
        )
        self.l2h_1 = nn.Sequential(
            Conv2dReLU(channel, channel, 3, 1, 1)
        )
        self.l2l_1 = nn.Sequential(
            Conv2dReLU(channel, channel, 3, 1, 1)
        )

        # stage 2
        self.h2h_2 = nn.Sequential(
            Conv2dReLU(channel, channel, 3, 1, 1)
        )
        self.l2h_2 = nn.Sequential(
            Conv2dReLU(channel, channel, 3, 1, 1)
        )

    def forward(self, h, l):
        # stage 1
        h2h = self.h2h_1(h)
        h2l = self.h2l_1(self.h2l_pool(h))
        l2l = self.l2l_1(l)
        l2h = self.l2h_1(self.l2h_up(l))
        h = h2h + l2h
        l = l2l + h2l

        # stage 2
        h2h = self.h2h_2(h)
        l2h = self.l2h_2(self.l2h_up(l))
        out = h2h + l2h
        return out


class BDMM(nn.Module):
    def __init__(self, inplanes: list, midplanes=32, upsample=8):
        super(BDMM, self).__init__()
        assert len(inplanes) == 3

        self.rfb1 = RFB_modified(inplanes[0], midplanes)
        self.rfb2 = RFB_modified(inplanes[1], midplanes)
        self.rfb3 = RFB_modified(inplanes[2], midplanes)

        self.down = nn.MaxPool2d(kernel_size=2, stride=2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear')

        self.agg1 = Agg(midplanes)
        self.agg2 = Agg(midplanes)

        self.conv_out = nn.Sequential(
            Conv2dReLU(midplanes, 1, 3, padding=1),
            nn.Upsample(scale_factor=upsample, mode='bilinear', align_corners=True),
        )

    def forward(self, x1, x2, x3):
        x1 = self.rfb1(x1)
        x2 = self.rfb2(x2)
        x3 = self.rfb3(x3)

        x2 = self.agg1(x2, x3)
        x1 = self.agg2(x1, x2)

        out = self.conv_out(x1)

        return out


class BDGD_A(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.l2h_up = nn.Upsample(scale_factor=2, mode="nearest")

        # stage 1
        self.l2l_0 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels),
            nn.Conv2d(in_channels, out_channels, 1),
        )

        # stage 2
        self.l2h_1 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )
        self.l2l_1 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )

        # stage 3
        self.l2h_2 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x, dist):
        dist_l = F.interpolate(dist, x.size()[2:], mode='bilinear')

        # stage 1
        l = self.l2l_0(x)

        # stage 2
        l2l = self.l2l_1(l*dist_l)
        l2h = self.l2h_1(self.l2h_up(l+l2l))

        # stage 3
        out = self.l2h_2(self.l2h_up(l)+l2h)
        return out


class BDGD_B(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()

        self.h2l_pool = nn.AvgPool2d((2, 2), stride=2)
        self.l2h_up = nn.Upsample(scale_factor=2, mode="nearest")

        # stage 1
        self.h2h_0 = nn.Sequential(
            nn.Conv2d(skip_channels, skip_channels, 3, 1, 1, groups=skip_channels),
            nn.Conv2d(skip_channels, out_channels, 1),
        )

        self.l2l_0 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels),
            nn.Conv2d(in_channels, out_channels, 1),
        )

        # stage 2
        self.h2h_1 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )
        self.h2l_1 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )
        self.l2h_1 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )
        self.l2l_1 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )

        # stage 3
        self.h2h_2 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )
        self.l2h_2 = nn.Sequential(
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x, skip, dist):
        dist_h = F.interpolate(dist, skip.size()[2:], mode='bilinear')
        dist_l = F.interpolate(dist, x.size()[2:], mode='bilinear')

        # stage 1
        h_in = self.h2h_0(skip)
        l_in = self.l2l_0(x)

        # stage 2
        h2h = self.h2h_1(h_in * dist_h)
        l2h = self.l2h_1(self.l2h_up(l_in))

        l2l = self.l2l_1(l_in * dist_l)
        h2l = self.h2l_1(self.h2l_pool(h_in))

        h = h2h + l2h
        l = l2l + h2l

        # stage 3
        h2h = self.h2h_2(h)
        l2h = self.l2h_2(self.l2h_up(l)) + l2h
        out = h2h + l2h
        return out


class BDM_Net(pl.LightningModule):
    def __init__(self, nclass=1, max_epoch=None):
        super().__init__()
        self.encoder = get_encoder('timm-efficientnet-b5', weights='noisy-student')
        self.agg = BDMM(self.encoder.out_channels[-3:], 32, upsample=8)

        self.dec1 = BDGD_A(64, 32)
        self.dec2 = BDGD_B(128, self.encoder.out_channels[-4], 64)
        self.dec3 = BDGD_B(256, self.encoder.out_channels[-3], 128)
        self.dec4 = BDGD_B(self.encoder.out_channels[-1], self.encoder.out_channels[-2], 256)

        self.seg_head = SegmentationHead(32, nclass, upsampling=2)

        self.learning_rate = 1e-4
        self.max_epoch = max_epoch

        initialize_weights(self.dec1)
        initialize_weights(self.dec2)
        initialize_weights(self.dec3)
        initialize_weights(self.dec4)

        initialize_weights(self.seg_head)
        initialize_weights(self.agg)

        self.num = 0

    def forward(self, x):
        x = self.encoder(x)
        bdm = self.agg(x[-3], x[-2], x[-1])
        c4 = self.dec4(x[-1], x[-2], bdm)
        c3 = self.dec3(c4, x[-3], bdm)
        c2 = self.dec2(c3, x[-4], bdm)
        c1 = self.dec1(c2, bdm)
        seg = self.seg_head(c1)

        return seg, bdm

    def training_step(self, batch, batch_idx):
        x, y, ibdm = batch
        y_hat, bdm = self(x)
        train_loss_seg = seg_loss(y_hat, y.unsqueeze(1))
        train_loss_bdm = bdm_loss(bdm.squeeze(1), ibdm)

        train_mean_dice = mean_dice(y_hat, y)

        self.log('train_loss_seg', train_loss_seg, on_epoch=True)
        self.log('train_loss_bdm', train_loss_bdm, on_epoch=True)

        self.log('train_mean_dice', train_mean_dice, on_epoch=True)

        return train_loss_seg + train_loss_bdm

    def validation_step(self, batch, batch_idx):
        x, y, ibdm = batch
        y_hat, bdm = self(x)
        val_loss_seg = seg_loss(y_hat, y.unsqueeze(1))
        val_loss_bdm = bdm_loss(bdm.squeeze(1), ibdm)

        val_mean_dice = mean_dice(y_hat, y)

        self.log('val_loss_seg', val_loss_seg)
        self.log('val_loss_bdm', val_loss_bdm)

        self.log('val_mean_dice', val_mean_dice)

        return val_loss_seg + val_loss_bdm

    def test_step(self, batch, batch_idx):
        x, y, ibdm = batch
        y_hat, bdm = self(x)
        test_loss_seg = seg_loss(y_hat, y.unsqueeze(1))
        test_mean_dice = mean_dice(y_hat, y)

        for i in range(y_hat.size()[0]):
            TH = 0.5
            img = x[i, :, :, :]
            seg_gt = y[i, :, :]
            seg = y_hat[i, 0, :, :]
            dh_gt = ibdm[i, :, :]
            dh = bdm[i, 0, :, :]
            seg = seg.sigmoid()

            plt.figure()
            plt.subplot(231)
            plt.title('image')
            mean = torch.tensor([0.485, 0.456, 0.406]).to(self.device).type_as(img)
            std = torch.tensor([0.229, 0.224, 0.225]).to(self.device).type_as(img)
            img *= std.unsqueeze(-1).unsqueeze(-1)
            img += mean.unsqueeze(-1).unsqueeze(-1)

            img = img.cpu().numpy().transpose(1, 2, 0).astype(np.float32)
            seg = seg.cpu().numpy().astype(np.float32)
            dh = dh.cpu().numpy().astype(np.float32)
            seg_gt = seg_gt.cpu().numpy().astype(np.float32)
            dh_gt = dh_gt.cpu().numpy().astype(np.float32)

            plt.imshow(img)
            plt.xticks([]), plt.yticks([])  # 去除坐标轴

            plt.subplot(232)
            plt.title('seg')
            plt.imshow(seg.astype(np.float32), cmap=plt.cm.gray)
            plt.xticks([]), plt.yticks([])  # 去除坐标轴

            plt.subplot(233)
            plt.title('ground truth')
            plt.imshow(seg_gt, cmap=plt.cm.gray)
            plt.xticks([]), plt.yticks([])  # 去除坐标轴

            plt.subplot(234)
            plt.title('overlay')
            plt.imshow(overlay(img.transpose(2, 0, 1), (seg > TH)))
            plt.xticks([]), plt.yticks([])  # 去除坐标轴

            plt.subplot(235)
            plt.title('bdm')
            plt.imshow(dh, cmap=plt.cm.jet)
            plt.xticks([]), plt.yticks([])  # 去除坐标轴

            plt.subplot(236)
            plt.title('ideal bdm')
            plt.imshow(dh_gt, cmap=plt.cm.jet)
            plt.xticks([]), plt.yticks([])  # 去除坐标轴

            save_path = './save'
            os.makedirs(save_path, exist_ok=True)
            plt.savefig(save_path+'/{}.png'.format(self.num), dpi=400)
            # plt.show()

            self.num += 1
            plt.close()

        self.log('test_mean_dice', test_mean_dice)

        return test_loss_seg

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        poly_learning_rate = lambda epoch: (1 - float(epoch) / self.max_epoch) ** 0.9
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, poly_learning_rate)
        return [optimizer], [scheduler]


if __name__ == '__main__':
    # from utils.util import CalParams
    model = BDM_Net(nclass=1)
    model_graph = draw_graph(model, input_size=(1, 3, 352, 352), expand_nested=True) # Added batch size of 1 as first element of tuple
    model_graph.visual_graph
    CalParams(model, torch.rand(1, 3, 352, 352))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register count_avgpool() for <class 'torch.nn.modules.pooling.AvgPool2d'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.UpsamplingBilinear2d'>.
[Statistics Information]
FLOPs: 10.739G
Params: 32.561M


# Train modal

In [None]:
import torch
import gc


torch.cuda.empty_cache()
gc.collect()

import random
import os
import numpy as np
import torch
import argparse
import albumentations as A
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from pytorch_lightning.trainer import Trainer
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
# from utils import dataset
# from BDM_Net import BDM_Net

gpu_list = [0]
gpu_list_str = ','.join(map(str, gpu_list))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", gpu_list_str)


# parser = argparse.ArgumentParser(description='BDM-Net')
# parser.add_argument('--sigma', '-s', type=float, default=5)
# args = parser.parse_args()

def _init_fn(worker_id, seed=42):
    random.seed(seed + worker_id)


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.


def data_loader():
    batch_size = 20
    num_workers = 4

    data_root = './'

    train_list = data_root + 'train.txt'
    val_list = data_root + 'val.txt'

    img_size = 352

    train_trfm = A.Compose([
        A.RandomResizedCrop(img_size, img_size, scale=(0.75, 1)),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
    ])

    val_trfm = A.Resize(img_size, img_size)

    train_data = MyDataset(split='train', data_root=data_root, data_list=train_list, transform=train_trfm, sigma= 5)
    val_data = MyDataset(split='val', data_root=data_root, data_list=val_list, transform=val_trfm, sigma= 5)

    train_loader = DataLoader(train_data, shuffle=True, drop_last=True, batch_size=batch_size,
                              num_workers=num_workers, pin_memory=True, worker_init_fn=_init_fn)
    val_loader = DataLoader(val_data, batch_size=1, num_workers=num_workers, pin_memory=True, worker_init_fn=_init_fn)

    return train_loader, val_loader


def model_init(encoder_idx = 0 ):
    max_epochs = 100

    model = BDM_Net(nclass=1, max_epoch=max_epochs)
    # path = './logs/default/version_0/checkpoints/BDM-epoch=98-val_mean_dice=0.9050.ckpt'
    path = None
    if path:
        pretrained_dict = torch.load(path, map_location='cpu')['state_dict']

        # model_dict = model.state_dict()
        # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict)

    return model, max_epochs


def train_process(model, train_loader, val_loader, max_epochs):
    tb_logger = pl_loggers.TensorBoardLogger('logs/')
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    checkpoint_callback = ModelCheckpoint(monitor='val_mean_dice',
                                          filename='BDM-{epoch:02d}-{val_mean_dice:.4f}',
                                          save_top_k=5,
                                          mode='max',
                                          save_weights_only=True)

    trainer = Trainer(max_epochs=max_epochs, logger=tb_logger, accelerator='gpu', devices=[0, ],
                      precision=16, check_val_every_n_epoch=1, benchmark=True,
                      callbacks=[lr_monitor, checkpoint_callback])  # 使用单卡

    trainer.fit(model, train_loader, val_loader)
    # trainer.test(model, test_dataloaders=val_loader)


def main():
    seed_everything(seed=42)
    train_loader, val_loader = data_loader()

    model, max_epochs = model_init()
    train_process(model, train_loader, val_loader, max_epochs)


if __name__ == '__main__':
    main()

sigma is 5
Totally 488 samples in train set.
Starting Checking image&label pair train list...
Checking image&label pair train list done!
sigma is 5
Totally 62 samples in val set.
Starting Checking image&label pair val list...
Checking image&label pair val list done!


INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type                | Params | Mode 
---------------------------------------------------------
0 | encoder  | EfficientNetEncoder | 28.3 M | train
1 | agg      | BDMM                | 519 K  | train
2 | dec1     | BDGD_A              | 30.6 K | train
3 | dec2     | BDGD_B              | 234 K  | train
4 | dec3     | BDGD_B              | 930 K  | train
5 | dec4     | BDGD_B              | 3.7 M  | train
6 | seg_head | SegmentationHead    | 289    | train
---------------------------------------------------------
33.8 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (24) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

# References

* Paper: https://arxiv.org/pdf/2201.00767
* Github: https://github.com/zihuanqiu/BDG-Net
