# Attention couple
latent couple (two shot diffusion)の発展版です。機能自体はlatent coupleと同じですが、latent coupleに比べて高速に生成できる（はず）。さらにクオリティもあがる（といいなあ・・・）。

参考：latent couple (two shot diffusion)

https://note.com/kizamimi/n/nab766a7484fe

In [1]:
#diffusers==0.11以降のcross attentionはよくわかんない
!pip install diffusers==0.10.2 transformers accelerate safetensors

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting diffusers==0.10.2
  Downloading diffusers-0.10.2-py3-none-any.whl (503 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m503.1/503.1 KB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m82.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.17.0-py3-none-any.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.8/212.8 KB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors
  Downloading safetensors-0.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub>=0.10

In [2]:
!pip install xformers triton

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting xformers
  Downloading xformers-0.0.16-cp39-cp39-manylinux2014_x86_64.whl (50.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 MB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting triton
  Downloading triton-2.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (63.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.3/63.3 MB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
Collecting pyre-extensions==0.0.23
  Downloading pyre_extensions-0.0.23-py3-none-any.whl (11 kB)
Collecting typing-inspect
  Downloading typing_inspect-0.8.0-py3-none-any.whl (8.7 kB)
Collecting lit
  Downloading lit-15.0.7.tar.gz (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.3/132.3 KB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mypy-extensions

In [3]:
#実装をわかりやすくするためのシンプルな画像生成パイプラインです
#()[]やトークン長の拡張・CLIP skipはありません。
#samplerもDDIMのみ

import torch
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm


class AttentionCoupleGenerator:
    def __init__(self,model_id, dtype=torch.float32, device="cuda"):

        self.tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder='tokenizer')
        self.text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder='text_encoder').eval().to(device, dtype=dtype)
        
        self.vae = AutoencoderKL.from_pretrained(model_id, subfolder='vae').eval().to(device, dtype=dtype)
        self.vae.enable_slicing()
    
        self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder='unet').eval().to(device, dtype=dtype)
        self.unet.set_use_memory_efficient_attention_xformers(True)

        self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")

        self.dtype = dtype
        self.device = device

        self.hook_forwards(self.unet) #Cross Attention書き換え

    def encode_prompts(self, prompts):
        '''
        プロンプトをとーくんにしてtext_encoderの隠れ状態を出力する。
        promptsはリストであることを前提とする。
        '''
        with torch.no_grad():
            tokens = self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids.to(self.device)
            embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype = self.dtype)
        return embs

    #潜在変数からPillowに変換
    def decode_latents(self, latents):
        latents = 1 / 0.18215 * latents
        with torch.no_grad():    
            images = self.vae.decode(latents).sample
        images = (images / 2 + 0.5).clamp(0, 1)
        images = images.cpu().permute(0, 2, 3, 1).float().numpy()
        images = (images * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        return pil_images

    def __call__(self,prompts, negative_prompt, batch_size = 4, height:int = 512, width:int = 512, guidance_scale:float = 7.0, num_inference_steps:int = 50, twoshot_weight=0.8, end_steps:float = 1):
        '''
        prompts: 基本プロンプト、左プロンプト、右プロンプトの順番
        '''

        self.twoshot_weight = twoshot_weight
        #[main*b,left*b,right*b,neg*b] chunkで分けやすいように
        all_prompts = []
        for prompt in prompts:
            all_prompts.extend([prompt] * batch_size)
        all_prompts.extend([negative_prompt] * batch_size)

        #プロンプト、ネガティブプロンプトのtext_encoder出力
        text_embs = self.encode_prompts(all_prompts)

        #スケジューラーのtimestepを設定
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.scheduler.timesteps

        #初期ノイズ
        latents = torch.randn(batch_size, 4, height // 8, width // 8).to(self.device, dtype = self.dtype)
        latents = latents * self.scheduler.init_noise_sigma

        self.height = height // 8
        self.width = width // 8
        self.pixels = self.height * self.width 

        progress_bar = tqdm(range(num_inference_steps), desc="Total Steps", leave=False)

        self.couple = True
        for i,t in enumerate(timesteps):
            #入力を作成
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            #attention_coupleの終了条件
            if i > num_inference_steps * end_steps and self.couple:
                cond, _, _, negative = text_embs.chunk(4)
                text_embs = torch.cat([cond,negative])
                self.couple = False
            
            #推定ノイズ
            with torch.no_grad():
                noise_pred = self.unet(sample = latent_model_input,timestep = t,encoder_hidden_states=text_embs).sample

            #ネガティブプロンプト版CFG
            noise_pred_text, noise_pred_negative= noise_pred.chunk(2)
            noise_pred = noise_pred_negative + guidance_scale * (noise_pred_text - noise_pred_negative)

            #推定ノイズからノイズを取り除いたlatentsを求める
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

            progress_bar.update(1)
        
        images = self.decode_latents(latents)
        return images

    def hook_forward(self, module):
        #diffusers==0.10.2
        def forward(hidden_states, context=None, mask=None):
            batch_size, sequence_length, _ = hidden_states.shape

            query = module.to_q(hidden_states)

            #copy query
            if self.couple:
                #(q_cond, q_uncond) -> (q_cond,q_cond,q_cond,q_uncond)
                query_cond , query_uncond = query.chunk(2)
                query = torch.cat([query_cond, query_cond, query_cond, query_uncond])

            context = context if context is not None else hidden_states
            key = module.to_k(context)
            value = module.to_v(context)

            dim = query.shape[-1]

            query = module.reshape_heads_to_batch_dim(query)
            key = module.reshape_heads_to_batch_dim(key)
            value = module.reshape_heads_to_batch_dim(value)

            # TODO(PVP) - mask is currently never used. Remember to re-implement when used

            # attention, what we cannot get enough of
            if module._use_memory_efficient_attention_xformers:
                hidden_states = module._memory_efficient_attention_xformers(query, key, value)
                # Some versions of xformers return output in fp32, cast it back to the dtype of the input
                hidden_states = hidden_states.to(query.dtype)
            else:
                if module._slice_size is None or query.shape[0] // module._slice_size == 1:
                    hidden_states = module._attention(query, key, value)
                else:
                    hidden_states = module._sliced_attention(query, key, value, sequence_length, dim)

            #cond * (1-w) + [left right] * w
            if self.couple:
                rate = int((self.pixels // query.shape[1]) ** 0.5) #down sampleの割合を確認
                
                height = self.height // rate
                width = self.width // rate

                
                cond, left, right, uncond = hidden_states.chunk(4)
                
                #画像の形に変換
                left = left.reshape(left.shape[0],  height, width, left.shape[2])
                right = right.reshape(right.shape[0],  height, width, right.shape[2])

                #左右を分割して合体
                couple = torch.cat([left[:,:,:width//2,:], right[:,:,width//2:,:]], dim=2)
                couple = couple.reshape(cond.shape[0], -1,  cond.shape[2])

                #重みづけ和
                cond = cond * (1 - self.twoshot_weight) + couple * self.twoshot_weight

                #uncondをくっつけておわり
                hidden_states = torch.cat([cond,uncond])

            # linear proj
            hidden_states = module.to_out[0](hidden_states)
            # dropout
            hidden_states = module.to_out[1](hidden_states)

            return hidden_states

        return forward
    
    #名前が分かりづらい
    def hook_forwards(self, root_module: torch.nn.Module):
        for name, module in root_module.named_modules():
            if "attn2" in name and module.__class__.__name__ == "CrossAttention":
                module.forward = self.hook_forward(module)

    

In [5]:
model_id = "hakurei/waifu-diffusion" #多分変わらんと思うけどSD2.x系しか確認してません
pipe = AttentionCoupleGenerator(model_id,dtype = torch.float16)

Downloading (…)tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading (…)tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/819 [00:00<?, ?B/s]

Downloading (…)_encoder/config.json:   0%|          | 0.00/620 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

Downloading (…)on_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading (…)main/vae/config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

Downloading (…)on_pytorch_model.bin:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

Downloading (…)ain/unet/config.json:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

Downloading (…)cheduler_config.json:   0%|          | 0.00/341 [00:00<?, ?B/s]

In [17]:
prompt = [
    "masterpiece, best quality, 2girl anime, touhou ",
    "masterpiece, best quality, 2girl anime, touhou, hakurei reimu",
    "masterpiece, best quality, 2girl anime, touhou, kirisame marisa",
]
negative_prompt = "worst quality, low quality, medium quality, deleted, lowres, comic, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"

In [20]:
images = pipe(prompt,negative_prompt,
              batch_size = 4, #一度に生成する画像数
              num_inference_steps=50, #sampling step数
              height = 768, 
              width = 768, 
              end_steps = 0.7, #attention coupleを終了するステップ数(0～1の割合で指定、1なら全ステップでattention coupleを適用、0は普通の生成になる)
              twoshot_weight=0.8 #左右プロンプトの重み、0で基本プロンプトのみで生成、1で基本プロンプトを無視して生成
              )    

Total Steps:   0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
import matplotlib.pyplot as plt
import math
plt.figure(figsize=(20,20))
for i,image in enumerate(images):
    plt.subplot(math.ceil(len(images)/4),4,i+1)
    plt.imshow(np.array(image))
    plt.axis('off')