<a href="https://colab.research.google.com/github/nackjaylor/sydney-innovation-program/blob/main/sip_unsupervised_and_deep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

In [None]:
class AutoEncoder_Linear(nn.Module):

  def __init__(self):
    super().__init__()


    self.encoder = nn.Sequential(
            nn.Linear(64 * 64, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 36),
            nn.ReLU(),
            nn.Linear(36, 18),
            nn.ReLU(),
            nn.Linear(18, 9)
        )
    
    self.decoder = nn.Sequential(
                nn.Linear(9, 18),
                nn.ReLU(),
                nn.Linear(18, 36),
                nn.ReLU(),
                nn.Linear(36, 64),
                nn.ReLU(),
                nn.Linear(64, 128),
                nn.ReLU(),
                nn.Linear(128, 64 * 64),
                nn.Sigmoid()
            )
    
  def forward(self, x):

    x = self.encoder(x)
    x = self.decoder(x)

    return x

In [None]:
class AutoEncoder_Convolutional(nn.Module):

  def __init__(self):
    super().__init__()


    self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
    self.flatten = nn.Flatten(start_dim=1)
    self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            nn.Linear(128, 9)
        )

    self.decoder_lin = nn.Sequential(
            nn.Linear(9, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

    self.unflatten = nn.Unflatten(dim=1, 
          unflattened_size=(32, 3, 3))

    self.decoder_conv = nn.Sequential(
        nn.ConvTranspose2d(32, 16, 3, 
        stride=2, output_padding=0),
        nn.BatchNorm2d(16),
        nn.ReLU(True),
        nn.ConvTranspose2d(16, 8, 3, stride=2, 
        padding=1, output_padding=1),
        nn.BatchNorm2d(8),
        nn.ReLU(True),
        nn.ConvTranspose2d(8, 1, 3, stride=2, 
        padding=1, output_padding=1)
    )
    
  def forward(self, x):

    x = self.encoder_cnn(x)
    x = self.flatten(x)
    x = self.encoder_lin(x)
    x = self.decoder_lin(x)
    x = self.unflatten(x)
    x = self.decoder_conv(x)
    x = torch.sigmoid(x)

    return x

In [None]:
def Segmentation_Network(nn.Module):
  def __init__(self):
    super().__init__()


    self.encoder_cnn = AutoEncoder_Convolutional.encoder_cnn

    self.feature_extract_1 = nn.Sequential(nn.Conv2d(32, 64, stride = 1),
                                           nn.BatchNorm2d(64)
                                           nn.Relu()
                                           )

    self.feature_extract_2 = nn.Sequential(nn.Conv2d(32, 64, stride = 3),
                                           nn.BatchNorm2d(64),
                                           nn.Relu(),
                                           nn.Upsample(scale_factor=3, mode='bilinear', align_corners=True)
                                          )
    
    
    self.top_conv = nn.Sequential(
        nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1),
        nn.ReLU())

    self.lateral_conv1 = nn.Sequential(
        nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
        nn.ReLU())

    self.lateral_conv2 = nn.Sequential(
        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1),
        nn.ReLU())

    self.lateral_conv3 = nn.Sequential(
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1),
        nn.ReLU())

    # background is considered as one additional class
    #   with label '0' by default
    self.segmentation_conv = nn.Sequential(
        nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(64, self.args.n_classes + 1, kernel_size=1)
    )



  def upsample_add(self, low_res_map, high_res_map):
      upsampled_map = nn.UpsamplingBilinear2d(scale_factor=2)(low_res_map)
      return upsampled_map + high_res_map

  def forward(self, x):

      # Encoder
      c1 =  self.res18_backbone(img)
      c2 =  self.conv2_x(c1)# 48 x 64
      c3 =  self.conv3_x(c2)# 24 x 32
      c4 =  self.conv4_x(c3)# 12 x 16
      c5 =  self.conv5_x(c4)# 6 x 8
      # Decoder
      p5 =   self.top_conv(c5)# 6 x 8
      p4 =   self.upsample_add(p5,self.lateral_conv1(c4))# 12 x 16
      p3 =   self.upsample_add(p4,self.lateral_conv2(c3))# 24 x 32
      p2 =   self.upsample_add(p3,self.lateral_conv3(c2))# 48 x 64
      out = nn.UpsamplingBilinear2d(scale_factor=2)(p2)
      out = self.segmentation_conv(out)
      return out


