# Notebook to experiment with different Conv Arch

In [1]:
import torch
import torch.nn as nn

import numpy as np
import sys
import cv2

import os, shutil
from pathlib import Path
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Model

* https://github.com/AntixK/PyTorch-VAE
* https://ml-cheatsheet.readthedocs.io/en/latest/architectures.html#vae

In [18]:
class SkipConnection_encoder(nn.Module):
    def __init__(self, cache, index):
        super(SkipConnection_encoder, self).__init__()
        self.cache = cache
        self.index = index
    def forward(self, x):
        self.cache[self.index] = x
        return x

class SkipConnection_decoder(nn.Module):
    def __init__(self, cache, index):
        super(SkipConnection_decoder, self).__init__()
        self.cache = cache
        self.index = index
    def forward(self, x):
        return x + self.cache[self.index]
    
#https://towardsdatascience.com/using-skip-connections-to-enhance-denoising-autoencoder-algorithms-849e049c0ac9    
class AE_1_skipped(nn.Module):
    def __init__(self, test_input, device):
        super(AE_1_skipped, self).__init__()

        if len(test_input.shape) == 3:
            test_input = test_input.unsqueeze(dim=0)

        channels = test_input.shape[1]
        print(channels)
        
        self.device = device
        
        self.cache = [0,0,0]
        
        self.encoder = nn.Sequential(
            nn.Dropout(p=0.2), #randomly drops 20% of input
            nn.Conv2d(in_channels=channels, out_channels=64, kernel_size=5, stride=1, padding=2), #C1 out: 64, 400, 400
            SkipConnection_encoder(self.cache, 0), #index 0, skipping to U4
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),#M1 out: 64, 200, 200
            

            # conv 2  
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=2),#C2 out: 64, 200, 200
            SkipConnection_encoder(self.cache, 1), #index 1, skipping to U3
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),#M2 out: 64, 100, 100

            # conv 3
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),#C3 out: 128, 100, 100
            SkipConnection_encoder(self.cache, 2), #index 2, skipping to U2
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),#M3 out: 128, 50, 50

            # conv 4  
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),#C4 out: 128, 50, 50
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),#M4 out: 128, 25, 25

            # conv 5
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),#C5 out: 128, 25, 25
            nn.LeakyReLU(),            
        )
              
        self.decoder = nn.Sequential(
            # conv 6
            nn.Upsample(scale_factor=2, mode='bilinear'), #U1 128, 50, 50
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2),#C6 out128, 50, 50
            nn.LeakyReLU(),

            # conv 7
            nn.Upsample(scale_factor=2, mode='bilinear'), #U2 128, 100, 100
            SkipConnection_decoder(self.cache, 2),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=2), #C7 64, 100, 100
            nn.LeakyReLU(),

            # conv 8
            nn.Upsample(scale_factor=2, mode='bilinear'), #U3 64, 200, 200
            SkipConnection_decoder(self.cache, 1),
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=2), #C8 64, 200, 200
            nn.LeakyReLU(),

            # conv 9
            nn.Upsample(scale_factor=2, mode='bilinear'), #U4 64, 400, 400
            SkipConnection_decoder(self.cache, 0),
            nn.ConvTranspose2d(in_channels=64, out_channels=channels, kernel_size=3, stride=1, padding=1),  #C9 3, 400, 400
            nn.Sigmoid()
        )

        #test the size changes
        with torch.no_grad():
            
            x = self.encoder(test_input)
            print('Encoded from ', test_input.shape, 'to', x.shape)
            x = self.decoder(x)
            print('output shape', x.shape)
        
    def forward(self, x):
        x = x.to(self.device)
        return self.decoder( self.encoder(x) ) 
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z) 
    
    def loss(self, x_input, x_ground_truth):
        x_hat = self.forward(x_input)
        #loss = F.binary_cross_entropy( x_hat, x_ground_truth )
        loss = F.mse_loss( x_hat, x_ground_truth )
        return loss, x_hat.detach()

In [19]:
#from models.custom import AE_1_skipped, AE_2_skipped,AE_3_skipped, DAE1
#%load_ext autoreload
#%reload_ext autoreload
#%autoreload 2

In [20]:
test_input = torch.zeros((1,1,1024,1024))
ae_model = AE_1_skipped(test_input=test_input, device=device)

1
Encoded from  torch.Size([1, 1, 1024, 1024]) to torch.Size([1, 128, 64, 64])




output shape torch.Size([1, 1, 1024, 1024])
