# Settings

In [None]:
# install packages
!pip install -r requirements.txt --user --upgrade --quiet -U
# !apt updates

In [1]:
# !pip install fil_finder astropy==4.3.1 typing 

import numpy as np
import cv2

import matplotlib.pyplot as plt
import monai
import torch
    
from fil_finder import FilFinder2D
import astropy.units as u

In [35]:

import torch.nn as nn

from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks import Convolution, UpSample
from monai.networks.layers.factories import Conv, Pool
from monai.utils import deprecated_arg, ensure_tuple_rep

import torch
from torch import nn
from torch.nn import functional as F

class NLBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, mode='embedded', 
                 dimension=3, bn_layer=True):
        """Implementation of Non-Local Block with 4 different pairwise functions but doesn't include subsampling trick
        args:
            in_channels: original channel size (1024 in the paper)
            inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper)
            mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation
            dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal)
            bn_layer: whether to add batch norm
        """
        super(NLBlockND, self).__init__()

        assert dimension in [1, 2, 3]
        
        if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']:
            raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`')
            
        self.mode = mode
        self.dimension = dimension

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        # the channel size is reduced to half inside the block
        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1
        
        # assign appropriate convolutional, max pool, and batch norm layers for different dimensions
        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        # function g in the paper which goes through conv. with kernel size 1
        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)

        # add BatchNorm layer after the last conv layer
        if bn_layer:
            self.W_z = nn.Sequential(
                    conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
                    bn(self.in_channels)
                )
            # from section 4.1 of the paper, initializing params of BN ensures that the initial state of non-local block is identity mapping
            nn.init.constant_(self.W_z[1].weight, 0)
            nn.init.constant_(self.W_z[1].bias, 0)
        else:
            self.W_z = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1)

            # from section 3.3 of the paper by initializing Wz to 0, this block can be inserted to any existing architecture
            nn.init.constant_(self.W_z.weight, 0)
            nn.init.constant_(self.W_z.bias, 0)

        # define theta and phi for all operations except gaussian
        if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate":
            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
        
        if self.mode == "concatenate":
            self.W_f = nn.Sequential(
                    nn.Conv2d(in_channels=self.inter_channels * 2, out_channels=1, kernel_size=1),
                    nn.ReLU()
                )
            
    def forward(self, x):
        """
        args
            x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1
        """

        batch_size = x.size(0)
        
        # (N, C, THW)
        # this reshaping and permutation is from the spacetime_nonlocal function in the original Caffe2 implementation
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        if self.mode == "gaussian":
            theta_x = x.view(batch_size, self.in_channels, -1)
            phi_x = x.view(batch_size, self.in_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
            f = torch.matmul(theta_x, phi_x)

        elif self.mode == "embedded" or self.mode == "dot":
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
            phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
            f = torch.matmul(theta_x, phi_x)

        elif self.mode == "concatenate":
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
            phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
            
            h = theta_x.size(2)
            w = phi_x.size(3)
            theta_x = theta_x.repeat(1, 1, 1, w)
            phi_x = phi_x.repeat(1, 1, h, 1)
            
            concat = torch.cat([theta_x, phi_x], dim=1)
            f = self.W_f(concat)
            f = f.view(f.size(0), f.size(2), f.size(3))
        
        if self.mode == "gaussian" or self.mode == "embedded":
            f_div_C = F.softmax(f, dim=-1)
        elif self.mode == "dot" or self.mode == "concatenate":
            N = f.size(-1) # number of position in x
            f_div_C = f / N
        
        y = torch.matmul(f_div_C, g_x)
        
        # contiguous here just allocates contiguous chunk of memory
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        
        W_y = self.W_z(y)
        # residual connection
        z = W_y + x

        return z


# if __name__ == '__main__':
#     import torch

#     for bn_layer in [True, False]:
#         img = torch.zeros(2, 3, 20)
#         net = NLBlockND(in_channels=3, mode='concatenate', dimension=1, bn_layer=bn_layer)
#         out = net(img)
#         print(out.size())
        
class TwoConv(nn.Sequential):
    """two convolutions."""

    @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        out_chns: int,
        act: Union[str, tuple],
        norm: Union[str, tuple],
        bias: bool,
        dropout: Union[float, tuple] = 0.0,
        dim: Optional[int] = None,
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.

        .. deprecated:: 0.6.0
            ``dim`` is deprecated, use ``spatial_dims`` instead.
        """
        super().__init__()

        if dim is not None:
            spatial_dims = dim
        conv_0 = Convolution(spatial_dims, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1)
        conv_1 = Convolution(spatial_dims, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1)
        
        self.add_module("conv_0", conv_0)
        self.add_module("conv_1", conv_1)


class Down(nn.Sequential):
    """maxpooling downsampling and two convolutions."""

    @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        out_chns: int,
        act: Union[str, tuple],
        norm: Union[str, tuple],
        bias: bool,
        dropout: Union[float, tuple] = 0.0,
        dim: Optional[int] = None,
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.

        .. deprecated:: 0.6.0
            ``dim`` is deprecated, use ``spatial_dims`` instead.
        """
        super().__init__()
        if dim is not None:
            spatial_dims = dim
        max_pooling = Pool["MAX", spatial_dims](kernel_size=2)
        convs = TwoConv(spatial_dims, in_chns, out_chns, act, norm, bias, dropout)
        self.add_module("max_pooling", max_pooling)
        self.add_module("convs", convs)

class UpCat(nn.Module):
    """upsampling, concatenation with the encoder feature map, two convolutions"""

    @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        cat_chns: int,
        out_chns: int,
        act: Union[str, tuple],
        norm: Union[str, tuple],
        bias: bool,
        dropout: Union[float, tuple] = 0.0,
        upsample: str = "deconv",
        pre_conv: Optional[Union[nn.Module, str]] = "default",
        interp_mode: str = "linear",
        align_corners: Optional[bool] = True,
        halves: bool = True,
        dim: Optional[int] = None,
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels to be upsampled.
            cat_chns: number of channels from the decoder.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
            pre_conv: a conv block applied before upsampling.
                Only used in the "nontrainable" or "pixelshuffle" mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
            align_corners: set the align_corners parameter for upsample. Defaults to True.
                Only used in the "nontrainable" mode.
            halves: whether to halve the number of channels during upsampling.
                This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.

        .. deprecated:: 0.6.0
            ``dim`` is deprecated, use ``spatial_dims`` instead.
        """
        super().__init__()
        if dim is not None:
            spatial_dims = dim
        if upsample == "nontrainable" and pre_conv is None:
            up_chns = in_chns
        else:
            up_chns = in_chns // 2 if halves else in_chns
        self.upsample = UpSample(
            spatial_dims,
            in_chns,
            up_chns,
            2,
            mode=upsample,
            pre_conv=pre_conv,
            interp_mode=interp_mode,
            align_corners=align_corners,
        )
        self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout)

    def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
        """

        Args:
            x: features to be upsampled.
            x_e: features from the encoder.
        """
        x_0 = self.upsample(x)

        if x_e is not None:
            # handling spatial shapes due to the 2x maxpooling with odd edge lengths.
            dimensions = len(x.shape) - 2
            sp = [0] * (dimensions * 2)
            for i in range(dimensions):
                if x_e.shape[-i - 1] != x_0.shape[-i - 1]:
                    sp[i * 2 + 1] = 1
            x_0 = torch.nn.functional.pad(x_0, sp, "replicate")
            x = self.convs(torch.cat([x_e, x_0], dim=1))  # input channels: (cat_chns + up_chns)
        else:
            x = self.convs(x_0)

        return x
    
class monai_unet(nn.Module):
    def __init__(
        self,
        spatial_dims: int = 2,
        net_inputch: int = 1,
        net_outputch: int = 2,
        net_bayesian = 0,
        # features: Sequence[int] = (32, 32, 64, 128, 256, 32),
        features: Sequence[int] = (32, 32, 64, 128, 256, 512, 32),
        act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
        norm: Union[str, tuple] = ("group", {"num_groups": 8}),
        bias: bool = True,
        dropout: Union[float, tuple] = 0.0,
        upsample: str = "deconv",
        dimensions: Optional[int] = None,
        bottleneck_channels = None,
    ):
        """
        A UNet implementation with 1D/2D/3D supports.

        Based on:

            Falk et al. "U-Net – Deep Learning for Cell Counting, Detection, and
            Morphometry". Nature Methods 16, 67–70 (2019), DOI:
            http://dx.doi.org/10.1038/s41592-018-0261-2

        Args:
            spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
            in_channels: number of input channels. Defaults to 1.
            out_channels: number of output channels. Defaults to 2.
            features: six integers as numbers of features.
                Defaults to ``(32, 32, 64, 128, 256, 32)``,

                - the first five values correspond to the five-level encoder feature sizes.
                - the last value corresponds to the feature size after the last upsampling.

            act: activation type and arguments. Defaults to LeakyReLU.
            norm: feature normalization type and arguments. Defaults to instance norm.
            bias: whether to have a bias term in convolution blocks. Defaults to True.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.

        .. deprecated:: 0.6.0
            ``dimensions`` is deprecated, use ``spatial_dims`` instead.

        Examples::

            # for spatial 2D
            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))

            # for spatial 2D, with group norm
            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))

            # for spatial 3D
            >>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))

        See Also

            - :py:class:`monai.networks.nets.DynUNet`
            - :py:class:`monai.networks.nets.UNet`

        """
        super().__init__()
        if dimensions is not None:
            spatial_dims = dimensions

        fea = ensure_tuple_rep(features, 7)
        print(f"BasicUNet features: {fea}.")
        in_channels = net_inputch
        out_channels = net_outputch
        
        self.NLBlock_1 = NLBlockND(in_channels=fea[1], mode='dot', dimension=2)
        self.NLBlock_2 = NLBlockND(in_channels=fea[2], mode='dot', dimension=2)
        self.NLBlock_3 = NLBlockND(in_channels=fea[3], mode='dot', dimension=2)
        self.NLBlock_4 = NLBlockND(in_channels=fea[4], mode='dot', dimension=2)
        self.NLBlock_5 = NLBlockND(in_channels=fea[5], mode='dot', dimension=2)
        
        self.skipAtt = attention_block(F_g=fea[4],F_l=fea[4],F_int=fea[4])
        self.conv_0 = TwoConv(spatial_dims, in_channels, features[0], act, norm, bias, dropout)
        self.down_1 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout)
        self.down_2 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout)
        self.down_3 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout)
        self.down_4 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout)
        self.down_5 = Down(spatial_dims, fea[4], fea[5], act, norm, bias, dropout)

        self.upcat_5 = UpCat(spatial_dims, fea[5], fea[4], fea[4], act, norm, bias, dropout, upsample)
        self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample)
        self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample)
        self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample)
        # self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False)
        self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[6], act, norm, bias, dropout, upsample, halves=False)

        # self.final_conv = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1)
        self.final_conv = Conv["conv", spatial_dims](fea[6], out_channels, kernel_size=1)
        self.bottleneck_channels = bottleneck_channels 
        
        if self.bottleneck_channels is not None:
            pool = nn.AdaptiveAvgPool1d(1)
            flatten = nn.Flatten()
            dropout = nn.Dropout(p=.2, inplace=True) if dropout else nn.Identity()
            linear = nn.Linear(fea[5], regression_channels, bias=True)
            activation = nn.Sigmoid() # nn.ReLU()
            self.bottleneck_head = nn.Sequential(pool,flatten,dropout,linear,activation)
    
        if net_bayesian!=0:
            self.MCDropout = MCDropout(p=net_bayesian)
    
    def forward(self, x: torch.Tensor):
        """
        Args:
            x: input should have spatially N dimensions
                ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.
                It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have
                even edge lengths.

        Returns:
            A torch Tensor of "raw" predictions in shape
            ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
        """
        x0 = self.conv_0(x)

        x1 = self.down_1(x0)
        # x1 = self.NLBlock_1(x1)
        
        x2 = self.down_2(x1)
        x2 = self.NLBlock_2(x2)
        
        x2 = self.MCDropout(x2)
        x3 = self.down_3(x2)
        x3 = self.NLBlock_3(x3)
        
        x3 = self.MCDropout(x3)
        x4 = self.down_4(x3)
        x4 = self.NLBlock_4(x4)
        
        x4 = self.MCDropout(x4)
        x5 = self.down_5(x4)
        x5 = self.NLBlock_5(x5)
        
        u5 = self.upcat_5(x5, x4)
        u4 = self.upcat_4(u5, x3)
        u4 = self.NLBlock_3(u4)
        
        u3 = self.upcat_3(u4, x2)
        u3 = self.NLBlock_2(u3)
        
        u2 = self.upcat_2(u3, x1)
        u1 = self.upcat_1(u2, x0)

        x = self.final_conv(u1)
        # x = F.sigmoid(x)
        
        if self.bottleneck_channels is None:
            return x
        else:
            y = self.bottleneck_head(x4)            
            return x, y
        


class attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(attention_block,self).__init__()
#         inplace= True
        inplace= False

        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=inplace)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi  
    

class MCDropout(nn.Dropout):
    def forward(self, input):
        return F.dropout(input, self.p, True, self.inplace)


In [36]:
net = monai_unet(spatial_dims= 2,
                net_inputch = 3,
                net_outputch = 2,net_bayesian=0.2,
                norm=("group", {"num_groups": 8})
               )

a = torch.rand(2,3,320,320)
b = net(a)
# net = monai.networks.nets.UNet(spatial_dims= 3,
#                 in_channels = 3,
#                 out_channels = 2,channels= 
#                 norm=("group", {"num_groups": 8})
#                )
net

BasicUNet features: (32, 32, 64, 128, 256, 512, 32).


monai_unet(
  (NLBlock_1): NLBlockND(
    (g): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    (W_z): Sequential(
      (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (theta): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    (phi): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
  )
  (NLBlock_2): NLBlockND(
    (g): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (W_z): Sequential(
      (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (theta): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (phi): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (NLBlock_3): NLBlockND(
    (g): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (W_z): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, e

In [13]:
net = monai.networks.nets.UNETR(spatial_dims=2, in_channels=3, out_channels=2, img_size=320, feature_size=32, norm_name='batch')
a = torch.rand(2,3,320,320)
b = net(a)

RuntimeError: Expected 5-dimensional input for 5-dimensional weight [32, 3, 3, 3, 3], but got 4-dimensional input of size [2, 3, 320, 320] instead

In [None]:
b.shape

In [None]:
net = monai.networks.nets.UNETR(
        in_channels=3,
        out_channels=2,
        img_size=(320, 320),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="conv",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0)

In [None]:
import datasets
from torch.utils.data import DataLoader
train_dataset = datasets.dataset_CLAHE('DRIVE2CH','train',transform=datasets.augmentation_train(),adaptive_hist_range=False)
train_loader = DataLoader(train_dataset)

test_dataset = datasets.dataset_CLAHE('DRIVE2CH','test',adaptive_hist_range=False)
test_loader = DataLoader(test_dataset)

# test_dataset = datasets.dataset('Fundusphotography','test', adaptive_hist_range=False)
# test_loader = DataLoader(test_dataset)

# test_dataset = datasets.dataset('../Retina/dataset/HRF','test',adaptive_hist_range=False)
# test_loader = DataLoader(test_dataset)

for idx,batch in enumerate(train_loader):
    x,y = batch['x'], batch['y']
    plt.imshow(x.permute(0,2,3,1)[0])
    # plt.imshow(y[0,0],alpha=0.5,cmap='gray')
    plt.show()    

In [None]:
from fil_finder import FilFinder2D
import astropy.units as u
import skimage
import skimage.morphology
from skimage.morphology import skeletonize, medial_axis, dilation, disk, remove_small_objects

def postprocess(mask):
    mask = mask.astype(np.bool)
    return skimage.morphology.remove_small_objects(mask,400)*1


def skeletion_analysis(image,skeleton):
    # fil = FilFinder2D(distance, mask = skeleton)
    fil = FilFinder2D(image, mask = skeleton)
    # filfind.preprocess_image(flatten_percent=85)
    fil.create_mask(border_masking = True, verbose = False, use_existing_mask = True)
    fil.medskel(verbose = False)
    fil.analyze_skeletons(branch_thresh = 10 * u.pix, skel_thresh = 10 * u.pix, prune_criteria = 'all')
    # fil.analyze_skeletons(branch_thresh = 10 * u.pix, skel_thresh = 10 * u.pix, prune_criteria = 'all', verbose=True)

    # Show the longest path
    plt.figure(figsize=(18,8))
    plt.subplot(121)
    plt.imshow(image, cmap = 'gray')
    # plt.imshow(distance, cmap = 'gray')
    plt.imshow(fil.skeleton, cmap = 'gray',alpha=0.5)
    plt.contour(fil.skeleton, colors = 'r', alpha=0.3)
    plt.axis('off')

    print(len(fil.filaments))
    fil1 = fil.filaments[0]
    plt.subplot(122)
    plt.plot(fil1.ridge_profile(fil.image))
    plt.xlabel('Length(Pixel)')
    plt.ylabel('Thickness(Pixel)')
    plt.tight_layout()
    plt.show()
    # fil1.skeleton_analysis(fil.image, verbose=True)
    print(fil.branch_properties.keys())
    # fil.find_widths(verbose=True, max_dist=200 *u.pix, pad_to_distance= 0 *u.pix, use_longest_path=True, xunit=u.pix)
    
skeleton,distance = medial_axis(y.numpy()[0,0], return_distance=True)
image = x.numpy()[0,0]
# skeletion_analysis(image,skeleton)

In [None]:
# fil = FilFinder2D(distance, mask = skeleton)
fil = FilFinder2D(image, mask = skeleton)
# filfind.preprocess_image(flatten_percent=85)
fil.create_mask(border_masking = True, verbose = False, use_existing_mask = True)
fil.medskel(verbose = False)
fil.analyze_skeletons(branch_thresh = 10 * u.pix, skel_thresh = 10 * u.pix, prune_criteria = 'all')
# fil.analyze_skeletons(branch_thresh = 10 * u.pix, skel_thresh = 10 * u.pix, prune_criteria = 'all', verbose=True)

# Show the longest path
plt.figure(figsize=(18,8))
plt.subplot(121)
plt.imshow(image, cmap = 'gray')
# plt.imshow(distance, cmap = 'gray')
plt.imshow(fil.skeleton, cmap = 'gray',alpha=0.5)
plt.contour(fil.skeleton, colors = 'r', alpha=0.3)
plt.axis('off')

print(len(fil.filaments))
fil1 = fil.filaments[0]
plt.subplot(122)
plt.plot(fil1.ridge_profile(fil.image))
plt.xlabel('Length(Pixel)')
plt.ylabel('Thickness(Pixel)')
plt.tight_layout()
plt.show()
# fil1.skeleton_analysis(fil.image, verbose=True)
print(fil.branch_properties.keys())
# fil.find_widths(verbose=True, max_dist=200 *u.pix, pad_to_distance= 0 *u.pix, use_longest_path=True, xunit=u.pix)

In [None]:
# branchpt

In [None]:
# len(fil1.branch_pts())
branchpt = fil1.branch_pts()
array_branchpt = np.zeros_like(label)
for idx in range(len(branchpt)):
    for idx_ in range(len(branchpt[idx])):
        x,y = branchpt[idx][idx_]
        array_branchpt[x,y] = 1
plt.imshow(array_branchpt)

In [None]:
fil1.plot_graph()
# fil1.skeleton_analysis(fil.image, verbose=True)

In [None]:
fil.branch_properties['number']

In [None]:
np.unique(label)

In [None]:
distance[distance==0]=1000
label = np.exp(-distance/16)
plt.figure(figsize=(20,20))
plt.imshow(label)

In [None]:
print(np.unique(label))

In [None]:
print(np.unique(1/distance))

In [None]:
# plt.figure(figsize=(20,20))
plt.imshow(y[0][0])

In [None]:
plt.figure(figsize=(20,20))
plt.subplot(121)
# plt.imshow(distance)
plt.imshow(y[0][0])
distance[distance==0]=1000
plt.subplot(122)
plt.imshow(1/distance)

In [None]:
# gpu status
!nvidia-smi
import multiprocessing
print(multiprocessing.cpu_count())
# !pip install tensorboard==1.15
gpus= "0"

import warnings
warnings.filterwarnings(action='ignore')

import os
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"]= gpus;
    
import torch
gpu_count = torch.cuda.device_count()
if gpu_count >=1:
    torch.multiprocessing.set_start_method('spawn')
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
device

print(torch.cuda.is_available())
print(gpu_count)
print(torch.__version__)

In [None]:
import nets
net= nets.unet(net_bayesian=0.2)
net

In [None]:
import nets
import torch
net= nets.smp_unet(net_bayesian=True)
a= torch.rand(2,3,128,128)
b= net(a)
b.shape

In [None]:
import segmentation_models_pytorch as smp
net = smp.Unet(encoder_name='resnet50',encoder_weights=None)
# net = smp.Unet(encoder_name='densenet169',encoder_weights=None)
# net = smp.Unet(encoder_name='timm-efficientnet-b5',encoder_weights=None)
# net = smp.Unet(encoder_name='timm-regnety_040',encoder_weights=None)
# net = smp.Unet(encoder_name='timm-skresnext50_32x4d',encoder_weights='imagenet')
# net = smp.DeepLabV3Plus(encoder_name='resnet50',encoder_weights=None)

net.encoder
# net.decoder

In [None]:
# net.encoder.layer4

In [None]:
# list(net.encoder.children())[-1]

In [None]:
list(net.encoder.children())[-1][-1].b1

In [None]:
# net.encoder.s3

In [None]:
import torch.nn as nn
import nets
MCDropout = nets.MCDropout()

net.encoder.layer4 = nn.Sequential(MCDropout, net.encoder.layer4)
net.encoder.layer4
# net.encoder.s4 = nn.Sequential(MCDropout,net.encoder.s4) # regnet
# net.encoder.s4

In [None]:
import torch.nn as nn
import nets
MCDropout = nets.MCDropout()

list(net.encoder.children())[-1][-1] = nn.Sequential(MCDropout,list(net.encoder.children())[-1][-1]) # resnet, densenet
list(net.encoder.children())[-1][-1]

In [None]:
net.encoder.bn2 =  nn.Sequential(net.encoder.bn2 ,MCDropout) # efficientnet
net.encoder.bn2 =  nn.Sequential(net.encoder.bn2 ,MCDropout) # efficientnet

In [None]:
import torch.nn as nn
import nets
MCDropout = nets.MCDropout()

list(net.encoder.children())[-1] = nn.Sequential(MCDropout,list(net.encoder.children())[-1]) # 
list(net.encoder.children())[-1]

In [None]:
net.encoder

In [None]:
# layers = list(net.encoder.children())
# layer_last = 0

# while layer_last==0:
#     try:
#         layers = layers[-1]
#     except:
#         layer_last = layers
        
# print('before',layer_last)
# layer_last = nn.Sequential(layer_last, MCDropout)
# print('after',layer_last)

In [None]:
b= net(a)
# b= net.encoder(a)

In [None]:
# net.encoder

In [None]:
net.encoder = nn.Sequential(*list(net.encoder.modules()))

In [None]:
a = torch.rand(1,3,128,128)
b = net.encoder(a)

In [None]:
net.encoder

In [None]:
net.encoder

In [None]:
b= net.encoder(a)

In [None]:
import train
import torch
import glob
import nets
import monai

# # !CUDA_VISIBLE_DEVICES=2 python train.py --project Retina \
# #                                         --data_dir DRIVE2CH \
# #                                         --data_module dataset \
# #                                         --net_name smp_unet \
# #                                         --net_inputch 3 \
# #                                         --net_outputch 2 \
# #                                         --net_encoder_name resnet152 \
# #                                         --net_norm batch \
# #                                         --net_baysian True \
# #                                         --lossfn skel_FocalLoss \
# #                                         --data_padsize None \
# #                                         --data_cropsize None \
# #                                         --data_resize 584_565 \
# #                                         --data_patchsize 224_224 \
# #                                         --batch_size 12 \
# #                                         --lr 1e-3 \
# #                                         --precision 32 \

net = train.SegModel(data_dir='DRIVE2CH',project='Retina',
                     net_name='smp_unet',
                     net_inputch=3,
                     net_outputch=2,
                     net_encoder_name='resnet50',
                     net_norm = 'batch',
                     net_baysian=True,
                     lossfn='BoundaryFocalLoss',
                     data_padsize = None,
                     data_cropsize = None,
                     data_resize = 584_565,
                     data_patchsize = '320_320',
                    ).cuda()

PATH = 'logs/Retina/2zs1jtfs/checkpoints/' # AMC
PATH = 'logs/Retina/f1trgt7t/checkpoints/' # DRIVE
# PATH = 'logs/Retina/i7s4t3w2/checkpoints/' # DRIVE
# keewonshin/Retina/f1trgt7t
FILE = glob.glob(PATH+'*.ckpt')
print(FILE,'\n',FILE[-1])
weight = torch.load(FILE[-1])

net.load_state_dict(weight['state_dict'],strict=False)
net = net.to(device)

In [None]:
# import train
# import torch
# import glob
# import nets

# net = train.SegModel(data_dir='DRIVE_2c',project='Retina',
#                      net='smp_FPNRecSoft',
#                      net_inputch=3,
#                      net_outputch=2,
#                      net_finalActivation=True,
#                      net_nnblock=True,
#                      net_supervision=True,
#                      net_skipatt=True,
#                      net_wavelet=True,
#                      net_rcnn=False,
#                      net_reconstruction=True,
#                      data_padsize =None,
#                      data_cropsize =None,
#                      data_resize =None,
#                      data_patchsize = '256_256',
#                     ).cuda()

# PATH = 'logs/Retina/13hcx6pe/checkpoints/'

# FILE = glob.glob(PATH+'*.ckpt')
# print(FILE,'\n',FILE[-1])
# weight = torch.load(FILE[-1])

# net.load_state_dict(weight['state_dict'],strict=False)
# # net = net.to(device)

In [None]:
# import train
# import torch
# import glob
# import nets

# net = train.SegModel(data_dir='DRIVE_2c',project='Retina',
#                      net='smp_FPNRecHard',
#                      net_inputch=3,
#                      net_outputch=2,
#                      net_finalActivation=True,
#                      net_nnblock=True,
#                      net_supervision=True,
#                      net_skipatt=True,
#                      net_wavelet=True,
#                      net_rcnn=False,
#                      net_reconstruction=True,
#                      data_padsize =None,
#                      data_cropsize =None,
#                      data_resize =None,
#                      data_patchsize = '256_256',
#                     ).cuda()

# PATH = 'logs/Retina/1muq8j90/checkpoints/'

# FILE = glob.glob(PATH+'*.ckpt')
# # print(FILE,'\n',FILE[-1])
# weight = torch.load(FILE[-1])

# net.load_state_dict(weight['state_dict'],strict=False)
# # net = net.to(device)

In [None]:
# import pytorch_lightning as pl
# import train
# import torch
# import glob
# import nets

# model = MyLightingModule.load_from_checkpoint(PATH)

# print(model.learning_rate)
# # prints the learning_rate you used in this checkpoint

# model.eval()
# y_hat = model(x)

# trainer = pl.Trainer(gpus = -1)

In [None]:
# import glob
# import cv2
# import numpy as np

# files = glob.glob('DRIVE2CH_CHASE_HRF/*.jpg')
# for idx in range(len(files)):
#     img = cv2.imread(files[idx])
#     img[img!=0] = 0
#     img[img==0] = 23
#     cv2.imwrite(files[idx],img)

In [None]:
# idx = 0
# img = cv2.imread(files[idx])
# np.unique(img)

In [None]:
# idx = 0
# # img = cv2.imread(files[idx])

# img = np.load(files[idx])
# artery = img[...,0] + img[...,2]
# vein = img[...,1] + img [...,2]
# artery[artery!=0] = 1
# vein[vein!=0] = 1

# plt.imshow(artery*255)
# plt.show()
# plt.imshow(vein*255)
# plt.show()

In [None]:
import datasets
from torch.utils.data import DataLoader
train_dataset = datasets.dataset_CLAHE('DRIVE2CH','train',transform=datasets.augmentation_train(),adaptive_hist_range=False)
train_loader = DataLoader(train_dataset)

test_dataset = datasets.dataset_CLAHE('DRIVE2CH','test',adaptive_hist_range=False)
test_loader = DataLoader(test_dataset)

# test_dataset = datasets.dataset('Fundusphotography','test', adaptive_hist_range=False)
# test_loader = DataLoader(test_dataset)

# test_dataset = datasets.dataset('../Retina/dataset/HRF','test',adaptive_hist_range=False)
# test_loader = DataLoader(test_dataset)

for idx,batch in enumerate(train_loader):
    x,y = batch['x'], batch['y']
    plt.imshow(x.permute(0,2,3,1)[0])
    # plt.imshow(y[0,0],alpha=0.5,cmap='gray')
    plt.show()    

In [None]:
for idx,batch in enumerate(train_loader):
    x,y = batch['x'], batch['y']
    plt.imshow(x.permute(0,2,3,1)[0])
    # plt.imshow(y[0,0],alpha=0.5,cmap='gray')
    plt.show()    

In [None]:
torch.unique(y)

In [None]:
# for idx,batch in enumerate(train_loader):
#     x,y = batch['x'], batch['y']
#     plt.imshow(x.permute(0,2,3,1)[0])
# #     plt.imshow(y[0,0],alpha=0.5,cmap='gray')
#     plt.show()    

In [None]:
from sklearn.metrics import *
def metric(yhat,y):
    """
    long type inputs torch or numpy
    """
    
    try:
        yhat_ = yhat.cpu().detach().numpy().flatten()
        y_ = y.cpu().detach().numpy().flatten()
    except:
        yhat_ = yhat.flatten()
        y_ = y.flatten()

    tn, fp, fn, tp = confusion_matrix(y_, yhat_).ravel()
    accuracy = (tp+tn)/(tn+fp+fn+tp)
    iou = tp/(tp+fp+fn)
    dice = 2*tp/(2*tp+fp+fn)
    specificity = tn / (tn+fp)
    sensitivity = tp / (tp+fn)
    
    return {'specificity':specificity, 'sensitivity':sensitivity, 'dice':dice, 'iou':iou, 'accuracy':accuracy}

def show_image_samples(x, y, yhat, message=''):
    '''
    all inputs should be shaped in BxCxHxW. (only for 2D segmentation)
    If prediction shape channel more than 2, you need to argmax it. (fixed)
    The first element of the batch will shown.
    '''
    
    plt.figure(figsize=(24,16))
    plt.subplot(131)
    plt.title(str(message)+'_x')
    plt.imshow(x,cmap='gray')
    plt.subplot(132)
    plt.title(str(message)+'_y')
    plt.imshow(y,cmap='gray')
    plt.subplot(133)
    plt.title(str(message)+'_yhat')
    plt.title(str(message)+'_y-yhat (Green:FP, Red:FN, White:TP)')
        
    temp = np.zeros((x.shape[0],x.shape[1],3))
    for idx_ in range(3):
        temp[...,idx_] = y # White (gt)
    
    diff = y-yhat
    
    diff_fp = diff.copy()
    diff_fp[diff_fp!=-1] = 0
    diff_fp[diff_fp!=0] = 1
    
    diff_fn = diff.copy()
    diff_fn[diff_fn!=1] = 0
    diff_fn[diff_fn!=0] = 1
    
    temp[...,1] -= diff_fn #R   gt-fn
    temp[...,2] -= diff_fn #R   gt-fn
    temp[...,1] += diff_fp #G   gt+fp
    temp[temp!=0]=1
    
    plt.imshow(temp,alpha=1,cmap='gray')
    plt.show()
    
def show_batch_samples(batch_x, batch_y, batch_yhat, message='', x_ch = 0):
    '''
    all inputs should be shaped in BxCxHxW. (only for 2D segmentation)
    If prediction shape channel more than 2, you need to argmax it. (fixed)
    The first element of the batch will shown.
    '''
    
    if len(batch_yhat.shape)==4 and batch_yhat.shape[1]>1:
        batch_yhat = torch.argmax(batch_yhat,1).unsqueeze(1)
    if len(batch_yhat.shape)==3:
        batch_yhat = batch_yhat.unsqueeze(1)

    idx= 0 
    plt.figure(figsize=(24,16))
    plt.subplot(131)
    plt.title(str(message)+'_x')
    plt.imshow(batch_x[idx,x_ch].cpu().detach(),cmap='gray')
    plt.subplot(132)
    plt.title(str(message)+'_y')
    plt.imshow(batch_y[idx,0].cpu().detach(),cmap='gray')
    plt.subplot(133)
    plt.title(str(message)+'_yhat')
    plt.title(str(message)+'_y-yhat (Green:FP, Red:FN, White:TP)')
    
    temp = np.zeros((batch_x[idx,0].shape[0],batch_x[idx,0].shape[1],3))
    for idx_ in range(3):
        temp[...,idx_] = batch_y[idx,0].cpu().detach() # White (gt)
    diff = batch_y[idx,0].float().cpu().detach().numpy()-batch_yhat[idx,0].float().cpu().detach().numpy()
    
    diff_fp = diff.copy()
    diff_fp[diff_fp!=-1] = 0
    diff_fp[diff_fp!=0] = 1
    diff_fn = diff.copy()
    diff_fn[diff_fn!=1] = 0
    diff_fn[diff_fn!=0] = 1
    
    temp[...,1] -= diff_fn #R   gt-fn
    temp[...,2] -= diff_fn #R   gt-fn
    temp[...,1] += diff_fp #G   gt+fp
    temp[temp!=0]=1
    
    plt.imshow(temp,alpha=1,cmap='gray')
    plt.show()

In [None]:
# torch.max(uncert),torch.mean(uncert),torch.std(uncert)
# plt.imshow(1-yhat[0][0][0].cpu().detach().numpy())

In [None]:
import utils
import kornia
import torch
import numpy as np
import cv2
import pylab as plt

def bayesian_inference_seg(net, x, iteration=50):
    yhat_stack = []
    net.eval()

    with torch.no_grad():
        for idx in range(iteration):
#             yhat = net(x)        
            def predictor(x, return_idx=0): # in case of prediction is type of list
                result = net(x)
                if isinstance(result, list) or isinstance(result, tuple):
                    return result[return_idx]
                else:
                    return result
                
            roi_size= 320
            yhat = sliding_window_inference(inputs=x, roi_size=roi_size, sw_batch_size=4, predictor=predictor, overlap=0.75, mode='constant')
            yhat = utils.Activation(yhat)
            yhat_stack.append(yhat)

    yhat_stack = torch.stack(yhat_stack)
    pred = torch.mean(yhat_stack,0)
    pred = torch.argmax(pred,1).unsqueeze(1)
        
    uncert = torch.var(yhat_stack,0)
    uncert = torch.mean(uncert,1)

#     plt.imshow(pred[0,0].cpu().detach().numpy())
#     plt.show()
#     plt.imshow(uncert[0].cpu().detach().numpy())
#     plt.show()
#     plt.hist(uncert[0,0].cpu().detach().numpy())
#     plt.show()
    
#     print(torch.unique(uncert))
#     plt.imshow(pred[0,0].cpu().detach().numpy())
#     return pred,uncert
#     uncert[uncert<torch.std(uncert)] = 0
#     uncert[uncert<torch.mean(uncert)] = 0
#     print(torch.max(uncert),torch.mean(uncert),torch.std(uncert),torch.unique(uncert))
    
#     uncert[uncert!=0] = 1
#     pred = pred+uncert
#     pred[pred!=0] = 1
#     print(np.unique(pred.cpu().detach().numpy(),return_counts=True))
    return pred, uncert

In [None]:
import monai
from monai.inferers import sliding_window_inference

metrics = list()
batch = next(iter(test_loader))
x,y = batch['x'].cuda(), batch['y'].cuda()

def predictor(x, return_idx=0): # in case of prediction is type of list
    net.eval()
    result = net(x)
    if isinstance(result, list) or isinstance(result, tuple):
        return result[return_idx]
    else:
        return result
# yhat,uncert = bayesian_inference_seg(net, x, iteration=2)
roi_size=320
yhat = sliding_window_inference(inputs=x, roi_size=roi_size, sw_batch_size=4, predictor=predictor, overlap=0.75, mode='constant')
# print(metric(yhat,y))
show_batch_samples(x,y,yhat)
        
# metrics = torch.tensor(metrics)
# print(torch.mean(metrics))

In [None]:
yhat.shape,y.shape

In [None]:
import monai
from monai.inferers import sliding_window_inference

metrics = list()
for idx,batch in enumerate(test_loader):
    net.eval()
    with torch.no_grad():
        x,y = batch['x'].cuda(), batch['y'].cuda()
        yhat,uncert = bayesian_inference_seg(net, x, iteration=20)
#         plt.figure(figsize=(24,16))
#         plt.subplot(131)
#         plt.imshow(x[0,0].cpu().detach(),cmap='gray')
#         plt.subplot(132)
#         plt.imshow(yhat[0,0].cpu().detach(),cmap='gray')
#         plt.subplot(133)
#         plt.imshow(uncert[0].cpu().detach(),cmap='gray')
#         plt.show()

        metrics.append(metric(yhat,y))
        print(metric(yhat,y))
        show_batch_samples(x,y,yhat)
        
metrics = torch.tensor(metrics)
print(torch.mean(metrics))

# .1 tensor([[0.8290]], device='cuda:0')
# .2 tensor([[0.8299]], device='cuda:0')
# .3 tensor([[0.8289]], device='cuda:0')
# tensor(0.8135) with out tophat

# yhat = bayesian_inference_seg(net, x, iteration=20, vote_threshold=0.2, tophat_param=7) tensor(0.7921)
# yhat = bayesian_inference_seg(net, x, iteration=20, vote_threshold=0.2, tophat_param=9) tensor(0.8121)

In [None]:
# import monai
# from monai.inferers import sliding_window_inference

# batch = next(iter(test_loader))
# x,y = batch['x'].cuda(), batch['y'].cuda()
# yhat = bayesian_inference_seg(net,x,iteration=20, vote_threshold=0.1, tophat_param=5)

In [None]:
metric(torch.argmax(yhat[0],1).unsqueeze(1),y)

In [None]:
plt.hist(yhat[1].cpu().detach().numpy().flatten(),bins=10)

In [None]:
torch.mean(yhat[1]),torch.max(yhat[1]),torch.std(yhat[1])

In [None]:
-torch.mean(yhat[1])+torch.std(yhat[1])*1

In [None]:
yhat[0].shape,yhat[1].shape

In [None]:
plt.imshow(yhat[1][0].cpu().detach().numpy())
# torch.unique(yhat_matmul)

In [None]:
# yhat_matmul = (1-torch.argmax(yhat[0],1))*yhat[1]
yhat_matmul = (torch.argmax(yhat[0],1))*yhat[1]
# yhat_matmul = yhat[1]
print(torch.unique(yhat_matmul))
# yhat_matmul[yhat_matmul!=0] = 1
# yhat_matmul[yhat_matmul<torch.std(yhat[1])] = 0
yhat_matmul[yhat_matmul<torch.mean(yhat[1])] = 0
yhat_matmul[yhat_matmul!=0] = 1

plt.figure(figsize=(10,10))
plt.imshow(yhat_matmul[0].cpu().detach().numpy())

In [None]:
plt.figure(figsize=(24,24))
plt.subplot(141)
plt.imshow(y[0][0].cpu().detach().numpy())
plt.subplot(142)
# plt.imshow(yhat[0][0].cpu().detach().numpy())
plt.imshow(yhat[0][0][1].cpu().detach().numpy())
plt.subplot(143)
uncent = yhat[1][0].cpu().detach().numpy()
uncent[uncent<0.0001] = 0
uncent[uncent!=0] = 1
plt.imshow(uncent)
plt.subplot(144)
final = yhat[0][0][1].cpu().detach().numpy()+yhat[1][0].cpu().detach().numpy()
final[final<0.001] = 0
final[final!=0] = 1
plt.imshow(x[0].permute(1,2,0).cpu().detach().numpy())
plt.imshow(final,alpha=0.4)

In [None]:
plt.figure(figsize=(24,24))
plt.imshow(uncent)

In [None]:
metrics = trainer.test(net, test_loader)

In [None]:
outputs = trainer.predict(net, test_loader)

In [None]:
import utils 
for output in outputs:
    x,y,yhat = output['x'],output['y'],output['yhat']
    yhat = utils.Activation(yhat)
    plt.figure(figsize=(25,25))
    plt.subplot(131)
    plt.imshow(x[0].permute(1,2,0).cpu().detach().numpy())
    plt.subplot(132)
    plt.imshow(y[0,0].cpu().detach().numpy())
    plt.subplot(133)
    plt.imshow(yhat[0,1].cpu().detach().numpy())
#     plt.subplot(144)
#     plt.imshow(torch.argmax(yhat,1)[0].cpu().detach().numpy())
    plt.show()

In [None]:
plt.imshow()
a= torch.rand(2,3,128,128).cuda()
b = net(a)
b.shape

In [None]:
batch = next(iter(train_loader))
x = batch['x']
plt.imshow(x[0,0])
plt.show()
plt.imshow(x[0,1])
plt.show()
plt.imshow(x[0,2])
plt.show()

In [None]:
batch['mask']

In [None]:
import utils
import pylab as plt

from monai.inferers import sliding_window_inference
for idx, batch in enumerate(test_loader):
    x,y = batch['x'].to(device), batch['y'].to(device)
    
    with torch.no_grad():
        
        def predictor(x, return_idx=0): # in case of prediction is type of list
            result = net(x)
            if isinstance(result, list) or isinstance(result, tuple):
                return result[return_idx]
            else:
                return result

        roi_size = 128
        yhat = sliding_window_inference(inputs=x, roi_size=roi_size,sw_batch_size=4,predictor=predictor,overlap=0.75,mode='constant')
        yhat = utils.Activation(yhat)

        plt.figure(figsize=(12,12))
        plt.subplot(121)
        plt.imshow(x[0].cpu().detach().permute(1,2,0))
        plt.imshow(y[0,0].cpu().detach(),alpha=0.5,cmap='gray')
        plt.subplot(122)
        plt.imshow(torch.argmax(yhat,1)[0].cpu().detach(),cmap='gray')
        plt.show()

In [None]:
import pylab as plt
plt.imshow(x[0].permute(1,2,0))
x.shape

In [None]:
import pylab as plt
plt.imshow(x[0].permute(1,2,0))
x.shape

In [None]:
# install packages
# !pip install -r requirements.txt --user --quiet -U
!pip uninstall torch -y
!pip install torch==1.8.0
# !apt updates

In [None]:
# !sudo apt update
# !sudo apt install libgl1-mesa-glx ffmpeg libsm6 libxext6 -y

In [None]:
import segmentation_models_pytorch as smp
net = smp.Unet()
# net

In [None]:
import pywt
import torch
from torch.autograd import Variable

w=pywt.Wavelet('db1')
# w=pywt.Wavelet('haar')
# w=pywt.Wavelet('rbio1.1')
dec_hi = torch.Tensor(w.dec_hi[::-1]) 
dec_lo = torch.Tensor(w.dec_lo[::-1])
rec_hi = torch.Tensor(w.rec_hi)
rec_lo = torch.Tensor(w.rec_lo)

filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1),
                       dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
                       dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
                       dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1),
                           rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1),
                           rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1),
                           rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)

filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1),
                       dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
                       dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
                       dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1),
                           rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1),
                           rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1),
                           rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)

def wt(vimg):
    padded = vimg
    res = torch.zeros(vimg.shape[0],4*vimg.shape[1],int(vimg.shape[2]/2),int(vimg.shape[3]/2))
    res = res.cuda()
    for i in range(padded.shape[1]):
        res[:,4*i:4*i+4] = torch.nn.functional.conv2d(padded[:,i:i+1], Variable(filters[:,None].cuda(),requires_grad=True),stride=2)
#         res[:,4*i+1:4*i+4] = (res[:,4*i+1:4*i+4]+1)/2.0
    return res

def iwt(vres):
    res = torch.zeros(vres.shape[0],int(vres.shape[1]/4),int(vres.shape[2]*2),int(vres.shape[3]*2))
    res = res.cuda()
    for i in range(res.shape[1]):
#         vres[:,4*i+1:4*i+4]=2*vres[:,4*i+1:4*i+4]-1
        temp = torch.nn.functional.conv_transpose2d(vres[:,4*i:4*i+4], Variable(inv_filters[:,None].cuda(),requires_grad=True),stride=2)
        res[:,i:i+1,:,:] = temp
    return res
import torch.nn as nn

class WT(nn.Module):
    def __init__(self):
        super(WT,self).__init__()
        
        w=pywt.Wavelet('db1')
        # w=pywt.Wavelet('haar')
        # w=pywt.Wavelet('rbio1.1')
        dec_hi = torch.Tensor(w.dec_hi[::-1]) 
        dec_lo = torch.Tensor(w.dec_lo[::-1])
        rec_hi = torch.Tensor(w.rec_hi)
        rec_lo = torch.Tensor(w.rec_lo)

        self.filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1),
                               dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
                               dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
                               dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
        self.inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1),
                                   rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1),
                                   rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1),
                                   rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)

#         filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1),
#                                dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
#                                dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
#                                dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
#         inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1),
#                                    rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1),
#                                    rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1),
#                                    rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)
    def forward(self,vimg):
        print(vimg.shape)
        padded = vimg
        res = torch.zeros(vimg.shape[0],4*vimg.shape[1],int(vimg.shape[2]/2),int(vimg.shape[3]/2))
        res = res.cuda()
        for i in range(padded.shape[1]):
            res[:,4*i:4*i+4] = torch.nn.functional.conv2d(padded[:,i:i+1], Variable(self.filters[:,None].cuda(),requires_grad=True),stride=2)            
    #         res[:,4*i+1:4*i+4] = (res[:,4*i+1:4*i+4]+1)/2.0
        filters = torch.zeros(vimg.shape[1],4*vimg.shape[1],int(vimg.shape[2]/2),int(vimg.shape[3]/2))
        print(res.shape,filters.shape)
        res = torch.nn.functional.conv2d(res,Variable(filters.cuda(),requires_grad=True),stride=1)
        return res
    
class IWT(nn.Module):
    def __init__(self):
        super(IWT,self).__init__()
        
        w=pywt.Wavelet('db1')
        # w=pywt.Wavelet('haar')
        # w=pywt.Wavelet('rbio1.1')
        dec_hi = torch.Tensor(w.dec_hi[::-1]) 
        dec_lo = torch.Tensor(w.dec_lo[::-1])
        rec_hi = torch.Tensor(w.rec_hi)
        rec_lo = torch.Tensor(w.rec_lo)

        filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1),
                               dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
                               dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
                               dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
        inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1),
                                   rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1),
                                   rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1),
                                   rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)

        filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1),
                               dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
                               dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
                               dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
        inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1),
                                   rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1),
                                   rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1),
                                   rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)
    def forward(self,vres):
        res = torch.zeros(vres.shape[0],int(vres.shape[1]/4),int(vres.shape[2]*2),int(vres.shape[3]*2))
        res = res.cuda()
        for i in range(res.shape[1]):
    #         vres[:,4*i+1:4*i+4]=2*vres[:,4*i+1:4*i+4]-1
            temp = torch.nn.functional.conv_transpose2d(vres[:,4*i:4*i+4], Variable(inv_filters[:,None].cuda(),requires_grad=True),stride=2)
            res[:,i:i+1,:,:] = temp
        return res
    

In [None]:
import segmentation_models_pytorch as smp
net = smp.Unet(in_channels=4).cuda()

def maxpool2wt(module):
    """
    """
    module_output = module
    if isinstance(module, torch.nn.modules.MaxPool2d):
        module_output = WT()
        
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig

    for name, child in module.named_children():
        module_output.add_module(name, maxpool2wt(child))

    del module
    return module_output

def upsample2iwt(module):
    """
    """
    module_output = module
    if isinstance(module, torch.nn.functional.interpolate):
        module_output = IWT()
        
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig

    for name, child in module.named_children():
        module_output.add_module(name, upsample2iwt(child))

    del module
    return module_output


In [None]:
import torch
a = torch.rand(1,4,64,64).cuda()
b= net(a)
b.shape

In [None]:
net_ = maxpool2wt(net)
net_ = net_.cuda()
net_

In [None]:
import torch
a = torch.rand(2,4,64,64).cuda()
b= net_(a)
b.shape

In [None]:
net_ =upsample2iwt(net)