In [None]:

import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL

import w_plus_adapter
from script.utils_direction import *
import os.path as osp


'''
model parameter settings
'''
base_model_path = "runwayml/stable-diffusion-v1-5"
# base_model_path = 'dreamlike-art/dreamlike-anime-1.0' #animate model
# base_model_path = 'darkstorm2150/Protogen_x3.4_Official_Release'

vae_model_path = "stabilityai/sd-vae-ft-mse"
device = "cuda"

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
)

if not osp.exists('./pretrain_models/wplus_adapter.bin'):
    download_url = 'https://github.com/csxmli2016/w-plus-adapter/releases/download/v1/wplus_adapter.bin'
    load_file_from_url(url=download_url, model_dir='./pretrain_models/', progress=True, file_name=None)



In [None]:

'''
Parameter settings
'''
wp_ckpt = './pretrain_models/wplus_adapter.bin'
which_w_direction = 'interfacegan' # interfacegan, ganspace, latentdirection # 
prompt = 'a woman wearing a red shirt in a garden'
seed = 23

e4e_id_path = './test_data/e4e/1.pth'

imgs = []
direction_length = 4
residual_att_scale = 0.9
use_freeu = True

scales = list(range(-direction_length, direction_length, 2))

wp_model = w_plus_adapter.WPlusAdapter(pipe, wp_ckpt, device)
id_noise0 = torch.load(e4e_id_path, map_location='cpu')

for i in scales:
    if which_w_direction == 'interfacegan':
        att = 'smile' # age, smile
        direction = get_direction_from_interfacegan(att)
        id_noise = id_noise0 + i * direction

    if which_w_direction == 'ganspace':
        att = 'eyes_open' #eyes_open, emotion_angry, emotion_disgust, emotion_fear, emotion_sad, emotion_happy, emotion_surprise,  gender, width
        direction = get_direction_from_latentdirection(att)
        id_noise[:, :8, ...] = (id_noise0 + i * direction)[:, :8, ...]


    if which_w_direction == 'latentdirection':
        att = 'lipstick' #big_smile, face_roundness, lipstick, overexposed, short_face, smile
        id_noise = id_noise0 + get_direction_from_ganspace(att, id_noise0, i) 
    
    images = wp_model.generate_idnoise(prompt=prompt, w=id_noise.repeat(1, 1, 1).to(device, torch.float16), scale=residual_att_scale, num_samples=1, num_inference_steps=50, seed=seed, use_freeu=use_freeu, negative_prompt=None)

    imgs += images


grid = image_grid(imgs, 1, len(imgs))
display(grid)
