<a href="https://colab.research.google.com/github/moeghaf/Deep-unsupervised-segmentation/blob/main/WNET_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WNET - Deep unsupervised segmentation

Model implementation of WNet -  https://arxiv.org/pdf/1711.08506.pdf

Config, data loader and NCut adapted from: https://github.com/Andrew-booler/W-Net/tree/master/Wnet


In [None]:
from PIL import Image
import torch
import torch.utils.data as Data
import os
import glob
import numpy as np
import pdb
import math
import cupy as cp
import matplotlib.pyplot as plt
from skimage import io, color, morphology
from torchvision import transforms as T
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300

In [None]:
class Config:

    def __init__(self):
        #network configure
        self.InputCh=3
        self.ScaleRatio = 2
        self.ConvSize = 3
        self.pad = 1#(self.ConvSize - 1) / 2
        self.MaxLv = 5
        self.ChNum = [self.InputCh,64]
        for i in range(self.MaxLv-1):
            self.ChNum.append(self.ChNum[-1]*2)

        #self.imagelist = "ImageSets/Segmentation/train.txt"
        self.BatchSize = 16
        self.Shuffle = True
        self.LoadThread = 2
        self.inputsize = [224,224]
        #self.CollagenDir = '/content/drive/MyDrive/Studies/PhD_WT_IMC/Year_2 /WnET/clean_col_data/*tiff'
        #self.CollagenDir = '/content/drive/MyDrive/Studies/PhD_WT_IMC/Year_2 /WnET/clean_col_data_not_mask/*tiff'
        #self.CollagenDir = '/content/drive/MyDrive/Studies/PhD_WT_IMC/Year_2 /WnET/raw_col_data/*tiff'
        #self.CollagenDir = '/content/drive/MyDrive/Others/WnET/raw_col_data/*tiff'

        #partition configure
        self.K = 10

        #training configure
        self.init_lr = 0.05
        self.lr_decay = 0.1
        self.lr_decay_iter = 1000
        self.max_iter = 50000
        self.cuda_dev = 0
        self.cuda_dev_list = "0,1"
        self.check_iter = 1000

        #Ncuts Loss configure
        self.radius = 2
        self.sigmaI = 10
        self.sigmaX = 1


        #color library
        self.color_lib = []
        #color library
        self.color_lib = []
        for r in range(0,256,128):
            for g in range(0,256,128):
                for b in range(0,256,128):
                    self.color_lib.append((r,g,b))



class DataLoader():
    #initialization

    def __init__(self, datapath,mode):
        self.raw_data = []
        self.mode = mode

        datapath = glob.glob(datapath)

        #load the images
        for file_name in datapath:
            with Image.open(file_name) as image:

                # MODIFIED
                image = np.array(image) - 1
                image = Image.fromarray(image)
                if image.mode != "RGB":
                    image = image.convert("RGB")
                self.raw_data.append(np.array(image.resize((config.inputsize[0],config.inputsize[1]),Image.BILINEAR)))
        #self.raw_data = np.array(self.raw_data)

        # Data augment

        #resize and align
        self.scale()
        #normalize
        self.transfer()

        #calculate weights by 2
        if(mode == "train"):
            self.dataset = self.get_dataset(self.raw_data, self.raw_data.shape,75)
        else:
            self.dataset = self.get_dataset(self.raw_data, self.raw_data.shape,75)

    def scale(self):
        for i in range(len(self.raw_data)):
            image = self.raw_data[i]
            self.raw_data[i] = np.stack((image[:,:,0],image[:,:,1],image[:,:,2]),axis = 0)
        self.raw_data = np.stack(self.raw_data,axis = 0)

    def transfer(self):
        #just for RGB 8-bit color
        self.raw_data = self.raw_data.astype(float)
        #for i in range(self.raw_data.shape[0]):
        #    Image.fromarray(self.raw_data[i].swapaxes(0,-1).astype(np.uint8)).save("./reconstruction/input_"+str(i)+".jpg")

    def torch_loader(self):
        return Data.DataLoader(
                                self.dataset,
                                batch_size = config.BatchSize,
                                shuffle = config.Shuffle,
                                num_workers = config.LoadThread,
                                pin_memory = True,
                            )

    def cal_weight(self,raw_data,shape):
        print("calculating weights.")

        dissim = cp.zeros((shape[0],shape[1],shape[2],shape[3],(config.radius-1)*2+1,(config.radius-1)*2+1))
        data = cp.asarray(raw_data)
        padded_data = cp.pad(data,((0,0),(0,0),(config.radius-1,config.radius-1),(config.radius-1,config.radius-1)),'constant')
        for m in range(2*(config.radius-1)+1):
            for n in range(2*(config.radius-1)+1):
                dissim[:,:,:,:,m,n] = data-padded_data[:,:,m:shape[2]+m,n:shape[3]+n]
        #for i in range(dissim.shape[0]):
        #dissim = -cp.power(dissim,2).sum(1,keepdims = True)/config.sigmaI/config.sigmaI
        temp_dissim = cp.exp(-cp.power(dissim,2).sum(1,keepdims = True)/config.sigmaI**2)
        dist = cp.zeros((2*(config.radius-1)+1,2*(config.radius-1)+1))
        for m in range(1-config.radius,config.radius):
            for n in range(1-config.radius,config.radius):
                if m**2+n**2<config.radius**2:
                    dist[m+config.radius-1,n+config.radius-1] = cp.exp(-(m**2+n**2)/config.sigmaX**2)


        print("weight calculated.")
        res = cp.multiply(temp_dissim,dist)

        return res

    def get_dataset(self,raw_data,shape,batch_size):
        dataset = []
        for batch_id in range(0,shape[0],batch_size):
            print(batch_id)
            batch = raw_data[batch_id:min(shape[0],batch_id+batch_size)]
            if(self.mode == "train"):
                tmp_weight = self.cal_weight(batch,batch.shape)
                weight = cp.asnumpy(tmp_weight)
                dataset.append(Data.TensorDataset(torch.from_numpy(batch).float(),torch.from_numpy(weight).float()))
                del tmp_weight
            else:
                dataset.append(Data.TensorDataset(torch.from_numpy(batch).float()))
        cp.get_default_memory_pool().free_all_blocks()
        return Data.ConcatDataset(dataset)

