In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np
from torchvision import datasets, transforms
import time
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

cuda


In [44]:

class CNN(nn.Module):
    def __init__(self, inChannels=3, nF=[16,32,64,128,256,512]):
        super(CNN, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=inChannels, out_channels=nF[0], kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(in_channels=nF[0], out_channels=nF[1], kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(in_channels=nF[1], out_channels=nF[2], kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(in_channels=nF[2], out_channels=nF[3], kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(in_channels=nF[3], out_channels=nF[4], kernel_size=(3,3), stride=1),
            nn.ReLU(),
        )
        self.flatten    = nn.Flatten()
        self.bottleneck = nn.Sequential(
            nn.Linear(in_features=nF[4]*4*4, out_features=nF[4]),
            nn.ReLU(),            
        )
        self.unflatten  = nn.Unflatten(dim=1, unflattened_size=(nF[4],1,1))
        self.decoder = nn.Sequential(            
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ConvTranspose2d(in_channels=nF[4], out_channels=nF[3], kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ConvTranspose2d(in_channels=nF[3], out_channels=nF[2], kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ConvTranspose2d(in_channels=nF[2], out_channels=nF[1], kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ConvTranspose2d(in_channels=nF[1], out_channels=nF[0], kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ConvTranspose2d(in_channels=nF[0], out_channels=inChannels, kernel_size=(3,3), stride=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=inChannels, out_channels=inChannels, kernel_size=(3,3), stride=1),
            nn.ReLU(),
        )
    
    def forward(self, x):
        encoder    = self.encoder(x)
        flatten    = self.flatten(encoder)
        bottleneck = self.bottleneck(flatten)
        unflatten  = self.unflatten(bottleneck)
        decoder    = self.decoder(unflatten)
        sigmoid    = nn.Sigmoid()
        #output     = sigmoid(decoder)
        output     = decoder
        return output

In [46]:

model = CNN()
ary = np.zeros((1,3,128,128),dtype=np.float32)
input = torch.Tensor(ary)
output = model.encoder(input)
print(output.shape)
output = model.flatten(output)
print(output.shape)
output = model.bottleneck(output)
print(output.shape)
output = model.unflatten(output)
print(output.shape)
output = model.decoder(output)
print(output.shape)


torch.Size([1, 256, 4, 4])
torch.Size([1, 4096])
torch.Size([1, 256])
torch.Size([1, 256, 1, 1])
torch.Size([1, 3, 96, 96])
