In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf

from diffusers import DDIMScheduler

from diffuser_utils import MasaCtrlPipeline
from mcc_utils import AttentionBase
from mcc_utils import register_attention_editor_diffusers

from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Note that you may add your Hugging Face token to get access to the models
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model_path = "CompVis/stable-diffusion-v1-4" #trained from "laion-aesthetics v2 5+" and 10% dropping of the text-conditioning
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler, cross_attention_kwargs={"scale": 0.5}).to(device)

Keyword arguments {'cross_attention_kwargs': {'scale': 0.5}} are not expected by MasaCtrlPipeline and will be ignored.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.16.1",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "steps_offset": 0,
  "thresholding": false,
  "trained_betas": null
}
 is outdated. `steps_offset` should be set to 1 instead of 0. Please make sure to update the config accordingly as leaving `steps_offset` might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request fo

In [None]:
def check_unet(net):
    for name, net in net.named_children():
        print(net)
check_unet(model.unet)

In [None]:
seed = 42
seed_everything(seed)

prompts = [
    "1 boy",  # source prompt
]

# initialize the noise map
start_code = torch.randn([1, 4, 64, 64], device=device) #downsampling factor is 8, 512 -> 64
start_code = start_code.expand(len(prompts), -1, -1, -1)

# inference the synthesized image without MasaCtrl
editor = AttentionBase()
register_attention_editor_diffusers(model, editor)
image_ori = model(prompts, latents=start_code, guidance_scale=7.5, device = device) #__call__

In [3]:
#txt to img
from mcc import MutualSelfAttentionControl

seed = 42
seed_everything(seed)

out_dir = "./result"
os.makedirs(out_dir, exist_ok=True)
sample_count = len(os.listdir(out_dir))
out_dir = os.path.join(out_dir, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)

prompts = [
    #"1 boy and 1 car",  # source prompt
    "1 running boy and 1 running horse"  # target prompt
]

# initialize the noise map
start_code = torch.randn([1, 4, 64, 64], device=device) #downsampling factor is 8, 512 -> 64
start_code = start_code.expand(len(prompts), -1, -1, -1)

# inference the synthesized image without MasaCtrl
editor = AttentionBase()
register_attention_editor_diffusers(model, editor)
image_ori = model(prompts, latents=start_code, guidance_scale=7.5, device = device) #__call__

# inference the synthesized image with MasaCtrl
STEP = 4
LAYER = 10

# hijack the attention module
editor = MutualSelfAttentionControl(STEP, LAYER)
register_attention_editor_diffusers(model, editor)

# inference the synthesized image
image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5, device = device)[-1:]

# save the synthesized image
out_image = torch.cat([image_ori, image_masactrl], dim=0)
save_image(out_image, os.path.join(out_dir, f"all_step{STEP}_layer{LAYER}.png"))
save_image(out_image[0], os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
save_image(out_image[1], os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
save_image(out_image[2], os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))

print("Syntheiszed images are saved in", out_dir)

Seed set to 42
  latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)


negative text embeddings added : torch.Size([2, 77, 768])
latents shape:  torch.Size([1, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  7.04it/s]


image shape:  torch.Size([1, 3, 512, 512])
MasaCtrl at denoising steps:  [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
negative text embeddings added : torch.Size([2, 77, 768])
latents shape:  torch.Size([1, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.33it/s]


image shape:  torch.Size([1, 3, 512, 512])


In [None]:
#img to img

from mcc import MutualSelfAttentionControl
from torchvision.io import read_image

def load_image(image_path, device):
    image = read_image(image_path)
    image = image[:3].unsqueeze_(0).float() / 127.5 - 1.  # [-1, 1]
    image = F.interpolate(image, (512, 512))
    image = image.to(device)
    return image

seed = 42
seed_everything(seed)

out_dir = "./result"
os.makedirs(out_dir, exist_ok=True)
sample_count = len(os.listdir(out_dir))
out_dir = os.path.join(out_dir, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)

# source image
SOURCE_IMAGE_PATH = "./img/corgi.jpg"
source_image = load_image(SOURCE_IMAGE_PATH, device)

source_prompt = ""
target_prompt = "a photo of a running corgi"
prompts = [source_prompt, target_prompt]

# invert the source image
start_code, latents_list = model.invert(source_image,
                                        source_prompt,
                                        guidance_scale=7.5,
                                        num_inference_steps=50,
                                        return_intermediates=True, device = device)
start_code = start_code.expand(len(prompts), -1, -1, -1)

# results of direct synthesis
editor = AttentionBase()
register_attention_editor_diffusers(model, editor)
image_fixed = model([target_prompt],
                    latents=start_code[-1:],
                    num_inference_steps=50,
                    guidance_scale=7.5, device = device)

# inference the synthesized image with MasaCtrl
STEP = 4
LAYER = 10

# hijack the attention module
editor = MutualSelfAttentionControl(STEP, LAYER)
register_attention_editor_diffusers(model, editor)

# inference the synthesized image
image_masactrl = model(prompts,
                       latents=start_code,
                       guidance_scale=7.5, device = device)
# Note: querying the inversion intermediate features latents_list
# may obtain better reconstruction and editing results
# image_masactrl = model(prompts,
#                        latents=start_code,
#                        guidance_scale=7.5,
#                        ref_intermediate_latents=latents_list)

# save the synthesized image
out_image = torch.cat([source_image * 0.5 + 0.5,
                       image_masactrl[0:1],
                       image_fixed,
                       image_masactrl[-1:]], dim=0)
save_image(out_image, os.path.join(out_dir, f"all_step{STEP}_layer{LAYER}.png"))
save_image(out_image[0], os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
save_image(out_image[1], os.path.join(out_dir, f"reconstructed_source_step{STEP}_layer{LAYER}.png"))
save_image(out_image[2], os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
save_image(out_image[3], os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))

print("Syntheiszed images are saved in", out_dir)