<a href="https://colab.research.google.com/github/moeghaf/Deep-unsupervised-segmentation/blob/main/WNET_paper_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WNET - Deep unsupervised segmentation

Implementation of WNet https://arxiv.org/pdf/1711.08506.pdf

In [6]:
import torch
import torch.nn as nn
import tensorflow as tf
import pandas as pd
import numpy as np

In [66]:
# Build encoder and test training
# Implementation from https://arxiv.org/pdf/1711.08506.pdf
# Modified to weigh the Ncut loss and reconstruction loss


class WNetEncoder(nn.Module):
    ''' WNet '''
    def __init__(self):
        super(WNetEncoder, self).__init__()

        self.maxpool = nn.MaxPool2d(2)

        # Module 1 - Conv 3x3
        self.module1 = torch.nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding = 1),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(64),

                                           nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding = 1),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(64),

                                           nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding = 1),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(64))

        # Module2 - separable conv 3x3, defined as a depthwise and point wise convolution
        self.module2 = torch.nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1),
                                           nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, groups=128),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(128),

                                           nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1),
                                           nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, groups=128),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(128))

        # Module3 - separable conv 3x3
        self.module3 = torch.nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1),
                                           nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, groups=256),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(256),

                                           nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1),
                                           nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, groups=256),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(256))

        # Module 4 - separable conv 3x3
        self.module4 = torch.nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1),
                                           nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, groups=512),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(512),

                                           nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
                                           nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, groups=512),
                                           nn.ReLU(),
                                           nn.BatchNorm2d(512))




    def forward(self,x):

      # Module 1 and maxpool
      x = self.module1(x)
      x1 = self.maxpool(x)


      # Module 2 and maxpool
      x1 = self.module2(x1)
      x2 = self.module(x1)

      # Module 3 and maxpool
      x2 = self.module3(x2)
      x3 = self.maxpool(x2)







IndentationError: ignored

In [62]:
model = WNetEncoder()
x = torch.tensor(np.zeros((1,3,224,224)).astype(np.float32))

In [63]:
x1 = model(x)
x1.shape

torch.Size([1, 64, 224, 224])