<a href="https://colab.research.google.com/github/moeghaf/Deep-unsupervised-segmentation/blob/main/WNET_paper_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WNET - Deep unsupervised segmentation

Implementation of WNet https://arxiv.org/pdf/1711.08506.pdf

In [6]:
import torch
import torch.nn as nn
import tensorflow as tf
import pandas as pd
import numpy as np

In [None]:
def add_conv():


In [90]:
# Build encoder and test training
# Implementation from https://arxiv.org/pdf/1711.08506.pdf
# Modified to weigh the Ncut loss and reconstruction loss


class WNetEncoder(nn.Module):
    ''' WNet '''
    def __init__(self, k ):
        super(WNetEncoder, self).__init__()

        self.k = k
        self.maxpool = nn.MaxPool2d(2)
        self.predconv = nn.Conv2d(224,self.k,1,)
        self.softmax = nn.Softmax2d()


        def add_conv(in_ch, out_ch):
          return torch.nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding = 1),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(out_ch),

                                           nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding = 1),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(out_ch))

        def add_sep_conv(in_ch,out_ch):
          ''' defined as a depthwise and point wise convolution '''
          return torch.nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1),
                                           nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1, groups=out_ch),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(out_ch),

                                           nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=1),
                                           nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1, groups=out_ch),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(out_ch))

        # Module 1 - Conv 3x3
        self.module1 = add_conv(3, 64)


        # Module2 - separable conv 3x3,
        self.module2 = add_sep_conv(64 ,128)

        # Module3 - separable conv 3x3
        self.module3 =  add_sep_conv(128 , 256)

        # Module 4 - separable conv 3x3
        self.module4 = add_sep_conv(256 , 512)

        # Module 5 - separable conv 3x3 and upconv
        self.module5 = add_sep_conv(512 , 1024)
        self.upconv5 = torch.nn.Sequential(nn.ConvTranspose2d(1024, 512, 2,2))

        # Module 6 - separable conv 3x3 and upconv
        self.module6 = add_sep_conv(1024 , 512)
        self.upconv6 = torch.nn.Sequential(nn.ConvTranspose2d(512, 256, 2,2))

        # Module 7 - separable conv 3x3 and upconv
        self.module7 = add_sep_conv(512 , 256)
        self.upconv7 = torch.nn.Sequential(nn.ConvTranspose2d(256, 128, 2,2))

        # Module 8 - separable conv 3x3 and upconv
        self.module8 = add_sep_conv(256 , 128)
        self.upconv8 = torch.nn.Sequential(nn.ConvTranspose2d(128, 64, 2,2))

        # Module 9
        self.module9 = add_conv(128, 64)
        self.convK =








    def forward(self,x):

      # Module 1 and maxpool
      x = self.module1(x)
      x1 = self.maxpool(x)


      # Module 2 and maxpool
      x1 = self.module2(x1)
      x2 = self.module(x1)

      # Module 3 and maxpool
      x2 = self.module3(x2)
      x3 = self.maxpool(x2)

      # Module 4 and maxpool
      x3 = self.module4(x3) # Join with output of module 5
      x4 = self.maxpool(x3)

      # Module 5 and maxpool
      x4 = self.module5(x4)
      x5 = self.upconv5(x4)

      # Skip and module 6
      skip1 = torch.concat((x5,x3), dim=1)
      x6 = self.module6(skip1)












In [91]:
model = WNetEncoder()
x = torch.tensor(np.zeros((1,3,224,224)).astype(np.float32))

In [92]:
x = model.module1(x)
x1 = model.maxpool(x)


# Module 2 and maxpool
x1 = model.module2(x1)
x2 = model.maxpool(x1)

# Module 3 and maxpool
x2 = model.module3(x2)
x3 = model.maxpool(x2)

# Module 4 and maxpool
x3 = model.module4(x3) # Join with output of module 5
x4 = model.maxpool(x3)

# Module 5 and upconv
x4 = model.module5(x4)
x5 = model.upconv5(x4)

# Skip layer into module 6
x5 = torch.concat((x5,x3), dim=1)

# Module 6 and upconv






In [93]:
x5.shape

torch.Size([1, 1024, 28, 28])

In [80]:
x5.shape

torch.Size([1, 1024, 28, 28])

In [63]:
x1 = model(x)
x1.shape

torch.Size([1, 64, 224, 224])