config = Config()


# Build encoder and test training
# Implementation from https://arxiv.org/pdf/1711.08506.pdf
# Modified to weigh the Ncut loss and reconstruction loss
import torch.nn as nn

def add_conv(in_ch, out_ch):
  return torch.nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding = 1),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(out_ch),

                                    nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding = 1),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(out_ch))

def add_sep_conv(in_ch,out_ch):
  ''' defined as a depthwise and point wise convolution '''
  return torch.nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1),
                                    nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1, groups=out_ch),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(out_ch),

                                    nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=1),
                                    nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1, groups=out_ch),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(out_ch))


class WNetEncoder(nn.Module):
    ''' WNet '''
    def __init__(self, k=config.K, radius=config.radius):
        super(WNetEncoder, self).__init__()
        self.radius = radius
        self.k = k
        self.maxpool = nn.MaxPool2d(2)
        self.pad = nn.ConstantPad2d(4,0)


        # Module 1 - Conv 3x3
        self.module1 = add_conv(3, 64)


        # Module2 - separable conv 3x3,
        self.module2 = add_sep_conv(64 ,128)

        # Module3 - separable conv 3x3
        self.module3 =  add_sep_conv(128 , 256)

        # Module 4 - separable conv 3x3
        self.module4 = add_sep_conv(256 , 512)

        # Module 5 - separable conv 3x3 and upconv
        self.module5 = add_sep_conv(512 , 1024)
        self.upconv5 = nn.ConvTranspose2d(1024, 512, 2,2)

        # Module 6 - separable conv 3x3 and upconv
        self.module6 = add_sep_conv(1024 , 512)
        self.upconv6 = nn.ConvTranspose2d(512, 256, 2,2)

        # Module 7 - separable conv 3x3 and upconv
        self.module7 = add_sep_conv(512 , 256)
        self.upconv7 = nn.ConvTranspose2d(256, 128, 2,2)

        # Module 8 - separable conv 3x3 and upconv
        self.module8 = add_sep_conv(256 , 128)
        self.upconv8 = nn.ConvTranspose2d(128, 64, 2,2)

        # Module 9
        self.module9 = add_conv(128, 64)
        self.predconv = nn.Conv2d(64,self.k,1)
        self.softmax = nn.Softmax2d()



    def forward(self,x):

      # Module 1 and maxpool
      x1_skip = self.module1(x)
      x1 = self.maxpool(x1_skip)


      # Module 2 and maxpool
      x2_skip = self.module2(x1)
      x2 = self.maxpool(x2_skip)

      # Module 3 and maxpool
      x3_skip = self.module3(x2)
      x3 = self.maxpool(x3_skip)

      # Module 4 and maxpool
      x4_skip = self.module4(x3) # Join with output of module 5
      x4 = self.maxpool(x4_skip)

      # Module 5 and maxpool
      x5 = self.module5(x4)
      x5 = self.upconv5(x5)

      # Skip and module 6
      skip1 = torch.concat((x5,x4_skip), dim=1)
      x6 = self.module6(skip1)
      x6 = self.upconv6(x6)

      # Module 7
      skip2 = torch.concat((x6, x3_skip), dim=1)
      x7 = self.module7(skip2)
      x7 = self.upconv7(x7)

      # Module 8
      skip3 = torch.concat((x7, x2_skip), dim=1)
      x8 = self.module8(skip3)
      x8 = self.upconv8(x8)

      # Module 9
      skip4 = torch.concat((x8, x1_skip), dim=1)
      x9 = self.module9(skip4)
      x9 = self.predconv(x9)
      self.k_pred = self.softmax(x9)

      return self.k_pred, self.pad(self.k_pred)


