<a href="https://colab.research.google.com/github/nackjaylor/sydney-innovation-program/blob/main/sip_unsupervised_and_deep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
from torch import nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
transform = transforms.Compose([transforms.ToTensor()])

In [10]:
class AutoEncoder_Linear(nn.Module):

  def __init__(self):
    super().__init__()


    self.encoder = nn.Sequential(
            nn.Linear(64 * 64, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 36),
            nn.ReLU(),
            nn.Linear(36, 18),
            nn.ReLU(),
            nn.Linear(18, 9)
        )
    
    self.decoder = nn.Sequential(
                nn.Linear(9, 18),
                nn.ReLU(),
                nn.Linear(18, 36),
                nn.ReLU(),
                nn.Linear(36, 64),
                nn.ReLU(),
                nn.Linear(64, 128),
                nn.ReLU(),
                nn.Linear(128, 64 * 64),
                nn.Sigmoid()
            )
    
  def forward(self, x):

    x = self.encoder(x)
    x = self.decoder(x)

    return x

In [11]:
class AutoEncoder_Convolutional(nn.Module):

  def __init__(self):
    super().__init__()


    self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
    self.flatten = nn.Flatten(start_dim=1)
    self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            nn.Linear(128, 9)
        )

    self.decoder_lin = nn.Sequential(
            nn.Linear(9, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

    self.unflatten = nn.Unflatten(dim=1, 
          unflattened_size=(32, 3, 3))

    self.decoder_conv = nn.Sequential(
        nn.ConvTranspose2d(32, 16, 3, 
        stride=2, output_padding=0),
        nn.BatchNorm2d(16),
        nn.ReLU(True),
        nn.ConvTranspose2d(16, 8, 3, stride=2, 
        padding=1, output_padding=1),
        nn.BatchNorm2d(8),
        nn.ReLU(True),
        nn.ConvTranspose2d(8, 1, 3, stride=2, 
        padding=1, output_padding=1)
    )
    
  def forward(self, x):

    x = self.encoder_cnn(x)
    x = self.flatten(x)
    x = self.encoder_lin(x)
    x = self.decoder_lin(x)
    x = self.unflatten(x)
    x = self.decoder_conv(x)
    x = torch.sigmoid(x)

    return x