In [None]:
import sys
import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from attribute_control import EmbeddingDelta
from attribute_control.model import SDXL
from attribute_control.prompt_utils import get_mask, get_mask_regex
from diffusers import StableDiffusionXLPipeline

torch.set_float32_matmul_precision('high')

DEVICE = 'cuda:0'
DTYPE = torch.float16

model = SDXL(
    pipeline_type='diffusers.StableDiffusionXLPipeline',
    model_name='stabilityai/stable-diffusion-xl-base-1.0',
    pipe_kwargs={ 'torch_dtype': DTYPE, 'variant': 'fp16', 'use_safetensors': True },
    device=DEVICE
)

seed = 42
delay_relative = 0.20

In [None]:
gen_out_folder = '/content/drive/MyDrive/crime_image_generator/Data/crime/xlgen/'
edit_out_folder = '/content/drive/MyDrive/crime_image_generator/Data/crime/xledit/'
src_folder = '/content/drive/MyDrive/crime_image_generator/Data/crime/images/'

attrs_40 = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 
            'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 
            'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 
            'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 
            'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 
            'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']
delta_attrs = ['Bald', 'Young', 'Pale_Skin', 'Heavy_Makeup', 'Smiling', 'Wavy_Hair', 'Chubby']

def get_delta(path):
    delta = EmbeddingDelta(model.dims)
    state_dict = torch.load(path)
    delta.load_state_dict(state_dict['delta'])
    delta = delta.to(DEVICE)
    return delta

deltas = {
    'Bald': get_delta('./pretrained_deltas/person_bald.pt'), 
    'Young': get_delta('./pretrained_deltas/person_age.pt'), 
    'Pale_Skin': get_delta('./pretrained_deltas/person_pale.pt'), 
    'Heavy_Makeup': get_delta('./pretrained_deltas/person_makeup.pt'), 
    'Smiling': get_delta('./pretrained_deltas/person_smile.pt'), 
    'Wavy_Hair': get_delta('./pretrained_deltas/person_curly_hair.pt'), 
    'Chubby': get_delta('./pretrained_deltas/person_width.pt'), 
}

def apply_all_deltas(attr, emb, characterwise_mask, delta_names):
    for attr_name in delta_names:
        alpha = attr[attrs_40.index(attr_name)] * 4 - 2
        if attr_name == "Young":
            alpha = -alpha
        emb = deltas[attr_name].apply(emb, characterwise_mask, alpha)
    return emb

i = 0

file = os.listdir('/content/drive/MyDrive/crime_image_generator/Data/crime/xledit/')[i]
attrs = json.load(open('/content/drive/MyDrive/crime_image_generator/Data/crime/attrs_cont.json'))
captions = json.load(open('/content/drive/MyDrive/crime_image_generator/Data/crime/captions.json'))
attr = attrs[file]
prompt = "a portrait photo with high facial detailed. " + captions[file].lower() 
pattern_target = r'\b(woman)\b' if 'woman' in prompt else r'\b(man)\b'
characterwise_mask = get_mask_regex(prompt, pattern_target)
emb = model.embed_prompt(prompt) 
src_img = Image.open(f'/content/drive/MyDrive/crime_image_generator/Data/crime/images/{file}')
gen_img = model.sample_delayed(
    embs=[emb],
    embs_unmodified=[emb],
    embs_neg=[None],
    delay_relative=delay_relative,
    generator=torch.manual_seed(seed),
    guidance_scale=7.5,
    num_inference_steps=30,
)[0]
edit_img = model.sample_delayed(
    embs=[apply_all_deltas(attr, emb, characterwise_mask, delta_names = [])],
    embs_unmodified=[emb],
    embs_neg=[None],
    delay_relative=delay_relative,
    generator=torch.manual_seed(seed),
    guidance_scale=7.5,
    num_inference_steps=30,
)[0]
