In [165]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from torch.utils.data import Dataset, DataLoader
from PIL import Image


In [166]:
# run params
section = 'vae'
run_id = '0001'
data_name = 'faces'
RUN_FOLDER = 'run/{}/'.format(section)
RUN_FOLDER += '_'.join([run_id, data_name])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #


DATA_FOLDER = './data/celeb/'

In [167]:
NUM_CLASSES = 10
BATCH_SIZE = 32
transform = transforms.Compose([   
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
filenames = np.array(glob(os.path.join(DATA_FOLDER, '*/*.jpg')))
NUM_IMAGES = len(filenames)


class SimpleDataset(Dataset):
    def __init__(self, filenames, transform):
        self.filenames = filenames
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img = Image.open(self.filenames[idx])
        if self.transform:
            img = self.transform(img)
        return img

In [168]:
dataset = SimpleDataset(filenames, transform)
dataloader = DataLoader(dataset, BATCH_SIZE)
z_dim = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [169]:
next(iter(dataloader)).shape

torch.Size([32, 3, 128, 128])

In [170]:
class Encoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 2)
        self.bn1 = nn.BatchNorm2d(32)
        self.dropout1 = nn.Dropout()

        self.conv2 = nn.Conv2d(32, 64, 3, 2)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout2 = nn.Dropout()

        self.conv3 = nn.Conv2d(64, 64, 3, 2)
        self.bn3 = nn.BatchNorm2d(64)
        self.dropout3 = nn.Dropout()

        self.conv4 = nn.Conv2d(64, 64, 3, 2)
        self.bn4 = nn.BatchNorm2d(64)
        self.dropout4 = nn.Dropout()

        self.flatten = nn.Flatten()
        self.mu = nn.Linear(4096, z_dim)

        self.log_var = nn.Linear(4096, z_dim)

    def forward(self, x):
        x = F.pad(self.conv1(x), (1,0,1,0))
        x = self.bn1(x)
        x = F.leaky_relu(x)
        x = self.dropout1(x)

        x = F.pad(self.conv2(x), (1,0,1,0))
        x = self.bn2(x)
        x = F.leaky_relu(x)
        x = self.dropout2(x)

        x = F.pad(self.conv3(x), (1,0,1,0))
        x = self.bn3(x)
        x = F.leaky_relu(x)
        x = self.dropout3(x)

        x = F.pad(self.conv4(x), (1,0,1,0))
        x = self.bn4(x)
        x = F.leaky_relu(x)
        x = self.dropout4(x)

        x = self.flatten(x)

        mu, log_var = self.mu(x), self.log_var(x)

        def sampling(args):
            mu, log_var = args
            epsilon = torch.normal(0., 1., size=mu.shape).to(device)
            return mu + torch.exp(log_var / 2) * epsilon

        x = sampling([mu, log_var])

        return x, mu, log_var


In [171]:
Encoder().to(device)(torch.randn((1, 3, 128, 128)).to(device))

(tensor([[-5.1759e-01,  2.3929e+00,  1.3602e+00,  2.7202e+00,  2.4660e+00,
           1.3402e+00,  1.4359e+00,  4.8709e-01,  1.8287e-01,  1.0801e+00,
           9.8907e-01, -4.2054e-01, -1.3979e-01,  2.9253e-01,  4.7928e-01,
          -4.0135e-01, -2.7017e+00,  2.9082e-01, -6.1216e-01,  4.1141e-01,
          -1.9953e-01, -2.1355e-01,  2.1881e+00,  4.2528e+00,  1.4462e+00,
           1.5811e-01, -1.2270e+00,  6.7239e-01,  4.5569e-01, -8.3518e-01,
           7.1327e-01,  1.5285e+00, -1.3974e-01,  2.0771e-01,  2.5645e+00,
           5.4974e-01, -1.7486e+00,  1.7022e+00,  2.2719e+00, -2.6096e+00,
          -8.7677e-01,  7.9289e-01,  7.6394e-01,  2.8425e-01,  2.2144e-01,
           1.6229e+00, -2.1539e+00,  3.5489e-01, -4.8027e-01,  4.2101e-02,
          -6.0902e-01, -2.4413e+00,  5.6538e-01,  1.1811e+00,  8.5354e-01,
           6.5099e-01, -7.3052e-01, -1.7747e-02,  1.7858e-02, -1.1516e+00,
           1.0280e+00, -1.0589e+00, -2.4579e-01, -1.2576e+00, -1.3307e+00,
           8.2519e-01,  2

In [172]:
class Decoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = nn.Linear(z_dim, 64*6*6)

        self.convT1 = nn.ConvTranspose2d(64, 64, 3, 2, 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.dropout1 = nn.Dropout()

        self.convT2 = nn.ConvTranspose2d(64, 64, 3, 2, 1)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout2 = nn.Dropout()

        self.convT3 = nn.ConvTranspose2d(64, 32, 3, 2, 1)
        self.bn3 = nn.BatchNorm2d(32)
        self.dropout3 = nn.Dropout()

        self.convT4 = nn.ConvTranspose2d(32, 3, 3, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = x.reshape(-1, 64, 6, 6)

        x = F.pad(self.convT1(x), (2,1,2,1))      
        x = self.bn1(x)
        x = F.leaky_relu(x)
        x = self.dropout1(x)

        x = F.pad(self.convT2(x), (2,1,2,1))      
        x = self.bn2(x)
        x = F.leaky_relu(x)
        x = self.dropout2(x)

        x = F.pad(self.convT3(x), (2,1,2,1))      
        x = self.bn3(x)
        x = F.leaky_relu(x)
        x = self.dropout3(x)

        x = F.pad(self.convT4(x), (2,1,2,1)) 
        
        return torch.sigmoid(x)


In [173]:
Decoder().to(device)(torch.randn((1, 200)).to(device)).shape

torch.Size([1, 3, 128, 128])

In [174]:
class AutoEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
    def forward(self, x):
        t, mu, log_var = self.encoder(x)
        x = self.decoder(t)
        return x, t, mu, log_var

In [175]:
model =AutoEncoder().to(device)
model

AutoEncoder(
  (encoder): Encoder(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout1): Dropout(p=0.5, inplace=False)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout2): Dropout(p=0.5, inplace=False)
    (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout3): Dropout(p=0.5, inplace=False)
    (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout4): Dropout(p=0.5, inplace=False)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mu): Linear(in_features=4096, out_features=200, bias=True)
    (log_var): Linear(in_features=4096, out_features=2

In [176]:
R_LOSS_FACTOR = 10000

def kl_loss(mu, log_var):
    kl_loss =  -0.5 * torch.sum(1 + log_var - torch.square(mu) - torch.exp(log_var))
    return kl_loss


In [177]:
optimizer = optim.Adam(model.parameters(), lr=0.0005)

criterion = nn.MSELoss()

In [178]:
for epoch in range(200):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, inputs in enumerate(dataloader, 0):
        
        # get the inputs; data is a list of [inputs, labels]
        inputs = inputs.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, _, mu, log_var = model(inputs)
        loss = criterion(outputs, inputs) 

        loss += kl_loss(mu, log_var)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # print statistics
        running_loss += loss.item()
    print(f'[{epoch + 1}] loss: {running_loss / len(dataloader):.3f}')
    torch.save(model, RUN_FOLDER + "/weights/weight.pt")

print('Finished Training')