In [None]:
import torch
from diffusers import DDIMScheduler, AutoencoderKL, EulerDiscreteScheduler

from w_plus_adapter import StableDiffusionWPlusPipeline as StableDiffusionPipeline

import w_plus_adapter
from script.utils_direction import *
import os.path as osp


'''
model parameter settings
'''
base_model_path = "runwayml/stable-diffusion-v1-5"
# base_model_path = 'dreamlike-art/dreamlike-anime-1.0' #animate model
# base_model_path = 'darkstorm2150/Protogen_x3.4_Official_Release'

vae_model_path = "stabilityai/sd-vae-ft-mse"
device = "cuda"
wp_ckpt = './pretrain_models/wplus_adapter.bin'


vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
)

pipe.scheduler = EulerDiscreteScheduler.from_config(
    pipe.scheduler.config
)

if not osp.exists('./pretrain_models/wplus_adapter.bin'):
    download_url = 'https://github.com/csxmli2016/w-plus-adapter/releases/download/v1/wplus_adapter.bin'
    load_file_from_url(url=download_url, model_dir='./pretrain_models/', progress=True, file_name=None)

wp_model = w_plus_adapter.WPlusAdapter(pipe, wp_ckpt, device)


In [None]:

'''
Parameter settings
'''

which_w_direction = 'interfacegan' # interfacegan, ganspace, latentdirection # 
prompt = 'a woman wearing a spacesuit'
seed = 23

e4e_id_path = './test_data/e4e/1.pth'

imgs = []
direction_length = 4
residual_att_scale = 1.0
num_inference_steps = 30
start_embed_step_ratio = 20 #means the former 20% steps did not use w injection
use_freeu = True

scales = list(range(-direction_length, direction_length, 2))

id_noise0 = torch.load(e4e_id_path, map_location='cpu')

for i in scales:
    if which_w_direction == 'interfacegan':
        att = 'smile' # age, smile
        direction = get_direction_from_interfacegan(att)
        id_noise = id_noise0 + i * direction

    if which_w_direction == 'ganspace':
        att = 'eyes_open' #eyes_open, emotion_angry, emotion_disgust, emotion_fear, emotion_sad, emotion_happy, emotion_surprise,  gender, width
        direction = get_direction_from_latentdirection(att)
        id_noise[:, :8, ...] = (id_noise0 + i * direction)[:, :8, ...]


    if which_w_direction == 'latentdirection':
        att = 'lipstick' #big_smile, face_roundness, lipstick, overexposed, short_face, smile
        id_noise = id_noise0 + get_direction_from_ganspace(att, id_noise0, i) 
    
    images = wp_model.generate_idnoise(prompt=prompt, w=id_noise.repeat(1, 1, 1).to(device, torch.float16), scale=residual_att_scale, num_samples=1, num_inference_steps=num_inference_steps, start_embed_step_ratio=start_embed_step_ratio, seed=seed, use_freeu=use_freeu, negative_prompt=None)

    imgs += images


grid = image_grid(imgs, 1, len(imgs))
display(grid)#or using the following code to save the results

'''
save_path = '0_results'
import time
TIMESTAMP = time.strftime("%m-%d_%H-%M", time.localtime())
os.makedirs(save_path, exist_ok=True)
grid.save(osp.join(save_path, '{}_{}.jpg'.format(TIMESTAMP, start_embed_step_ratio)))
'''
