In [10]:
from dataset.celeb_dataset import CelebDataset
import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from PIL import Image
from tqdm import tqdm
from models.unet_cond_base import Unet
from models.vqvae import VQVAE
from scheduler.linear_noise_scheduler import LinearNoiseScheduler
from scheduler.linear_noise_scheduler_ddim import LinearNoiseSchedulerDDIM
from utils.config_utils import *
from collections import OrderedDict
from datetime import datetime

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
# Read the config file #
#with open('celebhq-1024-64-16k-komondor/celeba_komondor_16k.yaml', 'r') as file:
#with open('celebhq-512-64/celeba_komondor_512.yaml', 'r') as file:
with open('celebhq-512-64-train-komondor_b/celeba_komondor_512_b.yaml', 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
print(config)
########################

diffusion_config = config['diffusion_params']
dataset_config = config['dataset_params']
diffusion_model_config = config['ldm_params']
autoencoder_model_config = config['autoencoder_params']
train_config = config['train_params']
sample_config = config['sample_params']

########## Create the noise scheduler #############

if sample_config['use_ddim']:
    print('Using DDIM')
    scheduler = LinearNoiseSchedulerDDIM(num_timesteps=diffusion_config['num_timesteps'],
                                            beta_start=diffusion_config['beta_start'],
                                            beta_end=diffusion_config['beta_end'])
else:
    scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                    beta_start=diffusion_config['beta_start'],
                                    beta_end=diffusion_config['beta_end'])
###############################################


#####################################



{'task_name': 'celebhq-512-64-train-komondor_b', 'continue': True, 'last_step': 0, 'last_epoch': 199, 'dataset_params': {'im_path': 'data/CelebAMask-HQ', 'im_channels': 3, 'im_size': 512, 'name': 'celebhq'}, 'diffusion_params': {'num_timesteps': 1000, 'beta_start': 0.0015, 'beta_end': 0.0195}, 'ldm_params': {'down_channels': [512, 768, 768, 1024], 'mid_channels': [1024, 768], 'down_sample': [True, True, True], 'attn_down': [True, True, True], 'time_emb_dim': 512, 'norm_channels': 32, 'num_heads': 16, 'conv_out_channels': 128, 'num_down_layers': 2, 'num_mid_layers': 2, 'num_up_layers': 2, 'condition_config': {'condition_types': ['attribute'], 'attribute_condition_config': {'attribute_condition_num': 19, 'attribute_condition_selected_attrs': ['Male', 'Young', 'Bald', 'Bangs', 'Receding_Hairline', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair', 'Straight_Hair', 'Wavy_Hair', 'No_Beard', 'Goatee', 'Mustache', 'Sideburns', 'Narrow_Eyes', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose']}}}, 'a

In [12]:
diffusion_model_config = config['ldm_params']
print(diffusion_model_config)

{'down_channels': [512, 768, 768, 1024], 'mid_channels': [1024, 768], 'down_sample': [True, True, True], 'attn_down': [True, True, True], 'time_emb_dim': 512, 'norm_channels': 32, 'num_heads': 16, 'conv_out_channels': 128, 'num_down_layers': 2, 'num_mid_layers': 2, 'num_up_layers': 2, 'condition_config': {'condition_types': ['attribute'], 'attribute_condition_config': {'attribute_condition_num': 19, 'attribute_condition_selected_attrs': ['Male', 'Young', 'Bald', 'Bangs', 'Receding_Hairline', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair', 'Straight_Hair', 'Wavy_Hair', 'No_Beard', 'Goatee', 'Mustache', 'Sideburns', 'Narrow_Eyes', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose']}}}


In [13]:
print(diffusion_model_config['condition_config']['attribute_condition_config']['attribute_condition_selected_attrs'] )

['Male', 'Young', 'Bald', 'Bangs', 'Receding_Hairline', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair', 'Straight_Hair', 'Wavy_Hair', 'No_Beard', 'Goatee', 'Mustache', 'Sideburns', 'Narrow_Eyes', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose']


In [14]:
temp_conf = diffusion_model_config.copy()

In [15]:
# add Eyeglasses to the condition
temp_conf['condition_config']['attribute_condition_config']['attribute_condition_selected_attrs'].extend(['Mouth_Slightly_Open'])
temp_conf['condition_config']['attribute_condition_config']['attribute_condition_num'] = len(temp_conf['condition_config']['attribute_condition_config']['attribute_condition_selected_attrs'])

In [16]:
im_dataset = CelebDataset(split='train',
                                im_path=dataset_config['im_path'],
                                im_size=dataset_config['im_size'],
                                im_channels=dataset_config['im_channels'],
                                use_latents=False,
                                latent_path=os.path.join(train_config['task_name'],
                                                         train_config['vqvae_latent_dir_name']),
                                condition_config=temp_conf['condition_config'],
                                #condition_config=diffusion_model_config['condition_config'],
                                )

100%|██████████| 30000/30000 [00:00<00:00, 78304.83it/s]

Found 30000 images
Found 0 masks
Found 0 captions
Found 30000 attributes





In [17]:
# go through the dataset find all image with 1 at position 19 in the condition
indexes_chubby = []
indexes_not_chubby = []
for i in tqdm(range(len(im_dataset))):
    _, cond = im_dataset[i]
    if cond['attribute'][19] == 1:
        indexes_chubby.append(i)
    elif len(indexes_not_chubby) < 2000:
        indexes_not_chubby.append(i)

    if len(indexes_chubby) == 2000 and len(indexes_not_chubby) == 2000:
        break

print(len(indexes_chubby))

 16%|█▌        | 4258/27000 [00:51<04:32, 83.31it/s]

2000





In [18]:
# save the indexes
import pickle

with open('indexes_mouth_open_2000.pkl', 'wb') as f:
    pickle.dump(indexes_chubby, f)

with open('indexes_mouth_closed_2000.pkl', 'wb') as f:
    pickle.dump(indexes_not_chubby, f)