### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf

from diffusers import DDIMScheduler

from masactrl.diffuser_utils import MasaCtrlPipeline
from masactrl.masactrl_utils import AttentionBase
from masactrl.masactrl_utils import regiter_attention_editor_diffusers

from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything

torch.cuda.set_device(0)  # set the GPU device

#### Model Construction

load Dreambooth tuned StableDiffusion Model

In [3]:
# DB tuned
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_path = "/root/DreamMatcher/concept_models/dreambooth/dreambooth-concept-original-prior-230821/black_cat/checkpoints/diffusers" 

scheduler = DDIMScheduler.from_pretrained(
        model_path, subfolder="scheduler", low_cpu_mem_usage=False
    )

model = MasaCtrlPipeline.from_pretrained(
    model_path,
    low_cpu_mem_usage=False,
    scheduler=scheduler,
    safety_checker=None,
).to(device)

The config attributes {'rescale_betas_zero_snr': False, 'timestep_spacing': 'leading'} were passed to DDIMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
The config attributes {'addition_embed_type': None, 'addition_embed_type_num_heads': 64, 'addition_time_embed_dim': None, 'attention_type': 'default', 'encoder_hid_dim_type': None, 'num_attention_heads': None, 'time_embedding_dim': None, 'transformer_layers_per_block': 1} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.


The config attributes {'force_upcast': True} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.
Some weights of the model checkpoint at /root/DreamMatcher/concept_models/dreambooth/dreambooth-concept-original-prior-230821/black_cat/checkpoints/diffusers/vae were not used when initializing AutoencoderKL: ['encoder.mid_block.attentions.0.to_out.0.bias', 'decoder.mid_block.attentions.0.to_q.weight', 'decoder.mid_block.attentions.0.to_k.weight', 'encoder.mid_block.attentions.0.to_q.weight', 'encoder.mid_block.attentions.0.to_out.0.weight', 'decoder.mid_block.attentions.0.to_v.weight', 'decoder.mid_block.attentions.0.to_k.bias', 'encoder.mid_block.attentions.0.to_k.weight', 'encoder.mid_block.attentions.0.to_v.bias', 'encoder.mid_block.attentions.0.to_k.bias', 'encoder.mid_block.attentions.0.to_q.bias', 'decoder.mid_block.attentions.0.to_q.bias', 'decoder.mid_block.attentions.0.to_v.bias', 'encoder.mid_block.attentions.

#### Consistent synthesis with MasaCtrl

In [10]:
from masactrl.masactrl import MutualSelfAttentionControl

prompts_list = [ 
    [
        "sks black_cat, casual, outdoors, laying",  # source prompt
        "sks black_cat, casual, outdoors, standing"  # target prompt
    ],
    [
        "sks black_cat, casual, outdoors, laying",  # source prompt
        "cat, casual, outdoors, standing"  # target prompt
    ],
    [
        "sks black_cat, casual, outdoors, laying",  # source prompt
        "cat, standing"  # target prompt
    ],
    [
        "sks black_cat, casual, outdoors, laying",  # source prompt
        "fat furry white cat, standing"  # target prompt
    ],
    [
        "sks black_cat, casual, outdoors, laying",  # source prompt
        "cat jumping in the sky"  # target prompt
    ],
    [
        "sks black_cat, casual, outdoors, laying",  # source prompt
        "cat running in the garden"  # target prompt
    ],
    [
        "sks black_cat, casual, outdoors, laying",  # source prompt
        "cat in the box"  # target prompt
    ]
]

# inference the synthesized image with MasaCtrl
# STEP = 6
# LAYPER = 10
start_step_layer = [
    (2, 10),
    (4, 10),
    (6, 10)
]

for STEP, LAYPER in start_step_layer:
    out_dir = f"./workdir/masactrl_db_{STEP}_{LAYPER}/"
    os.makedirs(out_dir, exist_ok=True)
    print(f"Starting step : {STEP}, Starting layer : {LAYPER}")
    
    for prompts in prompts_list:
        seed = 42
        seed_everything(seed)

        sample_count = len(os.listdir(out_dir))
        sample_dir = os.path.join(out_dir, f"sample_{sample_count}")
        os.makedirs(sample_dir, exist_ok=True)
        
        # initialize the noise map
        start_code = torch.randn([1, 4, 64, 64], device=device)
        start_code = start_code.expand(len(prompts), -1, -1, -1)

        # inference the synthesized image without MasaCtrl
        editor = AttentionBase()
        regiter_attention_editor_diffusers(model, editor)
        image_ori = model(prompts, latents=start_code, guidance_scale=7.5)

        # hijack the attention module
        editor = MutualSelfAttentionControl(STEP, LAYPER)
        regiter_attention_editor_diffusers(model, editor)

        # inference the synthesized image
        image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)[-1:]

        # save the synthesized image
        out_image = torch.cat([image_ori, image_masactrl], dim=0)
        save_image(out_image, os.path.join(sample_dir, f"all_step{STEP}_layer{LAYPER}.png"))
        save_image(out_image[0], os.path.join(sample_dir, f"source_step{STEP}_layer{LAYPER}.png"))
        save_image(out_image[1], os.path.join(sample_dir, f"without_step{STEP}_layer{LAYPER}.png"))
        save_image(out_image[2], os.path.join(sample_dir, f"masactrl_step{STEP}_layer{LAYPER}.png"))

        print("Syntheiszed images are saved in", sample_dir)

Seed set to 42


Starting step : 2, Starting layer : 10
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:10<00:00,  4.80it/s]


MasaCtrl at denoising steps:  [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]
Seed set to 42


Syntheiszed images are saved in ./workdir/masactrl_db_2_10/sample_0
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:10<00:00,  4.77it/s]


MasaCtrl at denoising steps:  [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Seed set to 42


Syntheiszed images are saved in ./workdir/masactrl_db_2_10/sample_1
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:10<00:00,  4.76it/s]


MasaCtrl at denoising steps:  [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Seed set to 42


Syntheiszed images are saved in ./workdir/masactrl_db_2_10/sample_2
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:10<00:00,  4.75it/s]


MasaCtrl at denoising steps:  [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Seed set to 42


Syntheiszed images are saved in ./workdir/masactrl_db_2_10/sample_3
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:10<00:00,  4.75it/s]


MasaCtrl at denoising steps:  [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Seed set to 42


Syntheiszed images are saved in ./workdir/masactrl_db_2_10/sample_4
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:10<00:00,  4.75it/s]


MasaCtrl at denoising steps:  [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Seed set to 42


Syntheiszed images are saved in ./workdir/masactrl_db_2_10/sample_5
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:10<00:00,  4.75it/s]


MasaCtrl at denoising steps:  [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MasaCtrl at U-Net layers:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]


Syntheiszed images are saved in ./workdir/masactrl_db_2_10/sample_6