class WNetDecoder(nn.Module):
  def __init__(self, k=config.K):
    super(WNetDecoder, self).__init__()
    self.k = k
    self.maxpool = nn.MaxPool2d(2)


    # Module 10
    self.module10 = add_conv(self.k, 64)

    # Module 11
    self.module11 = add_sep_conv(64, 128)

    # Module 12
    self.module12 = add_sep_conv(128, 256)

    # Module 13
    self.module13 = add_sep_conv(256, 512)

    # Module 14
    self.module14 = add_sep_conv(512, 1024)
    self.module14.append(nn.ConvTranspose2d(1024, 512, 2,2))

    # Module 15
    self.module15 = add_sep_conv(1024, 512)
    self.module15.append(nn.ConvTranspose2d(512, 256, 2,2))

    # Module 16
    self.module16 = add_sep_conv(512, 256)
    self.module16.append(nn.ConvTranspose2d(256, 128, 2,2))

    # Module 17
    self.module17 = add_sep_conv(256, 128)
    self.module17.append(nn.ConvTranspose2d(128, 64, 2,2))

    # Module 18
    self.module18 = add_conv(128, 64)
    self.module18.append(nn.Conv2d(64,3 ,1))

  def forward(self,x):

    # Module 10 and maxpool, x10_skip concat with module 18
    x10_skip = self.module10(x)
    x10 = self.maxpool(x10_skip)

    # Module 11 and maxpool
    x11_skip = self.module11(x10)
    x11 = self.maxpool(x11_skip)

    # Module 12 and maxpool
    x12_skip = self.module12(x11)
    x12 = self.maxpool(x12_skip)

    # Module 13 and maxpool
    x13_skip = self.module13(x12)
    x13 = self.maxpool(x13_skip)

    # Module 14 and upconv
    x14 = self.module14(x13)

    # Module 15, skip connection from 13
    skip13_to_15 =  torch.concat((x13_skip, x14), dim=1)
    x15 = self.module15(skip13_to_15)

    # Module 16, skip connection from 12
    skip12_to_16 = torch.concat((x12_skip, x15), dim=1)
    x16 = self.module16(skip12_to_16)

    # Module 17, skip connection from 11
    skip11_to_17 = torch.concat((x11_skip, x16),dim=1)
    x17 = self.module17(skip11_to_17)

    # Module 18, skip connection from 10
    skip10_to_18 = torch.concat((x10_skip, x17), dim=1)
    x_pred = self.module18(skip10_to_18)
    return x_pred






class WNet(nn.Module):
  def __init__(self, K= config.K):
    super(WNet, self).__init__()

    self.K = K
    self.encoder = WNetEncoder(self.K)
    self.decoder = WNetDecoder(self.K)

  def forward(self, x):

    x_segmented, xseg_pad = self.encoder(x)
    x_pred = self.decoder(x_segmented)
    return x_pred, x_segmented, xseg_pad



class NCutsLoss(nn.Module):
    def __init__(self, radius=config.radius):
        super(NCutsLoss,self).__init__()
        self.radius = radius
        self.gpu_list = []

    def forward(self, seg, padded_seg, weight,sum_weight):
        cropped_seg = []
        for m in torch.arange((self.radius-1)*2+1,dtype=torch.long):
            column = []
            for n in torch.arange((self.radius-1)*2+1,dtype=torch.long):
                column.append(padded_seg[:,:,m:m+seg.size()[2],n:n+seg.size()[3]].clone())
            cropped_seg.append(torch.stack(column,4))
        cropped_seg = torch.stack(cropped_seg,4)

        multi1 = cropped_seg.mul(weight)
        multi2 = multi1.sum(-1).sum(-1).mul(seg)
        multi3 = sum_weight.mul(seg)

        assocA = multi2.view(multi2.shape[0],multi2.shape[1],-1).sum(-1)
        assocV = multi3.view(multi3.shape[0],multi3.shape[1],-1).sum(-1)
        assoc = assocA.div(assocV).sum(-1)

        return torch.add(-assoc,10)

