In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets 
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import random_split, Subset
from CustomDataset import CustomDataset
from VanillaVAE import VanillaVAE 
# import wandb


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 3
#38804
Z_DIM = 1000
NUM_EPOCHS = 200
BATCH_SIZE = 32
LR_RATE = 3e-4
KL_COEFF_MAX = 0.0000025
PATH = "model.pt"

current_epoch = 0

In [None]:
def kl_coeff_set(epoch):
    # Calculate the beta value for the current epoch
    if epoch < NUM_EPOCHS / 2:
        coeff = (KL_COEFF_MAX / (NUM_EPOCHS / 2)) * epoch
    else:
        coeff = KL_COEFF_MAX

    return coeff

In [None]:

batch_size = 32

data_length = 1000#202599
dataset = CustomDataset("data/img_align_celeba", [(str(i).rjust(6, '0')+".jpg") for i in range(1,data_length+1)], transform=transforms.ToTensor())

train_size = int(data_length*0.8)

# Created using indices from 0 to train_size.
dataset_train = Subset(dataset, range(train_size))

# Created using indices from train_size to train_size + test_size.
dataset_val = Subset(dataset, range(train_size, data_length))

# dataset_train, dataset_val = random_split(dataset, [int(data_length*0.8), data_length- int(data_length*0.8)])


train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)

In [None]:
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="machine_run",
    
#     # track hyperparameters and run metadata
#     config={
#     "learning_rate": LR_RATE,
#     "architecture": "VAE",
#     "dataset": "CELEBA",
#     "epochs": NUM_EPOCHS,
#     }
# )

In [None]:
# Define train function
def train(num_epochs, model, optimizer, loss_fn):
    # Start training
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1} of {num_epochs}")
        loop = tqdm(enumerate(train_loader))
        epoch_loss = 0
        epoch_reconst_loss = 0
        epoch_kl_div = 0
        for i, x in loop:
            # Forward pass
            x = x.to(device) #.view(-1, INPUT_DIM)
            x_reconst, _, mu, sigma = model(x)

            # loss, formulas from https://www.youtube.com/watch?v=igP03FXZqgo&t=2182s
            reconst_loss, kl_div = loss_fn(x_reconst, x, mu, sigma)['loss']

            # Backprop and optimize

            kl_weight = kl_coeff_set(epoch)
            
            kl_div = kl_weight * kl_div
            
            loss = reconst_loss + kl_div
            
            
            #wandb.log({"examples": images}
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss = epoch_loss + loss.item()
            epoch_reconst_loss = epoch_reconst_loss + reconst_loss.item()
            epoch_kl_div = epoch_kl_div + kl_div.item()
            loop.set_postfix(loss=loss.item())
        
        # wandb.log({"total_loss": loss,
        #                "reconst_loss": reconst_loss,
        #                "kl_div": kl_div})


        if(epoch%2 == 0):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
                }, PATH)
            

# Initialize model, optimizer, loss
model = VanillaVAE(INPUT_DIM, Z_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = model.loss_function

train(NUM_EPOCHS, model, optimizer, model.loss_function)

In [None]:
# Run training


In [None]:
def validation(model, loss_fn):
    model.eval()
    with torch.no_grad():
        loss = 0
        loop = tqdm(enumerate(validation_loader), total=len(validation_loader), leave=False)
        for i, x in loop:
            x = x.to(device)
            x_reconst, _, mu, sigma = model(x)
            reconst_loss, kl_div = loss_fn(x_reconst, x, mu, sigma)['loss']
            loss = loss + reconst_loss + kl_div*KL_COEFF_MAX
        print("Validation loss: ", loss/len(validation_loader.dataset))


In [None]:
validation(model, model.loss_function)

In [None]:
def test_inference():

    import numpy as np
    import matplotlib.pyplot as plt
    from PIL import Image


    image = Image.open("data/img_align_celeba/000004.jpg")
    transform=transforms.ToTensor()

    encodings = []
    with torch.no_grad():
        mu, sigma = model.encode(transform(image).unsqueeze(0))
        encodings.append((mu, sigma))

    mu, sigma = encodings[0]

    epsilon = torch.randn_like(sigma)
    z = mu + sigma * epsilon
    out = model.decode(z)
    out = out.view(-1, 3, 224, 192)
    
    out = out.cpu().detach().numpy()
    out = np.transpose(out, (0, 2, 3, 1))
    plt.imshow(out[0])

test_inference()

In [None]:
def generate_new():
    import numpy as np
    import matplotlib.pyplot as plt
    z = torch.randn(1, Z_DIM).to(device)
    out = model.decode(z)
    out = out.view(-1, 3, 224, 192)
    #save_image(out, f"generated_ex.png")
    # change to PIL and plot
    out = out.cpu().detach().numpy()
    out = np.transpose(out, (0, 2, 3, 1))
    plt.imshow(out[0])

generate_new()
