### Imports

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision


from torch import utils
from torch import nn
from torchsummary import summary
from torchvision import datasets
from torchvision import transforms

### Constants and definitions

In [2]:
%matplotlib inline

BATCH_SIZE = 64
DATA_FOLDER = '../data/CelebA/'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ENCODED_DIMS = 200
IMAGE_SIZE = 128

### Data loading

Load entire dataset:

In [None]:
dataset = datasets.ImageFolder(root=DATA_FOLDER,
                               transform = transforms.Compose([
                                   transforms.Resize(IMAGE_SIZE),
                                   transforms.CenterCrop(IMAGE_SIZE),
                                   transforms.ToTensor(),
                                   transforms.Normalize((.5, .5, .5), (.5, .5, .5))
                               ]))

Split into training and test sets:

In [None]:
# get entire dataset size
dataset_size = len(dataset)

# get train and test sets sizes
train_size = int(dataset_size * 0.99)
test_size = dataset_size - train_size

# split train and test sets
train_set, test_set = utils.data.random_split(dataset, (train_size, test_size))

# report
print(f'CelebA size: {dataset_size} images\n'
      f'Training set size: {train_size} images\n'
      f'Test set size: {test_size} images')

Create data loaders:

In [None]:
train_loader = utils.data.DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
test_loader = utils.data.DataLoader(test_set, batch_size = BATCH_SIZE * 2, shuffle = False)

Visualize:

In [None]:
# get sample batch
sample = next(iter(train_loader))

# create image grid
grid = torchvision.utils.make_grid(sample[0].to('cpu')[:64], padding = 2, normalize = True)

# plot
plt.figure(figsize=(12, 12))
plt.axis('off')
plt.imshow(np.transpose(grid, (1, 2, 0)))

### Sampler

Sampler layer definition:

In [None]:
class Sampler(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, mu, log_var):
        epsilon = torch.randn_like(mu)
        return mu + torch.exp(log_var / 2) * epsilon 

### Encoder

Encoder model defition:

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encode = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.Dropout2d(),
            
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Dropout2d(),
            
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Dropout2d(),
            
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Dropout2d(),

            nn.Flatten()
        )
        
        self.mu = nn.Linear(4096, ENCODED_DIMS)
        self.log_var = nn.Linear(4096, ENCODED_DIMS)
        self.sample = Sampler()
        
    def forward(self, x):
        x = self.encode(x)
        mu = self.mu(x)
        log_var = self.log_var(x)
        sampled = self.sample(mu, log_var)
        
        return [mu, log_var, sampled]

Create encoder:

In [None]:
encoder = Encoder()

Visualize model:

In [None]:
summary(encoder.to(DEVICE), (3, 128, 128))