In [1]:
import os

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from pythae.models import VAE, VAEConfig, AutoModel
from pythae.data.datasets import DatasetOutput
from pythae.samplers import (
    TwoStageVAESampler,
    TwoStageVAESamplerConfig,
)
from pythae.trainers import BaseTrainerConfig

import gradio as gr
import torchvision.utils as vutils
from torchvision.transforms import ToPILImage
import traceback
device = torch.device("mps")

class UTKFaceDataset(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None):
        super().__init__(root=root, transform=transform, target_transform=target_transform)
    
    def __getitem__(self, index):
        path, _ = self.imgs[index]
        X, _ = super().__getitem__(index)
        filename = os.path.basename(path)
        return DatasetOutput(data=X, filename=filename)


class DatasetOutput:
    def __init__(self, data, filename):
        self.data = data
        self.filename = filename

transform = transforms.Compose([transforms.ToTensor(),])
all_dataset = UTKFaceDataset(root="./data", transform=transform)

trained_model = AutoModel.load_from_folder("./my_model/bestmodel/final_model")

twostagevae_config = TwoStageVAESamplerConfig(
    reconstruction_loss='mse', 
)

twostage_sampler = TwoStageVAESampler(
    model=trained_model,
    sampler_config=twostagevae_config
)

two_stage_train_config = BaseTrainerConfig(
    output_dir='my_model',
    learning_rate=0.00014,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_epochs=75,
)

In [2]:
def filter_dataset(race, gender, age_input): #taks ~1000 seconds with entire dataset (23,000 images)
    try:
        gender_map = {"Any": None, "Male": 0, "Female": 1}
        race_map = {"Any": None, "White": 0, "Black": 1, "Asian": 2, "Indian": 3}
    
        gender_code = gender_map[gender]
        race_code = race_map[race]
    
        if '-' in age_input: 
            age_range = [int(a) for a in age_input.split('-')]
        elif age_input: 
            age_range = [int(age_input), int(age_input)]
        else:
            age_range = [0, 116]
    
        filtered_dataset = []
        for item in all_dataset:
            _, filename = item.data, item.filename
            parts = filename.split('_')
            
            if len(parts) == 4:
                file_age, file_gender, file_race, _ = parts
                age_condition = age_range[0] <= int(file_age) <= age_range[1]
                gender_condition = True if gender_code is None else int(file_gender) == gender_code
                race_condition = True if race_code is None else int(file_race) == race_code
    
                if age_condition and gender_condition and race_condition:
                    filtered_dataset.append(item)
        
        if filtered_dataset:
            num_images_to_show = min(len(filtered_dataset), 25)
            indices = torch.randperm(len(filtered_dataset))[:num_images_to_show]
            images = [filtered_dataset[i].data for i in indices]
            grid = vutils.make_grid(images, nrow=5)
            attribute_data_pil = ToPILImage()(grid)
    
            filtered_dataset_tensors = [item.data for item in filtered_dataset]  
            filtered_dataset_tensor = torch.stack(filtered_dataset_tensors)

            split_index = int(0.8 * len(filtered_dataset_tensor))
            filtered_dataset_train = filtered_dataset_tensor[:split_index]
            filtered_dataset_eval = filtered_dataset_tensor[split_index:]
    
            twostage_sampler.fit(
                train_data=filtered_dataset_train,
                eval_data=filtered_dataset_eval,
                training_config=two_stage_train_config,
            )
            
            # Sample from the model
            gen_data2 = twostage_sampler.sample(num_samples=25)
            condition_grid = vutils.make_grid(gen_data2, nrow=5, normalize=True, scale_each=True)
            conditioned_data_pil = ToPILImage()(condition_grid)
            
            return f"Found {len(filtered_dataset)} images", attribute_data_pil, conditioned_data_pil
        else:
            return "No images found", None, None

    except Exception as e:
        error_message = str(e)
        error_traceback = traceback.format_exc()
        full_error_message = f"An error occurred: {error_message}\n\nTraceback:\n{error_traceback}"
        print(full_error_message) 
        return full_error_message, None, None

In [3]:
genders = ["Any", "Male", "Female"]
races = ["Any", "White", "Black", "Asian", "Indian"]

iface = gr.Interface(
    fn=filter_dataset,
    inputs=[
        gr.Dropdown(choices=races, label="Race"),
        gr.Dropdown(choices=genders, label="Gender"),
        gr.Textbox(label="Age (specific, range like 20-30, or leave blank for all)")
    ],
    outputs=[
        "text",
        "image",
        "image"
    ]
)
iface.launch()
#works better using local URL


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


