In [1]:
import os
import wandb
from PIL import Image
from torchvision.transforms import ToPILImage
import numpy as np

import torch
import torchvision.utils as vutils
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

from pythae.models import AutoModel
from pythae.data.datasets import DatasetOutput
from pythae.samplers import (
    TwoStageVAESampler,
    TwoStageVAESamplerConfig,
    NormalSampler
)
from pythae.trainers import BaseTrainerConfig

device = torch.device("mps") #note this may need to be cpu for some ppl

In [None]:
if torch.cuda.is_available():
        # get the name of the GPU
        gpu_name = torch.cuda.get_device_name(0)
        print(f"CUDA is available. GPU Name: {gpu_name}")

In [None]:
### IMAGE DATASET pytorch class
class UTKFaceDataset(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None):
        super().__init__(root=root, transform=transform, target_transform=target_transform)

    def __getitem__(self, index):
        path, _ = self.imgs[index]
        X, _ = super().__getitem__(index)
        filename = os.path.basename(path)
        return DatasetOutput(data=X, filename=filename)

transform = transforms.Compose([transforms.ToTensor(),])

#Create UTKFACE dataset
all_dataset = UTKFaceDataset(root="./data", transform=transform)

#Split UTKFACE dataset into train and eval sets
total_size = len(all_dataset)
train_size = 20000
eval_size = total_size - train_size 

#split dataset randomly into train and eval datasets
train_dataset, eval_dataset = random_split(all_dataset, [train_size, eval_size])


In [None]:
indices = torch.randperm(len(all_dataset))[:25]
images = [all_dataset[i]['data'] for i in indices]
grid = vutils.make_grid(images, nrow=5)
grid_pil = transforms.ToPILImage()(grid)
grid_pil

In [None]:
wandb.init(
    project='VAE_UTKFACE',
    entity='charlesdoyne'
)

In [None]:
wandb.log({"UTKFace dataset": [wandb.Image(grid_pil, caption="Images used to train main VAE")]})

In [None]:
### LOAD LATEST MODEL
"""
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))
"""
trained_model = AutoModel.load_from_folder("./my_model/bestmodel/final_model")


In [None]:
### GENERATION WITH NORMAL SAMPLER
normal_sampler = NormalSampler(
    model=trained_model
)

gen_data = normal_sampler.sample(
    num_samples=25
)


In [None]:
grid = vutils.make_grid(gen_data, nrow=5, normalize=True, scale_each=True)
grid_np = grid.permute(1, 2, 0).numpy()  # Convert to HxWxC layout
grid_pil = Image.fromarray((grid_np * 255).astype('uint8'), 'RGB')
grid_pil

In [None]:
wandb.log({"Normal Sampler Generations": wandb.Image(grid_pil, caption="Images generated with Normal Sampler")})

In [None]:
trained_model = trained_model.to(device)

reconstructions = []
for i in range(25):
    dataset_output = eval_dataset[i]
    image = dataset_output.data
    image = image.to(device) 
    reconstruction = trained_model.reconstruct(image.unsqueeze(0)).detach().cpu()
    reconstructions.append(reconstruction)

In [None]:
num_pairs = 8

image_pairs = []
for i in range(num_pairs):
    original = eval_dataset[i].data
    original = original.to(device) 
    reconstructed = reconstructions[i]
    original_cpu = original.cpu()
    pair = torch.cat((original_cpu, reconstructed.squeeze(0)), 2) 
    image_pairs.append(pair)

n_rows = 2
grid = vutils.make_grid(image_pairs, nrow=n_rows, padding=2)
grid_np = grid.cpu().numpy()
grid_np = np.transpose(grid_np, (1, 2, 0))
grid_np = grid_np - grid_np.min()  # Normalize to the range 0 - max
grid_np = grid_np / grid_np.max()  # Normalize to the range 0 - 1
grid_np = (grid_np * 255).astype(np.uint8)  # Scale to range 0 - 255
grid_pil = Image.fromarray(grid_np)
grid_pil

In [None]:
wandb.log({"VAE reconstructions": wandb.Image(grid_pil, caption="How the encoder-decoder reconstructs images")})

In [None]:
gender = "Male"
race = "White"
age_input = "30-35"

gender_map = {"Any": None, "Male": 0, "Female": 1}
race_map = {"Any": None, "White": 0, "Black": 1, "Asian": 2, "Indian": 3}

gender_code = gender_map[gender]
race_code = race_map[race]

if '-' in age_input: 
    age_range = [int(a) for a in age_input.split('-')]
elif age_input: 
    age_range = [int(age_input), int(age_input)]
else:
    age_range = [0, 116]

filtered_dataset = []
for item in all_dataset:
    _, filename = item.data, item.filename
    parts = filename.split('_')
    
    if len(parts) == 4:
        file_age, file_gender, file_race, _ = parts
        age_condition = age_range[0] <= int(file_age) <= age_range[1]
        gender_condition = True if gender_code is None else int(file_gender) == gender_code
        race_condition = True if race_code is None else int(file_race) == race_code

        if age_condition and gender_condition and race_condition:
            filtered_dataset.append(item)

if filtered_dataset:
    num_images_to_show = min(len(filtered_dataset), 25)
    indices = torch.randperm(len(filtered_dataset))[:num_images_to_show]
    images = [filtered_dataset[i].data for i in indices]
    grid = vutils.make_grid(images, nrow=5)
    attribute_data_pil = ToPILImage()(grid)

    filtered_dataset_tensors = [item.data for item in filtered_dataset]  
    filtered_dataset_tensor = torch.stack(filtered_dataset_tensors)

    split_index = int(0.8 * len(filtered_dataset_tensor))
    filtered_dataset_train = filtered_dataset_tensor[:split_index]
    filtered_dataset_eval = filtered_dataset_tensor[split_index:]

In [None]:
indices = torch.randperm(len(filtered_dataset))[:25]
images = [filtered_dataset[i]['data'] for i in indices]
grid = vutils.make_grid(images, nrow=5)
grid_pil = transforms.ToPILImage()(grid)
grid_pil

In [None]:
wandb.log({"Style Dataset": [wandb.Image(grid_pil, caption="Style dataset for second VAE")]})

In [None]:
### GENERATION WITH TWO STAGE VAE SAMPLER 
twostagevae_config = TwoStageVAESamplerConfig(
    reconstruction_loss='mse', 
)

twostage_sampler = TwoStageVAESampler(
    model=trained_model,
    sampler_config=twostagevae_config
)

two_stage_train_config = BaseTrainerConfig(
    output_dir='my_model',
    learning_rate=2e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_epochs=75, # Change this to train the model a bit more
)

twostage_sampler.fit(
    
    train_data=filtered_dataset_train,
    eval_data=filtered_dataset_eval,
    training_config=two_stage_train_config,
    
)

gen_data2 = twostage_sampler.sample(
    num_samples=25
)

In [None]:
grid = vutils.make_grid(gen_data2, nrow=5, normalize=True, scale_each=True)
grid_np = grid.permute(1, 2, 0).numpy()
grid_pil = Image.fromarray((grid_np * 255).astype('uint8'), 'RGB')
grid_pil

In [None]:
wandb.log({"Two Stage Sampler Sampler Generations": wandb.Image(grid_pil, caption="Images generated with Two Stage Sampler")})

In [None]:
wandb.finish()