In [52]:
import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from PIL import Image
from tqdm import tqdm
from models.unet_cond_base import Unet
from models.vqvae import VQVAE
from scheduler.linear_noise_scheduler import LinearNoiseScheduler
from scheduler.linear_noise_scheduler_ddim import LinearNoiseSchedulerDDIM
from utils.config_utils import *
from collections import OrderedDict
from datetime import datetime
from tools.sample_ddpm_attr_cond import ddim_inversion
from torchvision.transforms import Compose, Normalize

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [53]:
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
    """Given a target color (R, G, B) return a loss for how far away on average
    the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""
    target = torch.tensor(target_color).to(images.device) * 2 - 1  # Map target color to (-1, 1)
    target = target[None, :, None, None]  # Get shape right to work with the images (b, c, h, w)
    error = torch.abs(images - target).mean()  # Mean absolute difference between the image pixels and the target color
    return error

In [54]:
def glasses_loss(x, classifier_model, device='cuda'):
    # Create a resnet-18 model
    
    classifier_model.train()  # Ensure the model is in training mode

    # Move the input tensor `x` to the correct device
    x = x.to(device)

    transforms = Compose([
            Normalize(mean=[-0.5047, -0.2201,  0.0777], std=[1.0066, 0.8887, 0.6669])
        ])
    x = transforms(x)

    # Predict the glasses attribute
    pred = classifier_model(x)

    # Generate a target tensor with the same batch size as the input (assuming a binary classification task)
    target = torch.zeros(pred.size(0), 1).to(device)  # Assuming all targets are 1 (glasses present)

    # Calculate the loss using Binary Cross Entropy
    loss_fn = torch.nn.BCEWithLogitsLoss()
    loss = loss_fn(pred, target)

    # Return the loss with gradients enabled
    return loss

In [87]:
def sample(model, classifier_model, cond, scheduler, train_config, diffusion_model_config,
           autoencoder_model_config, diffusion_config, dataset_config, vae, use_ddim=False, start_step=0, num_steps=1000, noise_input=None, dir=''):
    r"""
    Sample stepwise by going backward one timestep at a time.
    We save the x0 predictions
    """    

    # seed random for reproducibility
    #torch.manual_seed(9)

    im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
    
    ########### Sample random noise latent ##########
    if noise_input is not None:
        xt = noise_input.to(device)
    else:
        xt = torch.randn((train_config['num_samples'],
                        autoencoder_model_config['z_channels'],
                        im_size,
                        im_size)).to(device)
    ###############################################


    ############# Validate the config #################
    condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
    assert condition_config is not None, ("This sampling script is for class conditional "
                                          "but no conditioning config found")
    condition_types = get_config_value(condition_config, 'condition_types', [])
    assert 'attribute' in condition_types, ("This sampling script is for attribute conditional "
                                          "but no class condition found in config")
    #validate_class_config(condition_config)
    ###############################################
    
    ############ Create Conditional input ###############
    num_classes = condition_config['attribute_condition_config']['attribute_condition_num']
    #sample_classes = torch.randint(0, num_classes, (train_config['num_samples'], ))
    #print('Generating images for {}'.format(list(sample_classes.numpy())))
    cond_input = {
        # 'class': torch.nn.functional.one_hot(sample_classes, num_classes).to(device)
        #  ['Male', 'Young', 'Bald', 'Bangs', 'Receding_Hairline', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair', 'Straight_Hair', 'Wavy_Hair', 'No_Beard', 'Goatee', 'Mustache', 'Sideburns', 'Narrow_Eyes', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose']
        'attribute': cond

    }
    # Unconditional input for classifier free guidance
    uncond_input = {
        'attribute': cond_input['attribute'] * 0
    }
    ###############################################
    
    # By default classifier free guidance is disabled
    # Change value in config or change default value here to enable it
    cf_guidance_scale = get_config_value(train_config, 'cf_guidance_scale', 1.0)
    
    current_time = datetime.now().strftime("%Y%m%d-%H%M%S")

    if not use_ddim:
        num_steps = diffusion_config['num_timesteps']
    
    ################# Sampling Loop ########################
    for i in tqdm(reversed(range(num_steps - start_step)), total=num_steps):
        torch.set_grad_enabled(True)
        xt_in = xt.clone()

        # activate gradient for xt
        xt.requires_grad_(True)
        
        timestep = ((i-1) * (1000 // num_steps)) + 1
        #print(timestep)
        
        # Get prediction of noise
        t = (torch.ones((xt.shape[0],))*timestep).long().to(device)
        
        
        noise_pred_cond = model(xt, t, cond_input)
        
        if cf_guidance_scale > 1:
            noise_pred_uncond = model(xt, t, uncond_input)
            noise_pred = noise_pred_uncond + cf_guidance_scale*(noise_pred_cond - noise_pred_uncond)
        else:
            noise_pred = noise_pred_cond
        
        # If DDIM is enabled, we need to also compute t_prev for the DDIM reverse process
        
        if use_ddim:
            t_prev = (torch.ones((xt.shape[0],)).to(device) * max(t - (1000 // num_steps), 1)).long().to(device)
            xt_new, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, t, t_prev)  # Use DDIM sampling
        else:
            xt_new, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))  # Use DDPM sampling
        

        #loss = color_loss(x0_pred) * 2
        # if not first step, use glasses loss

        loss = glasses_loss(x0_pred, classifier_model) * 0.5

        # set the loss to require grad
        
        if i % 10 == 0:
            print(i, "loss:", loss.item())

        cond_grad = -torch.autograd.grad(loss, xt, retain_graph=True)[0] 

        xt = xt + cond_grad
        print(cond_grad.max(), cond_grad.min())

        if use_ddim:
            t_prev = (torch.ones((xt.shape[0],)).to(device) * max(t - (1000 // num_steps), 1)).long().to(device)
            xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, t, t_prev)  # Use DDIM sampling
        else:
            xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))  # Use DDPM sampling

        if i == 0:
            # Decode ONLY the final image to save time
            ims = vae.decode(xt)
        else:
            ims = x0_pred
        
        ims = torch.clamp(ims, -1., 1.).detach().cpu()
        ims = (ims + 1) / 2
        grid = make_grid(ims, nrow=1)
        img = torchvision.transforms.ToPILImage()(grid)

        if not os.path.exists(os.path.join(train_config['task_name'], 'cond_attr_samples', dir, current_time)):
            os.makedirs(os.path.join(train_config['task_name'], 'cond_attr_samples', dir, current_time), exist_ok=True)
        img.save(os.path.join(train_config['task_name'], 'cond_attr_samples', dir, current_time, 'x0_{}.png'.format(i)))
        img.close()

        # save latent to pt        
        torch.save(xt, os.path.join(train_config['task_name'], 'cond_attr_samples', dir, current_time, 'xt_{}.pt'.format(i)))
    ##############################################################

    return ims, cond_input

    
def ddim_inversion(scheduler, vae, xt, diffusion_config, condition_input, model, train_config, num_inference_steps=None, dir='', save_img=True):
    r"""
    Reverse the process by diffusing the image forward in time.
    :param scheduler: the noise scheduler used (e.g., LinearNoiseSchedulerDDIM)
    :param vae: the variational autoencoder (VAE) to encode and decode images
    :param xt: image tensor that will be diffused forward
    :param diffusion_config: configuration for the diffusion process
    :param condition_input: the conditioning input for the image
    :param model: the diffusion model (e.g., Unet)
    :param train_config: the training configuration
    """

    xt = xt.to(device)  # Ensure image is on the correct device
    xt = (xt * 2) - 1  # Rescale from [0, 1] to [-1, 1] to match the model's input range

    # First, encode the image into latent space using the VAE
    z, _ = vae.encode(xt)

    all_timesteps = diffusion_config['num_timesteps']

    # If the number of inference steps is not provided, use all timesteps
    if num_inference_steps is None:
        num_timesteps = all_timesteps

    current_time = datetime.now().strftime("%Y%m%d-%H%M%S")

    
    intermediate_latents = []
    # Move forward in time by applying noise progressively
    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps):
        t_val = (i * (all_timesteps // num_inference_steps) + 1)
        t = (torch.ones((z.shape[0],)) * (i * (all_timesteps // num_inference_steps) + 1)).long().to(z.device)

        if i >= num_inference_steps - 1: continue

        # Predict noise based on current step and conditions
        noise_pred = model(z, t, condition_input)
        
        next_timestep = t
        current_timestep = max(0, t_val - (all_timesteps // num_inference_steps))
        current_timestep =  (torch.ones((z.shape[0],)) * current_timestep).long().to(z.device)

        # Use the noise prediction to forward-sample to the next timestep using DDIM forward equation
        # Reverse the reverse process from sample_prev_timestep
        alpha_t = scheduler.alpha_cum_prod.to(z.device)[current_timestep]
        alpha_t_next = scheduler.alpha_cum_prod.to(z.device)[next_timestep]
        
        '''
        z_next = (
            torch.sqrt(alpha_t_next) * z +
            torch.sqrt(1 - alpha_t_next) * noise_pred
        )
        '''

        z_next = (z - torch.sqrt(1 - alpha_t)[0] * noise_pred) * (torch.sqrt(alpha_t_next)[0] / torch.sqrt(alpha_t))[0] + torch.sqrt(1 - alpha_t_next)[0] * noise_pred
        
        # Optionally, if stochasticity is involved (if ddim_eta > 0), add noise at each step
        if scheduler.ddim_eta > 0:
            variance = (1 - alpha_t_next) / (1 - alpha_t) * scheduler.betas.to(z.device)[t]
            sigma = scheduler.ddim_eta * torch.sqrt(variance)
            z_next = z_next + sigma * torch.randn_like(z_next)
        
        z = z_next  # Move to the next time step

        intermediate_latents.append(z)

        if save_img:
            ims_clamped = torch.clamp(z, -1., 1.).detach().cpu()
            ims_clamped = (ims_clamped + 1) / 2  # Rescale to [0, 1]
            
            # Convert to image and save
            grid = make_grid(ims_clamped, nrow=1)
            img = torchvision.transforms.ToPILImage()(grid)
            
            # Save images at each step for visualization
            save_dir = os.path.join(train_config['task_name'], 'cond_attr_samples', dir, current_time)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir, exist_ok=True)

            
            # Save the image corresponding to the current timestep
            img.save(os.path.join(save_dir, 'x0_{}.png'.format(i)))
            img.close()

    # convert to torch tensor
    intermediate_latents = torch.stack(intermediate_latents, dim=0)

    # Return the final noisy latent z and the predicted noise used for the inversion
    return intermediate_latents

In [56]:
# Read the config file #
with open('celebhq-512-64-train-komondor_b/celeba_komondor_512_b.yaml', 'r') as file:
#with open('celebhq-1024-64-16k-komondor/celeba_komondor_16k.yaml', 'r') as file:
#with open('celebhq-512-64/celeba_komondor_512.yaml', 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
print(config)
########################

diffusion_config = config['diffusion_params']
dataset_config = config['dataset_params']
diffusion_model_config = config['ldm_params']
autoencoder_model_config = config['autoencoder_params']
train_config = config['train_params']
sample_config = config['sample_params']

########## Create the noise scheduler #############

if sample_config['use_ddim']:
    print('Using DDIM')
    scheduler = LinearNoiseSchedulerDDIM(num_timesteps=diffusion_config['num_timesteps'],
                                            beta_start=diffusion_config['beta_start'],
                                            beta_end=diffusion_config['beta_end'])
else:
    scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                    beta_start=diffusion_config['beta_start'],
                                    beta_end=diffusion_config['beta_end'])
###############################################

########## Load Unet #############
model = Unet(im_channels=autoencoder_model_config['z_channels'],
                model_config=diffusion_model_config).to(device)

if os.path.exists(os.path.join(train_config['task_name'],
                                train_config['ldm_ckpt_name'])):
    

    ddp_state_dict = torch.load(os.path.join(train_config['task_name'],
                                                    train_config['ldm_ckpt_name']),
                                        map_location=device)
    new_state_dict = OrderedDict()
    for k, v in ddp_state_dict.items():
        if k.startswith('module.'):
            name = k[7:] # remove `module.`
        new_state_dict[name] = v
    
    ddp_state_dict = new_state_dict
    print('Loaded unet checkpoint')
    model.load_state_dict(ddp_state_dict)
else:
    raise Exception('Model checkpoint {} not found'.format(os.path.join(train_config['task_name'],
                                                                        train_config['ldm_ckpt_name'])))

model.train()
#####################################

# Create output directories
if not os.path.exists(train_config['task_name']):
    os.mkdir(train_config['task_name'])

########## Load VQVAE #############
vae = VQVAE(im_channels=dataset_config['im_channels'],
            model_config=autoencoder_model_config).to(device)
vae.eval()

# Load vae if found
if os.path.exists(os.path.join(train_config['task_name'],
                                train_config['vqvae_autoencoder_ckpt_name'])):
    print('Loaded vae checkpoint')

    vae_state_dict = torch.load(os.path.join(train_config['task_name'],
                                                train_config['vqvae_autoencoder_ckpt_name']),
                                    map_location=device)
    
    
    new_state_dict = OrderedDict()

    for k, v in vae_state_dict.items():
        if k.startswith('module.'):
            name = k[7:]
        new_state_dict[name] = v   

    #new_state_dict = vae_state_dict     
    
    vae.load_state_dict(new_state_dict, strict=True)
else:
    raise Exception('VAE checkpoint {} not found'.format(os.path.join(train_config['task_name'],
                                                train_config['vqvae_autoencoder_ckpt_name'])))
#####################################
'''
classifier_model = torchvision.models.resnet18(pretrained=False)

num_ftrs = classifier_model.fc.in_features

# Modify the last fully connected layer for binary classification (1 output)
classifier_model.fc = torch.nn.Linear(num_ftrs, 1)

# Load weights from 'celeba_resnet18_latent_glasses_classifier_1.pth'
state = torch.load('celeba_resnet18_latent_glasses_classifier_1.pth', map_location=device)

new_state_dict = OrderedDict()
for k, v in state.items():
    name = k[7:]  # remove `module.`
    new_state_dict[name] = v

classifier_model.load_state_dict(new_state_dict)
'''




{'task_name': 'celebhq-512-64-train-komondor_b', 'continue': True, 'last_step': 0, 'last_epoch': 199, 'dataset_params': {'im_path': 'data/CelebAMask-HQ', 'im_channels': 3, 'im_size': 512, 'name': 'celebhq'}, 'diffusion_params': {'num_timesteps': 1000, 'beta_start': 0.0015, 'beta_end': 0.0195}, 'ldm_params': {'down_channels': [512, 768, 768, 1024], 'mid_channels': [1024, 768], 'down_sample': [True, True, True], 'attn_down': [True, True, True], 'time_emb_dim': 512, 'norm_channels': 32, 'num_heads': 16, 'conv_out_channels': 128, 'num_down_layers': 2, 'num_mid_layers': 2, 'num_up_layers': 2, 'condition_config': {'condition_types': ['attribute'], 'attribute_condition_config': {'attribute_condition_num': 19, 'attribute_condition_selected_attrs': ['Male', 'Young', 'Bald', 'Bangs', 'Receding_Hairline', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair', 'Straight_Hair', 'Wavy_Hair', 'No_Beard', 'Goatee', 'Mustache', 'Sideburns', 'Narrow_Eyes', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose']}}}, 'a

"\nclassifier_model = torchvision.models.resnet18(pretrained=False)\n\nnum_ftrs = classifier_model.fc.in_features\n\n# Modify the last fully connected layer for binary classification (1 output)\nclassifier_model.fc = torch.nn.Linear(num_ftrs, 1)\n\n# Load weights from 'celeba_resnet18_latent_glasses_classifier_1.pth'\nstate = torch.load('celeba_resnet18_latent_glasses_classifier_1.pth', map_location=device)\n\nnew_state_dict = OrderedDict()\nfor k, v in state.items():\n    name = k[7:]  # remove `module.`\n    new_state_dict[name] = v\n\nclassifier_model.load_state_dict(new_state_dict)\n"

In [57]:
im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])

torch.manual_seed(1)

noise_input = torch.randn((train_config['num_samples'],
                        autoencoder_model_config['z_channels'],
                        im_size,
                        im_size)).to(device)

#noise_input += feature_direction_disentangled.to(device) * 1

In [58]:
from dataset.celeb_dataset import CelebDataset

im_dataset = CelebDataset(split='val',
                                im_path=dataset_config['im_path'],
                                im_size=dataset_config['im_size'],
                                im_channels=dataset_config['im_channels'],
                                use_latents=False,
                                latent_path=os.path.join(train_config['task_name'],
                                                         train_config['vqvae_latent_dir_name']),
                                #condition_config=temp_conf['condition_config'],
                                condition_config=diffusion_model_config['condition_config'],
                                )

100%|██████████| 30000/30000 [00:00<00:00, 77351.29it/s]

Found 30000 images
Found 0 masks
Found 0 captions
Found 30000 attributes





In [59]:
sampled_im, sampled_cond = im_dataset[75]

In [60]:
current_time = datetime.now().strftime("%Y%m%d-%H%M%S")

In [61]:
print(sampled_im.shape)

torch.Size([3, 512, 512])


In [62]:
sampled_im = sampled_im.unsqueeze(0).to(device)

In [63]:
sampled_cond = sampled_cond['attribute']

In [64]:
sampled_cond = torch.from_numpy(sampled_cond).unsqueeze(0)

In [65]:
sampled_cond = {'attribute': sampled_cond.to(device)}

In [66]:
print(sampled_cond['attribute'].shape)

torch.Size([1, 19])


In [67]:
# ddim inversion
with torch.no_grad():
    intermediate_latents = ddim_inversion(scheduler, vae, sampled_im, diffusion_config, sampled_cond, model, train_config, 250, dir=current_time, save_img=True)

start_latent = intermediate_latents[-1]

  0%|          | 0/250 [00:00<?, ?it/s]

100%|█████████▉| 249/250 [00:43<00:00,  5.77it/s]


In [68]:
from models.simple_cnn import SimpleCNN

classifier_model = SimpleCNN()
classifier_model.load_state_dict(torch.load('celeba_cnn_latent_glasses_classifier_0.pth', map_location=device))


classifier_model.to(device)

SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=4096, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=1, bias=True)
)

In [69]:
start_latent = start_latent.detach().clone().to(device)

# add gradient to latent
start_latent.requires_grad_(True)

print(start_latent.shape)


torch.Size([1, 3, 64, 64])


In [70]:
['Male', 'Young', 'Bald', 'Bangs', 'Receding_Hairline', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair', 'Straight_Hair', 'Wavy_Hair', 'No_Beard', 'Goatee', 'Mustache', 'Sideburns', 'Narrow_Eyes', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose']

['Male',
 'Young',
 'Bald',
 'Bangs',
 'Receding_Hairline',
 'Black_Hair',
 'Blond_Hair',
 'Brown_Hair',
 'Gray_Hair',
 'Straight_Hair',
 'Wavy_Hair',
 'No_Beard',
 'Goatee',
 'Mustache',
 'Sideburns',
 'Narrow_Eyes',
 'Oval_Face',
 'Pale_Skin',
 'Pointy_Nose']

In [86]:
#with torch.no_grad():
for i in range(1):
    ims, cond_trans = sample(model, classifier_model, sampled_cond['attribute'], scheduler, train_config, diffusion_model_config,
                autoencoder_model_config, diffusion_config, dataset_config, vae, use_ddim=True, dir=current_time, noise_input=start_latent, num_steps=1000, start_step=0)

    

  0%|          | 1/1000 [00:18<5:13:51, 18.85s/it]

tensor(-0., device='cuda:0') tensor(-0., device='cuda:0')


  0%|          | 2/1000 [00:28<3:43:00, 13.41s/it]

tensor(-0., device='cuda:0') tensor(-0., device='cuda:0')
tensor(-0., device='cuda:0') tensor(-0., device='cuda:0')


  0%|          | 3/1000 [00:39<3:24:14, 12.29s/it]