In [179]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from torch.utils.data import Dataset, DataLoader
from PIL import Image


In [180]:
# run params
section = 'vae'
run_id = '0001'
data_name = 'faces'
RUN_FOLDER = 'run/{}/'.format(section)
RUN_FOLDER += '_'.join([run_id, data_name])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #


DATA_FOLDER = './data/celeb/'

In [181]:
NUM_CLASSES = 10
BATCH_SIZE = 32
transform = transforms.Compose([   
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
filenames = np.array(glob(os.path.join(DATA_FOLDER, '*/*.jpg')))
NUM_IMAGES = len(filenames)


class SimpleDataset(Dataset):
    def __init__(self, filenames, transform):
        self.filenames = filenames
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img = Image.open(self.filenames[idx])
        if self.transform:
            img = self.transform(img)
        return img

In [182]:
dataset = SimpleDataset(filenames, transform)
dataloader = DataLoader(dataset, BATCH_SIZE)
z_dim = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [183]:
next(iter(dataloader)).shape

torch.Size([32, 3, 128, 128])

In [184]:
class Encoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 2)
        self.bn1 = nn.BatchNorm2d(32)
        self.dropout1 = nn.Dropout()

        self.conv2 = nn.Conv2d(32, 64, 3, 2)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout2 = nn.Dropout()

        self.conv3 = nn.Conv2d(64, 64, 3, 2)
        self.bn3 = nn.BatchNorm2d(64)
        self.dropout3 = nn.Dropout()

        self.conv4 = nn.Conv2d(64, 64, 3, 2)
        self.bn4 = nn.BatchNorm2d(64)
        self.dropout4 = nn.Dropout()

        self.flatten = nn.Flatten()
        self.mu = nn.Linear(4096, z_dim)

        self.log_var = nn.Linear(4096, z_dim)

    def forward(self, x):
        x = F.pad(self.conv1(x), (1,0,1,0))
        x = self.bn1(x)
        x = F.leaky_relu(x)
        x = self.dropout1(x)

        x = F.pad(self.conv2(x), (1,0,1,0))
        x = self.bn2(x)
        x = F.leaky_relu(x)
        x = self.dropout2(x)

        x = F.pad(self.conv3(x), (1,0,1,0))
        x = self.bn3(x)
        x = F.leaky_relu(x)
        x = self.dropout3(x)

        x = F.pad(self.conv4(x), (1,0,1,0))
        x = self.bn4(x)
        x = F.leaky_relu(x)
        x = self.dropout4(x)

        x = self.flatten(x)

        mu, log_var = self.mu(x), self.log_var(x)

        def sampling(args):
            mu, log_var = args
            epsilon = torch.normal(0., 1., size=mu.shape).to(device)
            return mu + torch.exp(log_var / 2) * epsilon

        x = sampling([mu, log_var])

        return x, mu, log_var


In [185]:
Encoder().to(device)(torch.randn((1, 3, 128, 128)).to(device))

(tensor([[-0.5133,  0.6966,  2.7349, -1.6937,  0.5595, -0.2291,  0.9756,  1.9582,
           1.2483,  0.4462,  0.8618,  0.3774,  0.6404, -1.0817, -0.0484, -0.5272,
          -1.4039,  1.2009,  1.2682,  2.3018, -2.1301,  0.7920,  1.1097, -0.7670,
           0.1408, -1.4177, -0.7196, -0.4259, -0.2756,  2.1120, -1.5853, -0.0338,
           0.3483,  1.0878, -0.1384,  0.2455, -0.6010, -0.8431, -2.3416, -1.4999,
          -2.1460,  1.4183,  1.0914, -0.8100,  1.1454, -1.5163, -1.3616,  1.0333,
           0.1617, -1.2724, -1.2692, -2.3126,  1.2156,  0.0880, -0.5975,  0.1915,
          -0.1213, -1.2717, -0.1225, -1.1866,  0.9553,  1.1409,  0.3968, -2.6861,
           2.5686, -0.4801, -0.0055,  0.0073,  1.6582,  0.7641, -0.6670, -1.4761,
           0.3748, -0.7810,  0.5270, -2.5496,  0.8778,  0.4304, -0.1256, -0.2407,
          -0.3008, -4.7570, -0.3675, -0.4609,  0.1293,  1.9294,  0.0860, -0.1600,
           1.4306, -0.4745, -0.2443,  0.4694, -0.3394,  0.3838, -0.0125,  1.1362,
          -0.972

In [186]:
class Decoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = nn.Linear(z_dim, 64*6*6)

        self.convT1 = nn.ConvTranspose2d(64, 64, 3, 2, 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.dropout1 = nn.Dropout()

        self.convT2 = nn.ConvTranspose2d(64, 64, 3, 2, 1)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout2 = nn.Dropout()

        self.convT3 = nn.ConvTranspose2d(64, 32, 3, 2, 1)
        self.bn3 = nn.BatchNorm2d(32)
        self.dropout3 = nn.Dropout()

        self.convT4 = nn.ConvTranspose2d(32, 3, 3, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = x.reshape(-1, 64, 6, 6)

        x = F.pad(self.convT1(x), (2,1,2,1))      
        x = self.bn1(x)
        x = F.leaky_relu(x)
        x = self.dropout1(x)

        x = F.pad(self.convT2(x), (2,1,2,1))      
        x = self.bn2(x)
        x = F.leaky_relu(x)
        x = self.dropout2(x)

        x = F.pad(self.convT3(x), (2,1,2,1))      
        x = self.bn3(x)
        x = F.leaky_relu(x)
        x = self.dropout3(x)

        x = F.pad(self.convT4(x), (2,1,2,1)) 
        
        return torch.sigmoid(x)


In [187]:
Decoder().to(device)(torch.randn((1, 200)).to(device)).shape

torch.Size([1, 3, 128, 128])

In [188]:
class AutoEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
    def forward(self, x):
        t, mu, log_var = self.encoder(x)
        x = self.decoder(t)
        return x, t, mu, log_var

In [189]:
model =AutoEncoder().to(device)
model

AutoEncoder(
  (encoder): Encoder(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout1): Dropout(p=0.5, inplace=False)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout2): Dropout(p=0.5, inplace=False)
    (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout3): Dropout(p=0.5, inplace=False)
    (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout4): Dropout(p=0.5, inplace=False)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mu): Linear(in_features=4096, out_features=200, bias=True)
    (log_var): Linear(in_features=4096, out_features=2

In [190]:
R_LOSS_FACTOR = 100

def kl_loss(mu, log_var):
    kl_loss =  -0.5 * torch.sum(1 + log_var - torch.square(mu) - torch.exp(log_var))
    return kl_loss


In [191]:
optimizer = optim.Adam(model.parameters(), lr=0.0005)

criterion = nn.MSELoss()

In [192]:
for epoch in range(200):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, inputs in enumerate(dataloader, 0):
        
        # get the inputs; data is a list of [inputs, labels]
        inputs = inputs.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, _, mu, log_var = model(inputs)
        loss = criterion(outputs, inputs) * R_LOSS_FACTOR

        loss += kl_loss(mu, log_var)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # print statistics
        running_loss += loss.item()
    print(f'[{epoch + 1}] loss: {running_loss / len(dataloader):.3f}')
    torch.save(model, RUN_FOLDER + "/weights/weight.pt")

print('Finished Training')

[1] loss: 72.955
[2] loss: 16.845
[3] loss: 16.421
[4] loss: 16.348
[5] loss: 16.306
[6] loss: 17.016
[7] loss: 16.379
[8] loss: 16.407
[9] loss: 16.452
[10] loss: 16.437
