In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets 
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import random_split, Subset
from CustomDataset import CustomDataset
from VanillaVAE import VanillaVAE 


In [11]:
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 3
Z_DIM = 1000
H_DIM = 2000
NUM_EPOCHS = 50
BATCH_SIZE = 32
LR_RATE = 3e-4

PATH = "model.pt"


In [4]:
178*218*3

116412

In [12]:

batch_size = 32

data_length = 10000#202599
dataset = CustomDataset("data/img_align_celeba", [(str(i).rjust(6, '0')+".jpg") for i in range(1,data_length+1)], transform=transforms.ToTensor())

train_size = int(data_length*0.8)

# Created using indices from 0 to train_size.
dataset_train = Subset(dataset, range(train_size))

# Created using indices from train_size to train_size + test_size.
dataset_val = Subset(dataset, range(train_size, data_length+1))

# dataset_train, dataset_val = random_split(dataset, [int(data_length*0.8), data_length- int(data_length*0.8)])


train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)

In [6]:
# Define train function
def train(num_epochs, model, optimizer, loss_fn):
    # Start training
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1} of {num_epochs}")
        loop = tqdm(enumerate(train_loader))
        epoch_loss = 0
        for i, x in loop:
            # Forward pass
            x = x.to(device) #.view(-1, INPUT_DIM)
            x_reconst,_, mu, sigma = model(x)

            # loss, formulas from https://www.youtube.com/watch?v=igP03FXZqgo&t=2182s
            reconst_loss = loss_fn(x_reconst, x, mu, sigma)['loss']
            kl_div = - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

            # Backprop and optimize
            loss = reconst_loss + kl_div
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss = epoch_loss + loss.item()
            loop.set_postfix(loss=loss.item())
        
        if(epoch%2 == 0):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
                }, PATH)

In [7]:
# Initialize model, optimizer, loss
model = VanillaVAE(INPUT_DIM, Z_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

In [8]:
# Run training
train(NUM_EPOCHS, model, optimizer, model.loss_function)

Epoch 1 of 50


500it [17:23,  2.09s/it, loss=1.98e+4]


Epoch 2 of 50


500it [9:50:42, 70.88s/it, loss=2.33e+4]  


Epoch 3 of 50


500it [21:43,  2.61s/it, loss=3.89e+4]


Epoch 4 of 50


500it [21:30,  2.58s/it, loss=5.22e+4] 


Epoch 5 of 50


500it [25:17,  3.04s/it, loss=5.16e+4]


Epoch 6 of 50


449it [1:28:02, 11.77s/it, loss=4.93e+4] 


KeyboardInterrupt: 

In [None]:
def validation(model, loss_fn):
    model.eval()
    with torch.no_grad():
        loss = 0
        for i, x in enumerate(validation_loader):
            x = x.to(device).view(-1, INPUT_DIM)
            x_reconst, mu, sigma = model(x)
            reconst_loss = loss_fn(x_reconst, x)
            kl_div = - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
            loss = loss + (reconst_loss + kl_div).item()
        print("Validation loss: ", loss/len(validation_loader.dataset))


In [None]:
validation(model, loss_fn)

NameError: name 'validation_loader' is not defined

In [None]:
def test_inference():

    images = []
    idx = 0
    for x in dataset:
        images.append(x)
        idx += 1
        if idx == 10:
            break

    encodings = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, INPUT_DIM))
        encodings.append((mu, sigma))

    mu, sigma = encodings[0]
    epsilon = torch.randn_like(sigma)
    z = mu + sigma * epsilon
    out = model.decode(z)
    out = out.view(-1, 3, 178, 218)
    save_image(out, f"generated_ex.png")

In [None]:
test_inference()