<a href="https://colab.research.google.com/github/moeghaf/Deep-unsupervised-segmentation/blob/main/WNET_paper_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WNET - Deep unsupervised segmentation

Implementation of WNet https://arxiv.org/pdf/1711.08506.pdf

In [6]:
import torch
import torch.nn as nn
import tensorflow as tf
import pandas as pd
import numpy as np

In [74]:
# Build encoder and test training
# Implementation from https://arxiv.org/pdf/1711.08506.pdf
# Modified to weigh the Ncut loss and reconstruction loss


class WNetEncoder(nn.Module):
    ''' WNet '''
    def __init__(self):
        super(WNetEncoder, self).__init__()

        self.maxpool = nn.MaxPool2d(2)

        # Module 1 - Conv 3x3
        self.module1 = torch.nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding = 1),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(64),

                                           nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding = 1),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(64),

                                           nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding = 1),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(64))

        # Module2 - separable conv 3x3, defined as a depthwise and point wise convolution
        self.module2 = torch.nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1),
                                           nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, groups=128),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(128),

                                           nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1),
                                           nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, groups=128),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(128))

        # Module3 - separable conv 3x3
        self.module3 = torch.nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1),
                                           nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, groups=256),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(256),

                                           nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1),
                                           nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, groups=256),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(256))

        # Module 4 - separable conv 3x3
        self.module4 = torch.nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1),
                                           nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, groups=512),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(512),

                                           nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
                                           nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, groups=512),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(512))

        # Module 5 - separable conv 3x3
        self.module5 = torch.nn.Sequential(nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1),
                                           nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1, groups=1024),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(1024),

                                           nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1),
                                           nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1, groups=1024),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(1024))

        self.upconv5 = torch.nn.Sequential(nn.ConvTranspose2d(1024, 1024, 2,2))

        # Module 6 - separable conv 3x3
        self.m




    def forward(self,x):

      # Module 1 and maxpool
      x = self.module1(x)
      x1 = self.maxpool(x)


      # Module 2 and maxpool
      x1 = self.module2(x1)
      x2 = self.module(x1)

      # Module 3 and maxpool
      x2 = self.module3(x2)
      x3 = self.maxpool(x2)

      # Module 4 and maxpool
      x3 = self.module4(x3) # Join with output of module 5
      x4 = self.maxpool(x3)

      # Module 5 and maxpool
      x4 = self.module5(x4)
      x5 = self.upconv5(x4)

      # Add output of 4
      #x3







In [78]:
model = WNetEncoder()
x = torch.tensor(np.zeros((1,3,224,224)).astype(np.float32))

In [79]:
x = model.module1(x)
x1 = model.maxpool(x)


# Module 2 and maxpool
x1 = model.module2(x1)
x2 = model.maxpool(x1)

# Module 3 and maxpool
x2 = model.module3(x2)
x3 = model.maxpool(x2)

# Module 4 and maxpool
x3 = model.module4(x3) # Join with output of module 5
x4 = model.maxpool(x3)

# Module 5 and upconv
x4 = model.module5(x4)
x5 = model.upconv5(x4)

# Skip layer into module 6
x5 = torch.concat((x5,x3), dim=1)
# Module 6 and upconv






tensor([[[[-2.5994e-01, -7.9578e-01,  5.0006e-01,  ..., -3.3107e-01,
            3.1169e-01, -5.4255e-01],
          [-8.1609e-02, -3.3610e-01, -2.4916e-01,  ..., -7.4583e-01,
            4.1038e-01, -9.0952e-01],
          [ 5.5636e-01, -2.2577e-01,  3.6912e-01,  ..., -2.7182e-01,
           -6.2048e-01,  6.6920e-02],
          ...,
          [ 4.9273e-01,  9.5852e-01,  8.9347e-02,  ...,  1.2182e+00,
            3.1169e-01,  7.4757e-01],
          [ 6.5358e-01,  7.9171e-02,  5.9858e-01,  ..., -5.3064e-01,
           -7.1281e-01,  5.6261e-01],
          [-9.0767e-01, -3.3618e-01, -3.5847e-03,  ...,  1.1084e+00,
            6.6298e-01,  6.0464e-01]],

         [[-3.4689e-01,  5.7234e-01,  2.2426e-01,  ...,  6.5859e-01,
           -3.3813e-01, -1.0753e-01],
          [ 4.6524e-01, -2.9018e-02,  2.2370e-01,  ...,  8.2671e-01,
           -4.3728e-01,  4.3573e-01],
          [ 7.9725e-01,  5.0398e-01,  1.3812e+00,  ..., -5.0382e-01,
           -7.5192e-01, -1.7212e-01],
          ...,
     

In [80]:
x5.shape

torch.Size([1, 1024, 28, 28])

In [63]:
x1 = model(x)
x1.shape

torch.Size([1, 64, 224, 224])