In [1]:
import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
import time
import random
import argparse
import math
import logging

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.utils.data as data

#import utils
#import models.builer as builder
#import dataloader

from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR
from torchvision.transforms import transforms

from tqdm import tqdm
from PIL import Image, ImageFilter

In [7]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(

            ####block one#############
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(kernel_size=2, stride=2),

            ####block two#############
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(kernel_size=2, stride=2),

            ####block three#############
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=2, dilation=2),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(kernel_size=2, stride=2),

            ####block four#############
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=2, dilation=2),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            ####block five#############
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            ####block six#############
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            ####linear layers#########
            nn.Flatten(),  # Flatten the output of conv layer
            nn.Linear(16 * 28 * 28, 512),  # Adjust 512 as needed
            nn.Tanh(),
            
            nn.Linear(512, latent_dim),  # Adjust 128 as needed
            nn.Tanh()

        )

    def forward(self, x):
        x = self.encoder(x)
        return x

    #def forward(self, x):
    #    for layer in self.encoder:
    #        x = layer(x)
    #        print(f"Shape after {layer.__class__.__name__}: {x.shape}")
    #    return x
    

class Decoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(

            ###linear layers##############
            nn.Linear(latent_dim, 512),
            nn.Tanh(),
            
            nn.Linear(512, 16 * 28 * 28),
            nn.Tanh(),
            
            nn.Unflatten(1, (16, 28, 28)),

            ####block six#################
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),

            ####block five#################
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),

            ####block four#################
            #nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2),

            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),

            ####block three#################
            nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2),

            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),

            ####block two#################
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2),

            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),

            ####block one#################
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),

            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),

            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

    #def forward(self, x):
    #    for layer in self.decoder:
    #        x = layer(x)
    #        print(f"Shape after {layer.__class__.__name__}: {x.shape}")
    #    return x


class Autoencoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [9]:
import torch

if __name__ == "__main__":
    # Instantiate the autoencoder
    autoencoder = Autoencoder(latent_dim=128)
    
    # Create a random input tensor (batch_size=1, channels=3, height=224, width=224)
    input_tensor = torch.rand(1, 3, 224, 224)
    
    # Pass the tensor through the encoder
    print("Passing through Encoder:")
    x = input_tensor
    for layer in autoencoder.encoder.encoder:
        x = layer(x)
        print(f"Shape after {layer.__class__.__name__}: {x.shape}")
    
    # Pass the tensor through the decoder
    print("\nPassing through Decoder:")
    for layer in autoencoder.decoder.decoder:
        x = layer(x)
        print(f"Shape after {layer.__class__.__name__}: {x.shape}")


Passing through Encoder:
Shape after Conv2d: torch.Size([1, 64, 224, 224])
Shape after BatchNorm2d: torch.Size([1, 64, 224, 224])
Shape after ReLU: torch.Size([1, 64, 224, 224])
Shape after Conv2d: torch.Size([1, 64, 224, 224])
Shape after BatchNorm2d: torch.Size([1, 64, 224, 224])
Shape after ReLU: torch.Size([1, 64, 224, 224])
Shape after MaxPool2d: torch.Size([1, 64, 112, 112])
Shape after Conv2d: torch.Size([1, 128, 112, 112])
Shape after BatchNorm2d: torch.Size([1, 128, 112, 112])
Shape after ReLU: torch.Size([1, 128, 112, 112])
Shape after Conv2d: torch.Size([1, 128, 112, 112])
Shape after BatchNorm2d: torch.Size([1, 128, 112, 112])
Shape after ReLU: torch.Size([1, 128, 112, 112])
Shape after MaxPool2d: torch.Size([1, 128, 56, 56])
Shape after Conv2d: torch.Size([1, 256, 56, 56])
Shape after BatchNorm2d: torch.Size([1, 256, 56, 56])
Shape after ReLU: torch.Size([1, 256, 56, 56])
Shape after Conv2d: torch.Size([1, 256, 56, 56])
Shape after BatchNorm2d: torch.Size([1, 256, 56, 56])