## MasaCtrl: Tuning-free <span style="text-decoration: underline"><font color="Tomato">M</font></span>utu<span style="text-decoration: underline"><font color="Tomato">a</font></span>l <span style="text-decoration: underline"><font color="Tomato">S</font></span>elf-<span style="text-decoration: underline"><font color="Tomato">A</font></span>ttention <span style="text-decoration: underline"><font color="Tomato">Control</font></span> for Consistent Image Synthesis and Editing

Pytorch implementation of [MasaCtrl: Tuning-free Mutual Self-Attention Control for **Consistent Image Synthesis and Editing**](https://arxiv.org/abs/2304.08465)

[Mingdeng Cao](https://github.com/ljzycmd),
[Xintao Wang](https://xinntao.github.io/),
[Zhongang Qi](https://scholar.google.com/citations?user=zJvrrusAAAAJ),
[Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ),
[Xiaohu Qie](https://scholar.google.com/citations?user=mk-F69UAAAAJ),
[Yinqiang Zheng](https://scholar.google.com/citations?user=JD-5DKcAAAAJ)

[![arXiv](https://img.shields.io/badge/ArXiv-2304.08465-brightgreen)](https://arxiv.org/abs/2304.08465)
[![Project page](https://img.shields.io/badge/Project-Page-brightgreen)](https://ljzycmd.github.io/projects/MasaCtrl/)
[![demo](https://img.shields.io/badge/Demo-Hugging%20Face-brightgreen)](https://huggingface.co/spaces/TencentARC/MasaCtrl)

---

<div align="center">
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/overview.png">
<i> MasaCtrl enables performing various consistent non-rigid image synthesis and editing without fine-tuning and optimization. </i>
</div>

In [None]:
!pwd

/content


In [None]:
!git clone https://github.com/TencentARC/MasaCtrl.git

fatal: destination path 'MasaCtrl' already exists and is not an empty directory.


In [None]:
%cd MasaCtrl
!pip install -r requirements.txt

/content/MasaCtrl
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf

from diffusers import DDIMScheduler

from masactrl.diffuser_utils import MasaCtrlPipeline
from masactrl.masactrl_utils import AttentionBase
from masactrl.masactrl_utils import regiter_attention_editor_diffusers

from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything

torch.cuda.set_device(0)  # set the GPU device

### Model Construction

In [None]:
# Note that you may add your Hugging Face token to get access to the models
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_path = "andite/anything-v4.0"
# model_path = "runwayml/stable-diffusion-v1-5"
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)

Keyword arguments {'cross_attention_kwargs': {'scale': 0.5}} are not expected by MasaCtrlPipeline and will be ignored.
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.15.0",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "steps_offset": 0,
  "thresholding": false,
  "trained_betas": null
}
 is outdate

#### Consistent synthesis with MasaCtrl

In [None]:
from masactrl.masactrl import MutualSelfAttentionControl


seed = 42
seed_everything(seed)

out_dir = "./workdir/masactrl_exp/"
os.makedirs(out_dir, exist_ok=True)
sample_count = len(os.listdir(out_dir))
out_dir = os.path.join(out_dir, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)

prompts = [
    "1boy, casual, outdoors, sitting",  # source prompt
    "1boy, casual, outdoors, standing"  # target prompt
]

# initialize the noise map
start_code = torch.randn([1, 4, 64, 64], device=device)
start_code = start_code.expand(len(prompts), -1, -1, -1)

# inference the synthesized image without MasaCtrl
editor = AttentionBase()
regiter_attention_editor_diffusers(model, editor)
image_ori = model(prompts, latents=start_code, guidance_scale=7.5).cpu()

# inference the synthesized image with MasaCtrl
STEP = 4
LAYPER = 10

# hijack the attention module
editor = MutualSelfAttentionControl(STEP, LAYPER)
regiter_attention_editor_diffusers(model, editor)

# inference the synthesized image
image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)[-1:].cpu()

# save the synthesized image
out_image = torch.cat([image_ori, image_masactrl], dim=0)
save_image(out_image, os.path.join(out_dir, f"all_step{STEP}_layer{LAYPER}.png"))
save_image(out_image[0], os.path.join(out_dir, f"source_step{STEP}_layer{LAYPER}.png"))
save_image(out_image[1], os.path.join(out_dir, f"without_step{STEP}_layer{LAYPER}.png"))
save_image(out_image[2], os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYPER}.png"))

print("Syntheiszed images are saved in", out_dir)

INFO:lightning_fabric.utilities.seed:Global seed set to 42
  deprecate(


input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [00:56<00:00,  1.14s/it]


step_idx:  [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
layer_idx:  [10, 11, 12, 13, 14, 15]
input text embeddings : torch.Size([2, 77, 768])
latents shape:  torch.Size([2, 4, 64, 64])


DDIM Sampler: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]


Syntheiszed images are saved in ./workdir/masactrl_exp/sample_2
