In [125]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from torch.utils.data import Dataset, DataLoader
from PIL import Image


In [126]:
# run params
section = 'vae'
run_id = '0001'
data_name = 'faces'
RUN_FOLDER = 'run/{}/'.format(section)
RUN_FOLDER += '_'.join([run_id, data_name])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #


DATA_FOLDER = './data/celeb/'

In [127]:
NUM_CLASSES = 10
BATCH_SIZE = 32
transform = transforms.Compose([   
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
filenames = np.array(glob(os.path.join(DATA_FOLDER, '*/*.jpg')))
NUM_IMAGES = len(filenames)


class SimpleDataset(Dataset):
    def __init__(self, filenames, transform):
        self.filenames = filenames
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img = Image.open(self.filenames[idx])
        if self.transform:
            img = self.transform(img)
        return img

In [128]:
dataset = SimpleDataset(filenames, transform)
dataloader = DataLoader(dataset, BATCH_SIZE)
z_dim = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [129]:
next(iter(dataloader)).shape

torch.Size([32, 3, 128, 128])

In [130]:
class Encoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 2)
        self.bn1 = nn.BatchNorm2d(32)
        self.dropout1 = nn.Dropout()

        self.conv2 = nn.Conv2d(32, 64, 3, 2)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout2 = nn.Dropout()

        self.conv3 = nn.Conv2d(64, 64, 3, 2)
        self.bn3 = nn.BatchNorm2d(64)
        self.dropout3 = nn.Dropout()

        self.conv4 = nn.Conv2d(64, 64, 3, 2)
        self.bn4 = nn.BatchNorm2d(64)
        self.dropout4 = nn.Dropout()

        self.flatten = nn.Flatten()
        self.mu = nn.Linear(4096, z_dim)

        self.log_var = nn.Linear(4096, z_dim)

    def forward(self, x):
        x = F.pad(self.conv1(x), (1,0,1,0))
        x = self.bn1(x)
        x = F.leaky_relu(x)
        x = self.dropout1(x)

        x = F.pad(self.conv2(x), (1,0,1,0))
        x = self.bn2(x)
        x = F.leaky_relu(x)
        x = self.dropout2(x)

        x = F.pad(self.conv3(x), (1,0,1,0))
        x = self.bn3(x)
        x = F.leaky_relu(x)
        x = self.dropout3(x)

        x = F.pad(self.conv4(x), (1,0,1,0))
        x = self.bn4(x)
        x = F.leaky_relu(x)
        x = self.dropout4(x)

        x = self.flatten(x)

        mu, log_var = self.mu(x), self.log_var(x)

        def sampling(args):
            mu, log_var = args
            epsilon = torch.normal(0., 1., size=mu.shape).to(device)
            return mu + torch.exp(log_var / 2) * epsilon

        x = sampling([mu, log_var])

        return x, mu, log_var


In [131]:
Encoder().to(device)(torch.randn((1, 3, 128, 128)).to(device))

(tensor([[-6.9292e-02, -1.7313e+00, -2.4216e+00,  1.8383e+00,  3.3043e-03,
           6.6487e-03,  6.0011e-01,  1.1177e-01, -2.5705e-01,  4.3092e-01,
          -2.8566e+00,  5.2822e-01, -1.2150e+00, -2.3023e+00, -5.6485e-02,
           1.3575e+00,  9.9283e-01, -4.7911e-02,  1.3520e+00,  5.8625e-02,
           6.8887e-01, -3.6084e-02, -2.9957e-01,  1.1842e+00,  5.5238e-01,
           4.6386e-01,  9.4920e-01, -5.2600e-01,  2.4729e+00,  4.4372e-01,
           1.0076e+00, -1.6362e-01, -2.3901e-01, -5.4897e-01,  1.1118e-01,
          -1.5194e+00, -1.2079e-01,  3.3674e-01,  5.8869e-01,  1.2844e+00,
          -1.5311e+00,  1.2524e+00, -3.6729e+00,  1.1425e+00, -1.5921e+00,
           5.8042e-01,  9.4766e-01,  9.1553e-01,  1.9330e+00, -3.5187e+00,
          -1.0907e+00, -1.5125e+00,  2.3162e+00, -1.4165e+00, -9.0981e-01,
           1.5136e+00, -7.3439e-01, -1.9492e-01, -3.8977e+00, -7.9682e-01,
          -1.0914e+00, -5.5134e-01,  1.7630e+00,  2.3743e-01, -8.0812e-02,
           2.7078e+00, -5

In [132]:
class Decoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = nn.Linear(z_dim, 64*6*6)

        self.convT1 = nn.ConvTranspose2d(64, 64, 3, 2, 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.dropout1 = nn.Dropout()

        self.convT2 = nn.ConvTranspose2d(64, 64, 3, 2, 1)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout2 = nn.Dropout()

        self.convT3 = nn.ConvTranspose2d(64, 32, 3, 2, 1)
        self.bn3 = nn.BatchNorm2d(32)
        self.dropout3 = nn.Dropout()

        self.convT4 = nn.ConvTranspose2d(32, 3, 3, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = x.reshape(-1, 64, 6, 6)

        x = F.pad(self.convT1(x), (2,1,2,1))      
        x = self.bn1(x)
        x = F.leaky_relu(x)
        x = self.dropout1(x)

        x = F.pad(self.convT2(x), (2,1,2,1))      
        x = self.bn2(x)
        x = F.leaky_relu(x)
        x = self.dropout2(x)

        x = F.pad(self.convT3(x), (2,1,2,1))      
        x = self.bn3(x)
        x = F.leaky_relu(x)
        x = self.dropout3(x)

        x = F.pad(self.convT4(x), (2,1,2,1)) 
        
        return torch.sigmoid(x)


In [133]:
Decoder().to(device)(torch.randn((1, 200)).to(device)).shape

torch.Size([1, 3, 128, 128])

In [134]:
class AutoEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
    def forward(self, x):
        t, mu, log_var = self.encoder(x)
        x = self.decoder(t)
        return x, t, mu, log_var

In [135]:
model =AutoEncoder().to(device)
model

AutoEncoder(
  (encoder): Encoder(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout1): Dropout(p=0.5, inplace=False)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout2): Dropout(p=0.5, inplace=False)
    (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout3): Dropout(p=0.5, inplace=False)
    (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout4): Dropout(p=0.5, inplace=False)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mu): Linear(in_features=4096, out_features=200, bias=True)
    (log_var): Linear(in_features=4096, out_features=2

In [161]:
R_LOSS_FACTOR = 10000

def kl_loss(mu, log_var):
    kl_loss =  -0.5 * torch.sum(1 + log_var - torch.square(mu) - torch.exp(log_var))
    return kl_loss


In [162]:
optimizer = optim.Adam(model.parameters(), lr=0.0005)

criterion = nn.MSELoss()

In [164]:
for epoch in range(200):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, inputs in enumerate(dataloader, 0):
        
        # get the inputs; data is a list of [inputs, labels]
        inputs = inputs.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, _, mu, log_var = model(inputs)
        loss = criterion(outputs, inputs) * R_LOSS_FACTOR

        loss += kl_loss(mu, log_var)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # print statistics
        running_loss += loss.item()
    print(f'[{epoch + 1}] loss: {running_loss / len(dataloader):.3f}')
    torch.save(model, RUN_FOLDER + "/weights/weight.pt")

print('Finished Training')

[1] loss: 1738.996
[2] loss: 1622.546
[3] loss: 1596.606
[4] loss: 1590.767
[5] loss: 1589.142
