## Copyright 2022 Google LLC. Double-click for license information.

In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 버전 관리를 위한 코드들

In [None]:
# 깔끔히 정리
!pip -q uninstall -y diffusers transformers tokenizers accelerate xformers

# diffusers는 P2P 시절 버전 유지
!pip -q install --no-deps "diffusers==0.3.0" ftfy opencv-python ipywidgets

# Py3.12에서 호환되는 짝
!pip -q install "transformers==4.38.2" "tokenizers==0.15.2"


In [None]:
import torch, diffusers, transformers, tokenizers
print("torch:", torch.__version__)
print("diffusers:", diffusers.__version__)
print("transformers:", transformers.__version__)
print("tokenizers:", tokenizers.__version__)


# Null-text inversion + Editing with Prompt-to-Prompt

In [None]:
from typing import Optional, Union, Tuple, List, Callable, Dict
from tqdm.notebook import tqdm
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch.nn.functional as nnf
import numpy as np
import abc
import ptp_utils
import seq_aligner
import shutil
from torch.optim.adam import Adam
from PIL import Image

For loading the Stable Diffusion using Diffusers, follow the instuctions https://huggingface.co/blog/stable_diffusion and update MY_TOKEN with your token.

In [None]:
# --- 수동 로드로 파이프라인 구성 (diffusers==0.3.0 호환, 안전하게 전체 모듈 명시) ---

import torch
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

from transformers import CLIPTokenizer, CLIPTextModel, CLIPFeatureExtractor  # 구버전 호환
from diffusers import (
    AutoencoderKL, UNet2DConditionModel,
    PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker

repo = "CompVis/stable-diffusion-v1-4"

# 1) CLIP 토크나이저/텍스트 인코더
tokenizer    = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# 2) UNet / VAE / (기본) PNDM 스케줄러 - 레포 하위폴더에서 직접 로드
unet      = UNet2DConditionModel.from_pretrained(repo, subfolder="unet")
vae       = AutoencoderKL.from_pretrained(repo, subfolder="vae")
scheduler = PNDMScheduler.from_config(repo, subfolder="scheduler")

# 3) safety_checker / feature_extractor 도 '실제 객체'로 로드 (diffusers==0.3.0은 None/생략 불가)
safety_checker    = StableDiffusionSafetyChecker.from_pretrained(repo, subfolder="safety_checker")
feature_extractor = CLIPFeatureExtractor.from_pretrained(repo, subfolder="feature_extractor")

# 4) 파이프라인 조립
ldm_stable = StableDiffusionPipeline(
    text_encoder=text_encoder,
    vae=vae,
    unet=unet,
    tokenizer=tokenizer,
    scheduler=scheduler,              # 일단 PNDM으로
    safety_checker=safety_checker,
    feature_extractor=feature_extractor,
).to(device)

# 5) DDIM 스케줄러로 교체 (레포에서 로드 후, 파라미터만 덮어쓰기)
ddim = DDIMScheduler.from_config(repo, subfolder="scheduler")  # 여기서 레포/폴더를 넘겨야 함 (FrozenDict X)
# 원 노트북 파라미터와 동일하게 세팅
ddim.beta_start       = 0.00085
ddim.beta_end         = 0.012
ddim.beta_schedule    = "scaled_linear"
ddim.clip_sample      = False
ddim.set_alpha_to_one = False

ldm_stable.scheduler = ddim  # 최종 교체

# (선택) xformers 최적화 끄기 (없으면 조용히 패스)
try:
    ldm_stable.disable_xformers_memory_efficient_attention()
except AttributeError:
    pass

# 원본 노트북 변수명 유지
tokenizer = ldm_stable.tokenizer

print("Loaded:", type(ldm_stable.unet).__name__, "| Scheduler:", type(ldm_stable.scheduler).__name__)
print("beta_start/end:", ldm_stable.scheduler.beta_start, ldm_stable.scheduler.beta_end)
print("Device:", device)


In [None]:
NUM_DDIM_STEPS = 50
LOW_RESOURCE = False
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77

## Prompt-to-Prompt code

In [None]:

class LocalBlend:

    def get_mask(self, maps, alpha, use_pool):
        k = 1
        maps = (maps * alpha).sum(-1).mean(1)
        if use_pool:
            maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
        mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
        mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
        mask = mask.gt(self.th[1-int(use_pool)])
        mask = mask[:1] + mask
        return mask

    def __call__(self, x_t, attention_store):
        self.counter += 1
        if self.counter > self.start_blend:

            maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
            maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
            maps = torch.cat(maps, dim=1)
            mask = self.get_mask(maps, self.alpha_layers, True)
            if self.substruct_layers is not None:
                maps_sub = ~self.get_mask(maps, self.substruct_layers, False)
                mask = mask * maps_sub
            mask = mask.float()
            x_t = x_t[:1] + mask * (x_t - x_t[:1])
        return x_t

    def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, start_blend=0.2, th=(.3, .3)):
        alpha_layers = torch.zeros(len(prompts),  1, 1, 1, 1, MAX_NUM_WORDS)
        for i, (prompt, words_) in enumerate(zip(prompts, words)):
            if type(words_) is str:
                words_ = [words_]
            for word in words_:
                ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
                alpha_layers[i, :, :, :, :, ind] = 1

        if substruct_words is not None:
            substruct_layers = torch.zeros(len(prompts),  1, 1, 1, 1, MAX_NUM_WORDS)
            for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)):
                if type(words_) is str:
                    words_ = [words_]
                for word in words_:
                    ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
                    substruct_layers[i, :, :, :, :, ind] = 1
            self.substruct_layers = substruct_layers.to(device)
        else:
            self.substruct_layers = None
        self.alpha_layers = alpha_layers.to(device)
        self.start_blend = int(start_blend * NUM_DDIM_STEPS)
        self.counter = 0
        self.th=th




class EmptyControl:


    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        return attn


class AttentionControl(abc.ABC):

    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    @property
    def num_uncond_att_layers(self):
        return self.num_att_layers if LOW_RESOURCE else 0

    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            if LOW_RESOURCE:
                attn = self.forward(attn, is_cross, place_in_unet)
            else:
                h = attn.shape[0]
                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

class SpatialReplace(EmptyControl):

    def step_callback(self, x_t):
        if self.cur_step < self.stop_inject:
            b = x_t.shape[0]
            x_t = x_t[:1].expand(b, *x_t.shape[1:])
        return x_t

    def __init__(self, stop_inject: float):
        super(SpatialReplace, self).__init__()
        self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS)


class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention


    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}


class AttentionControlEdit(AttentionStore, abc.ABC):

    def step_callback(self, x_t):
        if self.local_blend is not None:
            x_t = self.local_blend(x_t, self.attention_store)
        return x_t

    def replace_self_attention(self, attn_base, att_replace, place_in_unet):
        if att_replace.shape[2] <= 32 ** 2:
            attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
            return attn_base
        else:
            return att_replace

    @abc.abstractmethod
    def replace_cross_attention(self, attn_base, att_replace):
        raise NotImplementedError

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
            h = attn.shape[0] // (self.batch_size)
            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
            attn_base, attn_repalce = attn[0], attn[1:]
            if is_cross:
                alpha_words = self.cross_replace_alpha[self.cur_step]
                attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
                attn[1:] = attn_repalce_new
            else:
                attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet)
            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
        return attn

    def __init__(self, prompts, num_steps: int,
                 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
                 self_replace_steps: Union[float, Tuple[float, float]],
                 local_blend: Optional[LocalBlend]):
        super(AttentionControlEdit, self).__init__()
        self.batch_size = len(prompts)
        self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
        if type(self_replace_steps) is float:
            self_replace_steps = 0, self_replace_steps
        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
        self.local_blend = local_blend

class AttentionReplace(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)

    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
                 local_blend: Optional[LocalBlend] = None):
        super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)


class AttentionRefine(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
        attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
        # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
        return attn_replace

    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
                 local_blend: Optional[LocalBlend] = None):
        super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
        self.mapper, alphas = self.mapper.to(device), alphas.to(device)
        self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])


class AttentionReweight(AttentionControlEdit):

    def replace_cross_attention(self, attn_base, att_replace):
        if self.prev_controller is not None:
            attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
        attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
        # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
        return attn_replace

    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
                local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
        super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
        self.equalizer = equalizer.to(device)
        self.prev_controller = controller


def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
                  Tuple[float, ...]]):
    if type(word_select) is int or type(word_select) is str:
        word_select = (word_select,)
    equalizer = torch.ones(1, 77)

    for word, val in zip(word_select, values):
        inds = ptp_utils.get_word_inds(text, word, tokenizer)
        equalizer[:, inds] = val
    return equalizer

def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
                cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
                out.append(cross_maps)
    out = torch.cat(out, dim=0)
    out = out.sum(0) / out.shape[0]
    return out.cpu()


def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float], self_replace_steps: float, blend_words=None, equilizer_params=None) -> AttentionControlEdit:
    if blend_words is None:
        lb = None
    else:
        lb = LocalBlend(prompts, blend_word)
    if is_replace_controller:
        controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb)
    else:
        controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb)
    if equilizer_params is not None:
        eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"])
        controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
                                       self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb, controller=controller)
    return controller


def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
    tokens = tokenizer.encode(prompts[select])
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
    images = []
    for i in range(len(tokens)):
        image = attention_maps[:, :, i]
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3)
        image = image.numpy().astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((256, 256)))
        image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
        images.append(image)
    ptp_utils.view_images(np.stack(images, axis=0))


def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
                        max_com=10, select: int = 0):
    attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
    u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
    images = []
    for i in range(max_com):
        image = vh[i].reshape(res, res)
        image = image - image.min()
        image = 255 * image / image.max()
        image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        image = np.array(image)
        images.append(image)
    ptp_utils.view_images(np.concatenate(images, axis=1))

## Null Text Inversion code

In [None]:
def load_512(image_path, left=0, right=0, top=0, bottom=0):
    if type(image_path) is str:
        image = np.array(Image.open(image_path))[:, :, :3]
    else:
        image = image_path
    h, w, c = image.shape
    left = min(left, w-1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h-bottom, left:w-right]
    h, w, c = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]
    image = np.array(Image.fromarray(image).resize((512, 512)))
    return image


class NullInversion:

    def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
        prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t
        pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
        pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
        prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
        return prev_sample

    def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
        timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
        beta_prod_t = 1 - alpha_prod_t
        next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
        next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
        next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
        return next_sample

    def get_noise_pred_single(self, latents, t, context):
        noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
        return noise_pred

    def get_noise_pred(self, latents, t, is_forward=True, context=None):
        latents_input = torch.cat([latents] * 2)
        if context is None:
            context = self.context
        guidance_scale = 1 if is_forward else GUIDANCE_SCALE
        noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
        noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
        if is_forward:
            latents = self.next_step(noise_pred, t, latents)
        else:
            latents = self.prev_step(noise_pred, t, latents)
        return latents

    @torch.no_grad()
    def latent2image(self, latents, return_type='np'):
        latents = 1 / 0.18215 * latents.detach()
        image = self.model.vae.decode(latents)['sample']
        if return_type == 'np':
            image = (image / 2 + 0.5).clamp(0, 1)
            image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
            image = (image * 255).astype(np.uint8)
        return image

    @torch.no_grad()
    def image2latent(self, image):
        with torch.no_grad():
            if type(image) is Image:
                image = np.array(image)
            if type(image) is torch.Tensor and image.dim() == 4:
                latents = image
            else:
                image = torch.from_numpy(image).float() / 127.5 - 1
                image = image.permute(2, 0, 1).unsqueeze(0).to(device)
                latents = self.model.vae.encode(image)['latent_dist'].mean
                latents = latents * 0.18215
        return latents

    @torch.no_grad()
    def init_prompt(self, prompt: str):
        uncond_input = self.model.tokenizer(
            [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
            return_tensors="pt"
        )
        uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
        text_input = self.model.tokenizer(
            [prompt],
            padding="max_length",
            max_length=self.model.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
        self.context = torch.cat([uncond_embeddings, text_embeddings])
        self.prompt = prompt

    @torch.no_grad()
    def ddim_loop(self, latent):
        uncond_embeddings, cond_embeddings = self.context.chunk(2)
        all_latent = [latent]
        latent = latent.clone().detach()
        for i in range(NUM_DDIM_STEPS):
            t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
            noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
            latent = self.next_step(noise_pred, t, latent)
            all_latent.append(latent)
        return all_latent

    @property
    def scheduler(self):
        return self.model.scheduler

    @torch.no_grad()
    def ddim_inversion(self, image):
        latent = self.image2latent(image)
        image_rec = self.latent2image(latent)
        ddim_latents = self.ddim_loop(latent)
        return image_rec, ddim_latents

    def null_optimization(self, latents, num_inner_steps, epsilon):
        uncond_embeddings, cond_embeddings = self.context.chunk(2)
        uncond_embeddings_list = []
        latent_cur = latents[-1]
        bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS)
        for i in range(NUM_DDIM_STEPS):
            uncond_embeddings = uncond_embeddings.clone().detach()
            uncond_embeddings.requires_grad = True
            optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
            latent_prev = latents[len(latents) - i - 2]
            t = self.model.scheduler.timesteps[i]
            with torch.no_grad():
                noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
            for j in range(num_inner_steps):
                noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
                noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
                latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
                loss = nnf.mse_loss(latents_prev_rec, latent_prev)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_item = loss.item()
                bar.update()
                if loss_item < epsilon + i * 2e-5:
                    break
            for j in range(j + 1, num_inner_steps):
                bar.update()
            uncond_embeddings_list.append(uncond_embeddings[:1].detach())
            with torch.no_grad():
                context = torch.cat([uncond_embeddings, cond_embeddings])
                latent_cur = self.get_noise_pred(latent_cur, t, False, context)
        bar.close()
        return uncond_embeddings_list

    def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
        self.init_prompt(prompt)
        ptp_utils.register_attention_control(self.model, None)
        image_gt = load_512(image_path, *offsets)
        if verbose:
            print("DDIM inversion...")
        image_rec, ddim_latents = self.ddim_inversion(image_gt)
        if verbose:
            print("Null-text optimization...")
        uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
        return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings


    def __init__(self, model):
        scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
                                  set_alpha_to_one=False)
        self.model = model
        self.tokenizer = self.model.tokenizer
        self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
        self.prompt = None
        self.context = None

null_inversion = NullInversion(ldm_stable)


## Infernce Code

In [None]:
@torch.no_grad()
def text2image_ldm_stable(
    model,
    prompt:  List[str],
    controller,
    num_inference_steps: int = 50,
    guidance_scale: Optional[float] = 7.5,
    generator: Optional[torch.Generator] = None,
    latent: Optional[torch.FloatTensor] = None,
    uncond_embeddings=None,
    start_time=50,
    return_type='image'
):
    batch_size = len(prompt)
    ptp_utils.register_attention_control(model, controller)
    height = width = 512

    text_input = model.tokenizer(
        prompt,
        padding="max_length",
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
    max_length = text_input.input_ids.shape[-1]
    if uncond_embeddings is None:
        uncond_input = model.tokenizer(
            [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
        )
        uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
    else:
        uncond_embeddings_ = None

    latent, latents = ptp_utils.init_latent(latent, model, height, width, generator, batch_size)
    model.scheduler.set_timesteps(num_inference_steps)
    for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
        if uncond_embeddings_ is None:
            context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
        else:
            context = torch.cat([uncond_embeddings_, text_embeddings])
        latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False)

    if return_type == 'image':
        image = ptp_utils.latent2image(model.vae, latents)
    else:
        image = latents
    return image, latent



def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, verbose=True):
    if run_baseline:
        print("w.o. prompt-to-prompt")
        images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator)
        print("with prompt-to-prompt")
    images, x_t = text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, uncond_embeddings=uncond_embeddings)
    if verbose:
        ptp_utils.view_images(images)
    return images, x_t

In [None]:
image_path = "./gnochi_mirror.jpeg"
prompt = "a cat sitting next to a mirror"
(image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, prompt, offsets=(0,0,200,0), verbose=True)

print("Modify or remove offsets according to your image!")

In [None]:
prompts = [prompt]
controller = AttentionStore()
image_inv, x_t = run_and_display(prompts, controller, run_baseline=False, latent=x_t, uncond_embeddings=uncond_embeddings, verbose=False)
print("showing from left to right: the ground truth image, the vq-autoencoder reconstruction, the null-text inverted image")
ptp_utils.view_images([image_gt, image_enc, image_inv[0]])
show_cross_attention(controller, 16, ["up", "down"])


In [None]:
prompts = ["a cat sitting next to a mirror",
           "a tiger sitting next to a mirror"
        ]

cross_replace_steps = {'default_': .8,}
self_replace_steps = .5
blend_word = ((('cat',), ("tiger",))) # for local edit. If it is not local yet - use only the source object: blend_word = ((('cat',), ("cat",))).
eq_params = {"words": ("tiger",), "values": (2,)} # amplify attention to the word "tiger" by *2

controller = make_controller(prompts, True, cross_replace_steps, self_replace_steps, blend_word, eq_params)
images, _ = run_and_display(prompts, controller, run_baseline=False, latent=x_t, uncond_embeddings=uncond_embeddings)

print("Image is highly affected by the self_replace_steps, usually 0.4 is a good default value, but you may want to try the range 0.3,0.4,0.5,0.7 ")

In [None]:
prompts = ["a cat sitting next to a mirror",
           "a silver cat sculpture sitting next to a mirror"
        ]

cross_replace_steps = {'default_': .8, }
self_replace_steps = .6
blend_word = ((('cat',), ("cat",))) # for local edit
eq_params = {"words": ("silver", 'sculpture', ), "values": (2,2,)}  # amplify attention to the words "silver" and "sculpture" by *2

controller = make_controller(prompts, False, cross_replace_steps, self_replace_steps, blend_word, eq_params)
images, _ = run_and_display(prompts, controller, run_baseline=False, latent=x_t, uncond_embeddings=uncond_embeddings)


In [None]:
prompts = ["a cat sitting next to a mirror",
           "watercolor painting of a cat sitting next to a mirror"
        ]

cross_replace_steps = {'default_': .8, }
self_replace_steps = .7
blend_word = None
eq_params = {"words": ("watercolor",  ), "values": (5, 2,)}  # amplify attention to the word "watercolor" by 5

controller = make_controller(prompts, False, cross_replace_steps, self_replace_steps, blend_word, eq_params)
images, _ = run_and_display(prompts, controller, run_baseline=False, latent=x_t, uncond_embeddings=uncond_embeddings)

# 새로운 실험 1: TI 형태의 최적화

1: placeholder 토큰 추가 함수

In [None]:
# === Textual Inversion: placeholder 토큰 추가 ===

import torch
import torch.nn.functional as F

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

def add_placeholder_token(pipe,
                          placeholder_token: str = "<sks-cat>",
                          initializer_token: str = "cat"):
    """
    SD 파이프라인에 placeholder 토큰을 추가하고,
    초기 임베딩을 initializer_token(예: 'cat')에서 복사.
    """
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder

    # 1) 토큰 추가
    num_added = tokenizer.add_tokens(placeholder_token)
    if num_added == 0:
        print(f"[add_placeholder_token] '{placeholder_token}' is already in tokenizer.")
    else:
        print(f"[add_placeholder_token] Added {num_added} token(s): {placeholder_token}")

    # 2) 텍스트 인코더의 임베딩 사이즈를 tokenizer에 맞게 확장
    text_encoder.resize_token_embeddings(len(tokenizer))

    # 3) placeholder 임베딩을 initializer 임베딩으로 초기화
    token_embeds = text_encoder.get_input_embeddings().weight  # (V, d)
    placeholder_id = tokenizer.convert_tokens_to_ids(placeholder_token)
    init_id = tokenizer.convert_tokens_to_ids(initializer_token)

    with torch.no_grad():
        token_embeds[placeholder_id] = token_embeds[init_id]

    print(f"[add_placeholder_token] placeholder id = {placeholder_id}, init id = {init_id}")
    return placeholder_id


2: B 이미지를 latent로 변환

In [None]:
# === Textual Inversion: 학습용 latent 준비 ===

from glob import glob
from PIL import Image
import numpy as np

def prepare_latents_for_textual_inversion(null_inversion,
                                          image_paths,
                                          device=device):
    """
    이미지 경로 리스트를 받아서, VAE로 latent로 변환한 뒤 하나의 텐서로 합침.
    """
    latents_list = []
    for p in image_paths:
        img = load_512(p)  # 이미 NTI 코드에서 정의됨 (512 정사각형 crop + resize)
        lat = null_inversion.image2latent(img)  # (1, 4, 64, 64)
        latents_list.append(lat)

    latents = torch.cat(latents_list, dim=0).to(device)  # (N, 4, 64, 64)
    print(f"[prepare_latents_for_textual_inversion] latents shape: {latents.shape}")
    return latents

# 예시: B 정체성 이미지들이 들어있는 폴더 지정
b_image_paths = sorted(glob("./cat.jpeg"))  # 혹은 .jpg 등
print("num B images:", len(b_image_paths))

latents_B = prepare_latents_for_textual_inversion(null_inversion, b_image_paths)


3: textual inversion 학습

In [None]:
import torch
import torch.nn.functional as F
from ptp_utils import register_attention_control

def train_textual_inversion_for_identity(
    pipe,
    latents_B: torch.Tensor,
    placeholder_token: str = "<sks-cat>",
    prompt_template: str = "a photo of a {}",
    num_train_steps: int = 800,
    batch_size: int = 1,
    lr: float = 5e-4,
    max_grad_norm: float = 1.0,
    device=None,
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 0. TI 동안 P2P 컨트롤 끄기
    register_attention_control(pipe, None)

    # 1. 디바이스 세팅
    pipe.to(device)
    pipe.unet.to(device)
    pipe.vae.to(device)
    pipe.text_encoder.to(device)

    tokenizer        = pipe.tokenizer
    text_encoder     = pipe.text_encoder
    unet             = pipe.unet
    noise_scheduler  = pipe.scheduler

    # 2. 스케줄러 버퍼를 전부 device 로
    def _to_dev(x):
        return x.to(device) if isinstance(x, torch.Tensor) else x

    if hasattr(noise_scheduler, "alphas_cumprod"):
        noise_scheduler.alphas_cumprod = _to_dev(noise_scheduler.alphas_cumprod)
    if hasattr(noise_scheduler, "alphas"):
        noise_scheduler.alphas = _to_dev(noise_scheduler.alphas)
    if hasattr(noise_scheduler, "betas"):
        noise_scheduler.betas = _to_dev(noise_scheduler.betas)
    if hasattr(noise_scheduler, "final_alpha_cumprod"):
        noise_scheduler.final_alpha_cumprod = _to_dev(noise_scheduler.final_alpha_cumprod)
    if hasattr(noise_scheduler, "timesteps") and isinstance(noise_scheduler.timesteps, torch.Tensor):
        noise_scheduler.timesteps = noise_scheduler.timesteps.to(device)

    # 3. placeholder 토큰 추가
    placeholder_id = add_placeholder_token(
        pipe,
        placeholder_token=placeholder_token,
        initializer_token="cat",
    )
    print(f"[TI] placeholder token '{placeholder_token}' id = {placeholder_id}")

    # 4. UNet / text_encoder freeze, embedding weight만 학습
    for p in unet.parameters():
        p.requires_grad_(False)
    for p in text_encoder.parameters():
        p.requires_grad_(False)

    token_embeds = text_encoder.get_input_embeddings().weight  # (V, d)
    token_embeds.requires_grad_(True)

    optimizer = torch.optim.Adam([token_embeds], lr=lr)

    latents_B = latents_B.to(device)
    N = latents_B.shape[0]
    global_step = 0

    print(f"[TI] start training for token {placeholder_token}")
    print(f"     num_train_steps={num_train_steps}, batch_size={batch_size}, N={N}")

    while global_step < num_train_steps:
        # 4-1. 배치
        idx = torch.randint(0, N, (batch_size,), device=device)
        clean_latents = latents_B[idx]

        noise = torch.randn_like(clean_latents)
        timesteps = torch.randint(
            0,
            noise_scheduler.num_train_timesteps,
            (batch_size,),
            device=device,
            dtype=torch.long,
        )

        noisy_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps)

        # 4-2. 텍스트 인코딩
        prompts = [prompt_template.format(placeholder_token)] * batch_size
        text_inputs = tokenizer(
            prompts,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        ).to(device)

        encoder_hidden_states = text_encoder(text_inputs.input_ids)[0]

        # 4-3. UNet 예측 & loss
        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=encoder_hidden_states,
        ).sample

        # 구버전 스케줄러: epsilon 예측 가정
        target = noise

        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # 4-4. 역전파 & placeholder 행만 gradient 유지
        optimizer.zero_grad()
        loss.backward()

        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_([token_embeds], max_grad_norm)

        with torch.no_grad():
            if token_embeds.grad is not None:
                grad = token_embeds.grad
                mask = torch.zeros_like(grad)
                mask[placeholder_id] = grad[placeholder_id]
                grad.copy_(mask)

        optimizer.step()

        if global_step % 50 == 0:
            print(f"[TI] step {global_step:04d}/{num_train_steps} | loss = {loss.item():.6f}")

        global_step += 1

    print(f"[TI] finished. token '{placeholder_token}' embedding updated.")
    return placeholder_id


실행

In [None]:
# --- P2P AttentionStore: between_steps를 항상 no_grad에서 실행하도록 패치 ---

import torch
import ptp_utils
import types

def patch_between_steps_no_grad():
    patched = False
    for name, obj in vars(ptp_utils).items():
        # 클래스이면서 between_steps 메서드를 가진 애들만 골라서 패치
        if isinstance(obj, type) and hasattr(obj, "between_steps"):
            orig_between = obj.between_steps

            # 이미 패치된 것 같으면 건너뜀 (id 체크)
            if getattr(orig_between, "_patched_no_grad", False):
                continue

            def make_patched(orig):
                def patched_between(self, *args, **kwargs):
                    # <-- 여기서 grad 비활성화
                    with torch.no_grad():
                        return orig(self, *args, **kwargs)
                # 나중에 중복 패치 방지용 플래그
                patched_between._patched_no_grad = True
                return patched_between

            obj.between_steps = make_patched(orig_between)
            print(f"[patch_between_steps_no_grad] Patched {name}.between_steps")
            patched = True

    if not patched:
        print("[patch_between_steps_no_grad] No class with between_steps found in ptp_utils.")

patch_between_steps_no_grad()


In [None]:
# ================================
# A(배경) + B(정체성) 이미지 합성 전체 파이프라인
#   - 1) B로 Textual Inversion
#   - 2) A에 Null-text Inversion
#   - 3) P2P로 "cat" -> placeholder_token 치환 (B 정체성 삽입)
# ================================
import torch
from glob import glob

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# --------------------------------------------------
# 0. 유저 설정: A/B 경로, 프롬프트, 토큰 이름
# --------------------------------------------------
# (1) A: 배경 이미지 경로 + 프롬프트 (cat 이 들어간 문장)
image_A_path = "./gnochi_mirror.jpeg"   # <- 너의 실제 A 이미지 경로로 변경
prompt_A     = "a cat sitting next to a mirror"

# (2) B: 정체성 이미지들이 들어있는 폴더 (여러 장 추천)
b_image_glob = "./cat.jpeg"   # *.jpg 등으로 바꿔도 됨

# (3) Textual Inversion용 placeholder 토큰
placeholder_token = "<sks-cat>"   # 프롬프트 안에서 그대로 사용될 문자열
initializer_token = "cat"         # 처음엔 cat 임베딩에서 시작

# (4) cat이라는 단어를 어떤 토큰으로 취급할지
source_word = "cat"               # A 프롬프트 안에서 교체하고 싶은 단어

# --------------------------------------------------
# 1. B 이미지들로 Textual Inversion 수행
#    (이미 학습해둔 토큰이 있으면, 이 블록은 주석 처리해도 됨)
# --------------------------------------------------
b_image_paths = sorted(glob(b_image_glob))
print(f"#B images = {len(b_image_paths)}")
if len(b_image_paths) == 0:
    raise ValueError("B 정체성 이미지가 없습니다. b_image_glob 경로를 확인하세요.")

# B 이미지를 latent로 변환
latents_B = prepare_latents_for_textual_inversion(
    null_inversion,   # 같은 VAE를 쓰기 위해 null_inversion 객체 재사용
    b_image_paths,
    device=device,
)

# Textual Inversion 학습
_ = train_textual_inversion_for_identity(
    pipe=ldm_stable,
    latents_B=latents_B,
    placeholder_token=placeholder_token,
    prompt_template="a photo of a {}",   # -> "a photo of a <sks-cat>"
    num_train_steps=800,                 # 필요하면 1000 이상으로 늘려도 됨
    batch_size=1,
    lr=5e-4,
    max_grad_norm=1.0,
    device=device,
)
# 이 시점에서 placeholder_token 임베딩은 B 정체성을 담고 있음.

# --------------------------------------------------
# 2. A 이미지에 대해 Null-text Inversion 수행
#    (프롬프트는 원래 A를 잘 설명하는 문장: 여기선 cat 그대로 사용)
# --------------------------------------------------
(image_gt, image_rec), x_t, uncond_embeddings = null_inversion.invert(
    image_A_path,
    prompt_A,
)  # offsets, num_inner_steps, epsilon 등은 기본값 사용

print("Inversion done.")
print("  x_t shape:", x_t.shape)
print("  len(uncond_embeddings):", len(uncond_embeddings))

# --------------------------------------------------
# 3. P2P 편집: source 프롬프트 vs target 프롬프트
#    - source:  A (원본)
#    - target:  A에서 'cat'만 placeholder_token으로 치환 → B 정체성 삽입
# --------------------------------------------------
# target 프롬프트: 간단히 문자열 치환으로 만들자
prompt_B = prompt_A.replace(source_word, placeholder_token)

prompts = [prompt_A, prompt_B]
print("Source prompt :", prompts[0])
print("Target prompt :", prompts[1])

# P2P 세부 설정 (기존 노트북 예제 값에서 크게 벗어나지 않게 설정)
cross_replace_steps = {"default_": 0.8}
self_replace_steps  = 0.5

# 'cat' 이 있는 영역만, target 쪽의 placeholder_token으로 로컬 편집
blend_word = (((source_word,), (placeholder_token,)),)

# equalizer를 쓰고 싶으면 여기서 정의, 지금은 안 씀
eq_params = None

controller = make_controller(
    prompts,
    True,                  # is_replace_controller → AttentionReplace 사용
    cross_replace_steps,
    self_replace_steps,
    blend_word,
    eq_params,
)

# --------------------------------------------------
# 4. NTI 결과(x_t, uncond_embeddings)를 사용해서 편집 샘플링
#    - text2image_ldm_stable 안에서, target 프롬프트의
#      placeholder_token 임베딩이 곧 B 정체성 임베딩으로 사용됨.
# --------------------------------------------------
images_edit, _ = text2image_ldm_stable(
    ldm_stable,
    prompts,
    controller,
    num_inference_steps=NUM_DDIM_STEPS,
    guidance_scale=GUIDANCE_SCALE,
    generator=None,
    latent=x_t,                     # NTI에서 얻은 마지막 latent
    uncond_embeddings=uncond_embeddings,  # NTI에서 최적화된 uncond 리스트
    start_time=NUM_DDIM_STEPS,
    return_type="image",
)

# --------------------------------------------------
# 5. 결과 확인 (좌: GT, 중: 단순 재구성, 우: B 정체성으로 합성된 결과)
# --------------------------------------------------
ptp_utils.view_images([image_gt, image_rec, images_edit[0]])


# 실험 2: masking 학습

디퓨저 사용으로 수정..?

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
# [핵심] 구버전 diffusers에서 Attention 모듈 클래스 직접 가져오기
from diffusers.models.attention import CrossAttention

# ------------------------------------------------------------------------------
# 1. [구버전 호환] 순정 상태 복구 기능을 포함한 마스크 생성 함수
# ------------------------------------------------------------------------------
def get_automask_from_attention(model, tokenizer, latents, prompt, target_word_index, resolution=16):

    # 1. P2P Controller 등록 (이때 U-Net의 forward 메서드가 P2P용으로 바뀜)
    controller = AttentionStore()
    register_attention_control(model, controller)

    # 2. 마스크 생성 (Forward Pass)
    t = torch.tensor([400]).to(model.device)
    noisy_latents = model.scheduler.add_noise(latents, torch.randn_like(latents), t)

    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]

    with torch.no_grad():
        model.unet(noisy_latents, t, encoder_hidden_states=text_embeddings)

    # 3. Attention Map 추출
    attention_maps = controller.get_average_attention()
    target_map = None

    for key in attention_maps:
        if key == resolution:
            attn = attention_maps[key]
            target_map = attn[0].mean(0)[:, target_word_index]
            break

    if target_map is None:
        target_map = list(attention_maps.values())[0][0].mean(0)[:, target_word_index]

    res_int = int(np.sqrt(target_map.shape[0]))
    target_map = target_map.view(1, 1, res_int, res_int)
    mask = F.interpolate(target_map, size=(64, 64), mode='bilinear')
    mask = (mask - mask.min()) / (mask.max() - mask.min())
    mask[mask < 0.3] = 0   # 0.3 미만은 아예 0으로 날려버림 (Hard Threshold)

    controller.reset()

    # [핵심 해결책] 4. U-Net을 '순정 상태'로 원상복구 (Monkey Patching Undo)
    # P2P가 forward 메서드를 바꿔치기했으므로, 클래스 원본(CrossAttention)의 메서드를 다시 주입합니다.
    # 이렇게 하면 P2P의 흔적이 사라지고, 학습 시 Gradient 끊김 문제가 해결됩니다.

    print("  Restoring U-Net to original state for training...")
    for name, module in model.unet.named_modules():
        # 모듈이 CrossAttention 클래스인지 확인
        if isinstance(module, CrossAttention):
            # 인스턴스 메서드를 클래스 원본 메서드로 덮어씌움
            module.forward = CrossAttention.forward.__get__(module, CrossAttention)

    print("  -> Done.")

    return mask.detach()

# ------------------------------------------------------------------------------
# 2. [유지] 최적화 함수 (이전과 동일)
# ------------------------------------------------------------------------------
def optimize_embedding_with_automask(
    model, tokenizer, latents_B,
    placeholder_token,
    initializer_token,
    prompt_template="a photo of a {}",
    target_word_index_in_template=4,
    num_train_steps=300
):
    print(f"Initializing '{placeholder_token}' with '{initializer_token}'...")

    # 1. 토큰 추가
    num_added_tokens = tokenizer.add_tokens(placeholder_token)
    placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)

    # 2. 임베딩 리사이즈 및 초기화
    model.text_encoder.resize_token_embeddings(len(tokenizer))
    token_embeds = model.text_encoder.get_input_embeddings().weight.data
    init_token_id = tokenizer.encode(initializer_token, add_special_tokens=False)[0]
    token_embeds[placeholder_token_id] = token_embeds[init_token_id]

    # [디버깅] 초기 임베딩 값 저장 (비교용)
    initial_embedding = token_embeds[placeholder_token_id].clone()

    # 3. 마스크 생성 (Threshold 0.3 적용 포함)
    print("Generating Auto-Mask...")
    prompt_for_mask = prompt_template.format(initializer_token)

    mask = get_automask_from_attention(
        model, tokenizer, latents_B,
        prompt=prompt_for_mask,
        target_word_index=target_word_index_in_template,
        resolution=16
    ).detach().clone()

    # [수정] Threshold 적용 (User 요청 1번)
    mask[mask < 0.3] = 0

    # 4. 학습 모드 강제 설정 (매우 중요)
    model.text_encoder.train()
    model.text_encoder.get_input_embeddings().weight.requires_grad_(True)

    # Optimizer 설정
    optimizer = torch.optim.Adam([model.text_encoder.get_input_embeddings().weight], lr=1e-3)

    train_prompt = prompt_template.format(placeholder_token)
    text_input = tokenizer(train_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    input_ids = text_input.input_ids.to(model.device)

    print(f"Start Optimization (Steps: {num_train_steps})...")
    latents_B = latents_B.detach()

    for step in range(num_train_steps):
        # ... (노이즈 추가 및 Forward는 기존과 동일) ...
        noise = torch.randn_like(latents_B)
        bsz = latents_B.shape[0]
        timesteps = torch.randint(0, model.scheduler.num_train_timesteps, (bsz,), device=latents_B.device).long()
        noisy_latents = model.scheduler.add_noise(latents_B, noise, timesteps)

        encoder_hidden_states = model.text_encoder(input_ids)[0]
        noise_pred = model.unet(noisy_latents, timesteps, encoder_hidden_states).sample

        loss_pixel = F.mse_loss(noise_pred, noise, reduction="none")
        loss = (loss_pixel * mask).mean()

        loss.backward()

        # [핵심 수정] Gradient 수동 조작 로직 점검
        grads = model.text_encoder.get_input_embeddings().weight.grad
        if grads is not None:
            # 타겟 토큰의 grad만 백업
            target_grad = grads[placeholder_token_id, :].clone()
            # 전체 grad 0으로 초기화 (다른 토큰 보호)
            grads.data.zero_()
            # 타겟 토큰 grad 복구
            grads.data[placeholder_token_id, :] = target_grad

        optimizer.step()
        optimizer.zero_grad()

        if step % 50 == 0:
            print(f"  Step {step}: Loss {loss.item():.4f}")

    # [디버깅] 학습 후 변화량 체크
    final_embedding = model.text_encoder.get_input_embeddings().weight.data[placeholder_token_id]
    delta = (final_embedding - initial_embedding).norm().item()
    print(f"Optimization Done. Embedding Change (L2 Norm): {delta:.6f}")

    if delta < 1e-4:
        print("WARNING: 임베딩이 거의 변하지 않았습니다! 학습률(LR)을 높이거나 코드를 점검하세요.")
    else:
        print("SUCCESS: 임베딩이 유의미하게 업데이트되었습니다.")

    return placeholder_token

In [None]:
import torch
from glob import glob

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# --------------------------------------------------
# 0. 유저 설정
# --------------------------------------------------
image_A_path = "./gnochi_mirror.jpeg"
prompt_A     = "a cat sitting next to a mirror"
b_image_glob = "./cat.jpeg"

placeholder_token = "<sks-cat>"
initializer_token = "cat"
source_word = "cat"

# --------------------------------------------------
# 1. B 이미지들로 Textual Inversion 수행 (Masked Version)
# --------------------------------------------------
b_image_paths = sorted(glob(b_image_glob))
print(f"#B images = {len(b_image_paths)}")

if len(b_image_paths) == 0:
    print("Warning: B 이미지가 없습니다.")
else:
    # [수정 1] 이미지를 하나만 사용하도록 강제 (Batch Error 방지)
    # 여러 장이 있어도 첫 번째 장만 사용하여 Identity를 추출합니다.
    b_image_paths = b_image_paths[:1]
    print(f"Using single image for optimization: {b_image_paths}")

    # [수정 2] 기존 함수 사용 (null_inversion 인자 전달)
    # 단, 반환값이 Tuple(images, latents)일 수 있으므로 안전하게 처리
    raw_output = prepare_latents_for_textual_inversion(
        null_inversion,
        b_image_paths,
        device=device,
    )

    # 튜플인지 확인해서 Latent만 꺼내기
    if isinstance(raw_output, tuple) or isinstance(raw_output, list):
        latents_B = raw_output[1] # 보통 두 번째 요소가 latent
    else:
        latents_B = raw_output # 텐서 그대로

    latents_B = latents_B.to(ldm_stable.device)
    print(f"Latents loaded. Shape: {latents_B.shape}") # [1, 4, 64, 64] 확인

    # [수정 3] 기존 train_textual_inversion_for_identity 대신
    #          위에서 정의한 'Masked Optimization' 함수 호출
    optimized_token = optimize_embedding_with_automask(
        model=ldm_stable,  # ldm_stable 전달
        tokenizer=tokenizer,
        latents_B=latents_B,
        placeholder_token=placeholder_token,
        initializer_token=initializer_token,
        prompt_template="a photo of a {}",
        target_word_index_in_template=4,
        num_train_steps=500, # 스텝 수 조절 가능
    )

    print(f"Optimization Finished. Token: {optimized_token}")


# --------------------------------------------------
# 2. A 이미지에 대해 Null-text Inversion 수행 (기존 유지)
# --------------------------------------------------
(image_gt, image_rec), x_t, uncond_embeddings = null_inversion.invert(
    image_A_path,
    prompt_A,
)
print("Inversion done.")

# --------------------------------------------------
# 3. P2P 편집 설정 (기존 유지)
# --------------------------------------------------
prompt_B = prompt_A.replace(source_word, placeholder_token)
prompts = [prompt_A, prompt_B]

print("Source prompt :", prompts[0])
print("Target prompt :", prompts[1])

cross_replace_steps = {"default_": 0.8}
self_replace_steps  = 0.6 # 너무 낮으면 구조가 깨짐, 적당히 조절

blend_word = (((source_word,), (placeholder_token,)),)

controller = make_controller(
    prompts,
    True,   # AttentionReplace
    cross_replace_steps,
    self_replace_steps,
    blend_word,
    None,
)

# --------------------------------------------------
# 4. 이미지 생성 (기존 유지)
# --------------------------------------------------
images_edit, _ = text2image_ldm_stable(
    ldm_stable,
    prompts,
    controller,
    num_inference_steps=NUM_DDIM_STEPS,
    guidance_scale=GUIDANCE_SCALE,
    generator=None,
    latent=x_t,
    uncond_embeddings=uncond_embeddings,
    start_time=NUM_DDIM_STEPS,
    return_type="image",
)

# --------------------------------------------------
# 5. 결과 확인
# --------------------------------------------------
ptp_utils.view_images([image_gt, image_rec, images_edit[0]])

어째서인지 masking 없을때와의 차이가 없음...

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

# ======================================================
# [검증 실험] 학습된 Embedding이 B 이미지를 얼마나 잘 재현하는가?
# ======================================================

# 1. 검증용 프롬프트
test_prompt = f"a photo of {placeholder_token}"  # "a photo of <sks-cat>"
print(f"Testing Prompt: {test_prompt}")

# 2. Reference B 이미지 준비 (시각화용)
# b_image_paths[0]가 유효한지 확인
if len(b_image_paths) > 0:
    img_b_origin = Image.open(b_image_paths[0]).convert("RGB").resize((512, 512))
else:
    print("Warning: B 이미지가 없습니다.")
    img_b_origin = Image.new("RGB", (512, 512), (255, 255, 255)) # 빈 이미지

# 3. P2P 컨트롤러 없이 순수 생성 (일반적인 Text-to-Image)
ver_images = []
print("Generating verification images...")

# 랜덤 시드 4개를 돌려봅니다.
for i in range(4):
    g_cpu = torch.Generator().manual_seed(8888 + i)

    # 노트북 내에 정의된 AttentionStore 사용 (ptp_utils 아님)
    dummy_controller = AttentionStore()

    # 생성
    image, _ = text2image_ldm_stable(
        ldm_stable,
        [test_prompt],
        dummy_controller,
        num_inference_steps=50,
        guidance_scale=7.5,
        generator=g_cpu,
        latent=None, # 랜덤 노이즈에서 시작 (Reconstruction 능력 검증)
        uncond_embeddings=None,
        start_time=50,
        return_type="image"
    )

    # image[0]는 (512, 512, 3) 형태의 numpy array라고 가정
    ver_images.append(image[0])

    # 메모리 정리
    dummy_controller.reset()

# 4. 결과 시각화 (Matplotlib 사용 - 에러 없음)
plt.figure(figsize=(20, 5))

# (1) Reference Image B
plt.subplot(1, 5, 1)
plt.imshow(img_b_origin)
plt.title("Reference (Input)", fontsize=12)
plt.axis('off')

# (2) Generated Images
for i, img in enumerate(ver_images):
    plt.subplot(1, 5, i + 2)
    plt.imshow(img)
    plt.title(f"Gen {i+1} (<sks-cat>)", fontsize=12)
    plt.axis('off')

plt.show()

mask 부터 다시 잘 찾아보자..

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from diffusers.models.attention import CrossAttention

# ------------------------------------------------------------------------------
# 시각화 전용 함수: Raw Attention Map을 뽑아냅니다.
# ------------------------------------------------------------------------------
def visualize_attention_thresholds(model, tokenizer, image_path, prompt, target_word_index, thresholds=[0.2, 0.4, 0.6, 0.8]):

    # 1. 이미지 로드 및 전처리
    image_pil = Image.open(image_path).convert("RGB").resize((512, 512))
    image_np = np.array(image_pil).astype(np.float32) / 127.5 - 1.0
    image_tensor = torch.from_numpy(image_np.transpose(2, 0, 1)).unsqueeze(0).to(model.device)

    # Latent 인코딩
    with torch.no_grad():
        latents = model.vae.encode(image_tensor).latent_dist.sample() * 0.18215

    # 2. Attention Store 등록
    controller = AttentionStore()
    register_attention_control(model, controller)

    # 3. Forward Pass (노이즈 추가하여 구조 파악)
    t = torch.tensor([400]).to(model.device)
    noisy_latents = model.scheduler.add_noise(latents, torch.randn_like(latents), t)

    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]

    with torch.no_grad():
        model.unet(noisy_latents, t, encoder_hidden_states=text_embeddings)

    # 4. Attention Map 추출 (16x16 해상도 기준)
    attention_maps = controller.get_average_attention()
    target_map = None
    resolution = 16

    for key in attention_maps:
        if key == resolution:
            attn = attention_maps[key]
            target_map = attn[0].mean(0)[:, target_word_index]
            break

    if target_map is None: # fallback
        target_map = list(attention_maps.values())[0][0].mean(0)[:, target_word_index]

    # 5. 시각화를 위해 512x512로 Upscaling
    res_int = int(np.sqrt(target_map.shape[0]))
    target_map = target_map.view(1, 1, res_int, res_int)

    # Raw Map (0~1 정규화 전, 분포 확인용)
    raw_map_flat = target_map.flatten()
    print(f"Attention Value Stats | Min: {raw_map_flat.min():.4f}, Max: {raw_map_flat.max():.4f}, Mean: {raw_map_flat.mean():.4f}")

    # 시각화용 맵 (0~1 정규화)
    attn_map_highres = F.interpolate(target_map, size=(512, 512), mode='bilinear')
    attn_map_norm = (attn_map_highres - attn_map_highres.min()) / (attn_map_highres.max() - attn_map_highres.min())
    attn_map_norm = attn_map_norm.squeeze().cpu().numpy()

    # 6. Cleanup (모델 원상복구)
    controller.reset()
    for name, module in model.unet.named_modules():
        if isinstance(module, CrossAttention):
            module.forward = CrossAttention.forward.__get__(module, CrossAttention)

    # --------------------------------------------------------------------------
    # 7. 시각화 (Matplotlib)
    # --------------------------------------------------------------------------
    plt.figure(figsize=(20, 10))

    # (1) 원본 이미지
    plt.subplot(2, 3, 1)
    plt.imshow(image_pil)
    plt.title("Original Image")
    plt.axis('off')

    # (2) Raw Attention Heatmap
    plt.subplot(2, 3, 2)
    plt.imshow(attn_map_norm, cmap='jet')
    plt.title("Raw Attention Map ('cat')")
    plt.colorbar()
    plt.axis('off')

    # (3) Superimposed (이미지 위에 맵 겹치기)
    plt.subplot(2, 3, 3)
    plt.imshow(image_pil)
    plt.imshow(attn_map_norm, cmap='jet', alpha=0.5) # 반투명 오버레이
    plt.title("Overlay")
    plt.axis('off')

    # (4) Threshold 적용 결과들
    for i, thresh in enumerate(thresholds):
        if i >= 3: break # 공간상 3개만 보여줌
        mask_binary = (attn_map_norm > thresh).astype(np.float32)

        # 마스크가 적용된 이미지 (배경을 검게 처리)
        masked_img = np.array(image_pil).astype(np.float32) / 255.0
        masked_img = masked_img * mask_binary[:, :, None] # Broadcast

        plt.subplot(2, 3, 4 + i)
        plt.imshow(masked_img)
        plt.title(f"Threshold > {thresh}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# ==============================================================================
# 실행: B 이미지에 대해 'cat'이 어디 잡히는지 확인
# ==============================================================================
target_image_path = b_image_paths[0] # B 이미지 경로
check_prompt = "a photo of a cat"    # B 이미지를 설명하는 일반 프롬프트

visualize_attention_thresholds(
    model=ldm_stable,
    tokenizer=tokenizer,
    image_path=target_image_path,
    prompt=check_prompt,
    target_word_index=4,  # "a photo of a cat" -> cat index is 4
    thresholds=[0.3, 0.5, 0.7] # 보고 싶은 기준치 설정
)

In [None]:
import torch
from glob import glob

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# --------------------------------------------------
# 0. 유저 설정
# --------------------------------------------------
image_A_path = "./gnochi_mirror.jpeg"
prompt_A     = "a cat sitting next to a mirror"
b_image_glob = "./cat.jpeg"  # (주의) 대비가 확실한 새로운 고양이 사진 권장

placeholder_token = "<sks-cat>"
initializer_token = "cat"
source_word = "cat"

# --------------------------------------------------
# 1. B 이미지들로 Textual Inversion 수행 (Masked Version)
# --------------------------------------------------
b_image_paths = sorted(glob(b_image_glob))
print(f"#B images = {len(b_image_paths)}")

if len(b_image_paths) == 0:
    print("Warning: B 이미지가 없습니다.")
else:
    # 1장만 사용 강제
    b_image_paths = b_image_paths[:1]
    print(f"Using single image for optimization: {b_image_paths}")

    # Latent 추출
    raw_output = prepare_latents_for_textual_inversion(
        null_inversion,
        b_image_paths,
        device=device,
    )

    if isinstance(raw_output, tuple) or isinstance(raw_output, list):
        latents_B = raw_output[1]
    else:
        latents_B = raw_output

    latents_B = latents_B.to(ldm_stable.device)
    print(f"Latents loaded. Shape: {latents_B.shape}")

    # Masked Optimization 실행
    # (주의: optimize_embedding_with_automask 함수가 정의되어 있어야 함)
    optimized_token = optimize_embedding_with_automask(
        model=ldm_stable,
        tokenizer=tokenizer,
        latents_B=latents_B,
        placeholder_token=placeholder_token,
        initializer_token=initializer_token,
        prompt_template="a photo of a {}",
        target_word_index_in_template=4,
        num_train_steps=500,
    )

    print(f"Optimization Finished. Token: {optimized_token}")


# --------------------------------------------------
# 2. A 이미지에 대해 Null-text Inversion 수행
# --------------------------------------------------
(image_gt, image_rec), x_t, uncond_embeddings = null_inversion.invert(
    image_A_path,
    prompt_A,
)
print("Inversion done.")

# --------------------------------------------------
# 3. P2P 편집 설정 (핵심 수정 파트!)
# --------------------------------------------------
prompt_B = prompt_A.replace(source_word, placeholder_token)
prompts = [prompt_A, prompt_B]

print("Source prompt :", prompts[0])
print("Target prompt :", prompts[1])

# [수정됨] 구조 강제 비율을 0.8 -> 0.4로 대폭 낮춤
# 0.4: 초반 40%만 원본 구조를 따르고, 나머지 60%는 새로운 임베딩이 자유롭게 그림
cross_replace_steps = {"default_": 0.4}
self_replace_steps  = 0.4

blend_word = (((source_word,), (placeholder_token,)),)

controller = make_controller(
    prompts,
    True,   # AttentionReplace
    cross_replace_steps,
    self_replace_steps,
    blend_word,
    None,
)

# --------------------------------------------------
# 4. 이미지 생성
# --------------------------------------------------
# [Tip] 만약 결과가 여전히 안 변한다면, 아래 uncond_embeddings=uncond_embeddings를
#       uncond_embeddings=None 으로 바꿔서 NTI를 꺼보세요.
images_edit, _ = text2image_ldm_stable(
    ldm_stable,
    prompts,
    controller,
    num_inference_steps=NUM_DDIM_STEPS,
    guidance_scale=GUIDANCE_SCALE,
    generator=None,
    latent=x_t,
    uncond_embeddings=None, # NTI 적용 (너무 강하면 None으로 변경 시도)
    start_time=NUM_DDIM_STEPS,
    return_type="image",
)

# --------------------------------------------------
# 5. 결과 확인
# --------------------------------------------------
ptp_utils.view_images([image_gt, image_rec, images_edit[0]])

검정 고양이 어디감..

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# ------------------------------------------------------------------------------
# [New] Slerp 함수 정의: 노이즈의 분산(에너지)을 유지하며 섞어줍니다.
# ------------------------------------------------------------------------------
def slerp(val, low, high):
    """
    Spherical Linear Interpolation
    val: interpolation factor (0.0 = low, 1.0 = high)
    low: starting vector
    high: target vector
    """
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)

    # 두 벡터 사이의 각도 계산 (Clamp for numerical stability)
    dot = (low_norm * high_norm).sum(1)
    dot = torch.clamp(dot, -1, 1)
    omega = torch.acos(dot).unsqueeze(1)

    so = torch.sin(omega)
    # 0에 가까운 경우(두 벡터가 평행) 예외 처리
    if so.abs().sum() < 1e-6:
        return (1.0 - val) * low + val * high

    res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
    return res

# ------------------------------------------------------------------------------
# 1. 안전 장치: x_t 데이터 타입 및 장치 보정
# ------------------------------------------------------------------------------
if 'x_t' not in locals() or x_t is None:
    print("Error: 'x_t' 변수가 없습니다. 위쪽의 [2. Null-text Inversion] 셀을 먼저 실행해주세요!")
else:
    print("x_t found. Checking type...")

    # x_t가 리스트라면 첫 번째 요소를 꺼냅니다.
    if isinstance(x_t, list):
        start_noise = x_t[0]
    else:
        start_noise = x_t

    # 모델과 같은 장치(GPU)로 강제 이동
    device = ldm_stable.device
    start_noise = start_noise.to(device)

    # ------------------------------------------------------------------------------
    # 2. 파라미터 설정 (Soft NTI + Relaxed P2P)
    # ------------------------------------------------------------------------------
    # [수정됨] Cross Replace를 0.4로 낮추어 Identity 반영 구간 확보
    cross_replace_steps = {"default_": 0.4}

    # 텍스처(색감) 제약 완화 (0.2는 좋은 설정입니다)
    self_replace_steps  = 0.2

    # 배경 유지 강도 (0.6 ~ 0.7 추천)
    nti_strength = 0.6

    print(f"Applying Soft NTI (Strength: {nti_strength}) using Slerp")
    print(f"Relaxing P2P Constraints (Self-Replace: {self_replace_steps})")

    # ------------------------------------------------------------------------------
    # 3. Soft NTI Latent 생성 (Slerp 적용)
    # ------------------------------------------------------------------------------
    random_latents = torch.randn_like(start_noise)

    # [수정됨] 단순 선형 합 대신 Slerp 사용
    # nti_strength가 1에 가까울수록 start_noise(원본 구조)를 많이 가져감
    # 따라서 random_latents 쪽으로 가는 비율은 (1 - nti_strength)가 됨
    start_latents = slerp(1.0 - nti_strength, start_noise, random_latents)

    # ------------------------------------------------------------------------------
    # 4. P2P 컨트롤러 설정
    # ------------------------------------------------------------------------------
    prompt_B = prompt_A.replace(source_word, placeholder_token)
    prompts = [prompt_A, prompt_B]

    # Attention Refinement 제어
    controller = make_controller(
        prompts,
        True,
        cross_replace_steps,
        self_replace_steps,
        blend_word,
        None,
    )

    # ------------------------------------------------------------------------------
    # 5. 이미지 생성
    # ------------------------------------------------------------------------------
    print("Generating image...")

    try:
        # 주의: NTI의 uncond_embeddings가 너무 강력하면(Overfitting),
        # Identity 변경을 방해할 수 있습니다. 여전히 안 바뀐다면
        # uncond_embeddings 대신 ldm_stable.get_learned_conditioning([""]) 사용 고려.
        images_edit, _ = text2image_ldm_stable(
            ldm_stable,
            prompts,
            controller,
            num_inference_steps=NUM_DDIM_STEPS,
            guidance_scale=GUIDANCE_SCALE,
            generator=None,
            latent=start_latents,
            uncond_embeddings=uncond_embeddings,
            start_time=NUM_DDIM_STEPS,
            return_type="image",
        )
        print("Generation finished.")

    except Exception as e:
        print(f"Generation Failed with Error: {e}")
        import traceback
        traceback.print_exc()
        images_edit = []

    # ------------------------------------------------------------------------------
    # 6. 결과 시각화
    # ------------------------------------------------------------------------------
    if images_edit is None or len(images_edit) == 0:
        print("이미지가 생성되지 않았습니다.")
    else:
        print("Displaying results...")
        plt.figure(figsize=(15, 5))

        # (1) Original GT
        plt.subplot(1, 3, 1)
        if isinstance(image_gt, torch.Tensor):
            # Tensor인 경우 permute 처리
            disp_img = image_gt.detach().cpu()
            if disp_img.ndim == 4: disp_img = disp_img[0]
            plt.imshow(disp_img.permute(1, 2, 0).numpy())
        else:
            plt.imshow(image_gt)
        plt.title("Original (A)")
        plt.axis('off')

        # (2) Result
        plt.subplot(1, 3, 2)
        plt.imshow(images_edit[0])
        plt.title(f"Result (<sks-cat>)\nSoft NTI: {nti_strength}")
        plt.axis('off')

        # (3) Reference B
        if 'b_image_paths' in locals() and len(b_image_paths) > 0:
            plt.subplot(1, 3, 3)
            # 경로가 올바른지 확인 후 로드
            try:
                ref_img = Image.open(b_image_paths[0]).resize((512,512))
                plt.imshow(ref_img)
            except:
                print("Reference image load failed.")
            plt.title("Reference (B)")
            plt.axis('off')

        plt.show()

In [None]:
# ------------------------------------------------------------------------------
# [실험 도구] 파라미터 Grid Search 및 비교 시각화
# ------------------------------------------------------------------------------

# 1. 테스트하고 싶은 조합들을 리스트에 넣으세요.
# (nti_strength, cross_replace, self_replace) 순서
experiments = [
    (0.7, 0.4, 0.2),  # 현재 설정 (기준)
    (0.8, 0.7, 0.2),  # 배경/구도 강화
    (0.9, 0.4, 0.3),  # 전체적인 보존력 강화
    (0.9, 0.9, 0.3),  # 구조는 강하게, 텍스처는 유연하게
]

results = []

print(f"총 {len(experiments)}개의 조합을 실험합니다...")

for idx, (nti, cross, self_r) in enumerate(experiments):
    print(f"\n[Experiment {idx+1}/{len(experiments)}] NTI:{nti}, Cross:{cross}, Self:{self_r}")

    # 1. Soft NTI (Slerp)
    random_latents = torch.randn_like(start_noise)
    start_latents = slerp(1.0 - nti, start_noise, random_latents)

    # 2. P2P Controller
    cross_replace_steps = {"default_": cross}
    self_replace_steps = self_r

    controller = make_controller(
        prompts, True, cross_replace_steps, self_replace_steps, blend_word, None
    )

    # 3. Generation
    try:
        imgs, _ = text2image_ldm_stable(
            ldm_stable, prompts, controller,
            num_inference_steps=NUM_DDIM_STEPS,
            guidance_scale=GUIDANCE_SCALE,
            generator=None,
            latent=start_latents,
            uncond_embeddings=uncond_embeddings, # 필요시 ldm_stable.get_learned_conditioning([""]) 로 변경 테스트
            start_time=NUM_DDIM_STEPS,
            return_type="image"
        )
        results.append((imgs[0], f"NTI:{nti}\nCr:{cross}, Sf:{self_r}"))
    except Exception as e:
        print(f"Error: {e}")
        results.append((np.zeros((512,512,3)), "Error"))

# ------------------------------------------------------------------------------
# 결과 시각화 (한 줄에 쫙 펼쳐서 비교)
# ------------------------------------------------------------------------------
plt.figure(figsize=(20, 6))

# (1) 원본
plt.subplot(1, len(experiments)+2, 1)
if isinstance(image_gt, torch.Tensor):
    disp_img = image_gt.detach().cpu()
    if disp_img.ndim == 4: disp_img = disp_img[0]
    plt.imshow(disp_img.permute(1, 2, 0).numpy())
else:
    plt.imshow(image_gt)
plt.title("Original (A)")
plt.axis('off')

# (2) 실험 결과들
for i, (img, label) in enumerate(results):
    plt.subplot(1, len(experiments)+2, i+2)
    plt.imshow(img)
    plt.title(label, fontsize=10)
    plt.axis('off')

# (3) 레퍼런스 B
if 'b_image_paths' in locals() and len(b_image_paths) > 0:
    plt.subplot(1, len(experiments)+2, len(experiments)+2)
    try:
        plt.imshow(Image.open(b_image_paths[0]).resize((512,512)))
        plt.title("Ref (B)")
    except: pass
    plt.axis('off')

plt.tight_layout()
plt.show()

## 모듈화

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from torch.optim.adam import Adam
import matplotlib.pyplot as plt

# -------------------------------------------------------------------------
# 1. Helper Functions
# -------------------------------------------------------------------------

def slerp(val, low, high):
    """구면 선형 보간 (Spherical Linear Interpolation)"""
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos(torch.clamp((low_norm * high_norm).sum(1), -1, 1))
    so = torch.sin(omega)

    if so.abs().sum() < 1e-6:
        return (1.0 - val) * low + val * high

    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + \
          (torch.sin(val * omega) / so).unsqueeze(1) * high
    return res

def reset_attention_hooks(model):
    """P2P Hook 제거 (메모리 누수 및 에러 방지)"""
    for module in model.unet.modules():
        if hasattr(module, "_forward_hooks"):
            module._forward_hooks.clear()
    try:
        if hasattr(model.unet, "set_default_attn_processor"):
            model.unet.set_default_attn_processor()
    except:
        pass

def get_automask_from_attention(model, tokenizer, latents, prompt, target_word_index, resolution=16):
    """크로스 어텐션 기반 자동 마스크 생성"""
    from diffusers.models.attention import CrossAttention

    # P2P Controller 등록 (외부 라이브러리 의존)
    controller = AttentionStore()
    ptp_utils.register_attention_control(model, controller)

    # Forward Pass
    t = torch.tensor([400]).long()
    # device mismatch 방지: 스케줄러엔 CPU, 모델엔 GPU
    noisy_latents = model.scheduler.add_noise(latents, torch.randn_like(latents), t.to("cpu"))

    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length,
                           truncation=True, return_tensors="pt")
    text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]

    with torch.no_grad():
        model.unet(noisy_latents, t.to(model.device), encoder_hidden_states=text_embeddings)

    # Attention Map 추출
    attention_maps = controller.get_average_attention()
    target_map = None

    for key in attention_maps:
        if key == resolution:
            attn = attention_maps[key]
            target_map = attn[0].mean(0)[:, target_word_index]
            break

    if target_map is None:
        target_map = list(attention_maps.values())[0][0].mean(0)[:, target_word_index]

    res_int = int(np.sqrt(target_map.shape[0]))
    target_map = target_map.view(1, 1, res_int, res_int)
    mask = F.interpolate(target_map, size=(64, 64), mode='bilinear')
    mask = (mask - mask.min()) / (mask.max() - mask.min())
    mask[mask < 0.3] = 0  # Hard threshold

    controller.reset()

    # U-Net 원상복구
    for name, module in model.unet.named_modules():
        if isinstance(module, CrossAttention):
            module.forward = CrossAttention.forward.__get__(module, CrossAttention)

    return mask.detach()

# -------------------------------------------------------------------------
# 2. Textual Inversion Function
# -------------------------------------------------------------------------
def train_textual_inversion_with_mask(
    model, tokenizer, ref_image_path, placeholder_token, initializer_token,
    prompt_template="a photo of a {}", target_word_index=4, num_train_steps=500, lr=1e-3, device="cuda"
):
    print(f"\n[Step 1] Training Textual Inversion for '{placeholder_token}'")

    # 이미지 로드 및 Latent 변환
    img = load_512(ref_image_path)
    latents_B = null_inversion.image2latent(img).to(device)

    # 토큰 추가 및 초기화
    num_added = tokenizer.add_tokens(placeholder_token)
    placeholder_id = tokenizer.convert_tokens_to_ids(placeholder_token)
    init_id = tokenizer.convert_tokens_to_ids(initializer_token)

    model.text_encoder.resize_token_embeddings(len(tokenizer))
    token_embeds = model.text_encoder.get_input_embeddings().weight.data
    token_embeds[placeholder_id] = token_embeds[init_id].clone()

    # 마스크 생성
    print("  Generating attention mask...")
    mask = get_automask_from_attention(
        model, tokenizer, latents_B,
        prompt=prompt_template.format(initializer_token),
        target_word_index=target_word_index
    ).detach().clone()

    # 학습 설정
    model.text_encoder.train()
    model.text_encoder.get_input_embeddings().weight.requires_grad_(True)
    optimizer = Adam([model.text_encoder.get_input_embeddings().weight], lr=lr)

    train_prompt = prompt_template.format(placeholder_token)
    text_input = tokenizer(train_prompt, padding="max_length", max_length=tokenizer.model_max_length,
                           truncation=True, return_tensors="pt")
    input_ids = text_input.input_ids.to(device)

    # 학습 루프
    print(f"  Optimizing ({num_train_steps} Steps)...")
    latents_B = latents_B.detach()

    for step in range(num_train_steps):
        noise = torch.randn_like(latents_B)
        bsz = latents_B.shape[0]
        # Device Mismatch 해결
        timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (bsz,), device="cpu").long()
        noisy_latents = model.scheduler.add_noise(latents_B, noise, timesteps)

        encoder_hidden_states = model.text_encoder(input_ids)[0]
        noise_pred = model.unet(noisy_latents, timesteps.to(device), encoder_hidden_states).sample

        loss_pixel = F.mse_loss(noise_pred, noise, reduction="none")
        loss = (loss_pixel * mask).mean()

        loss.backward()

        grads = model.text_encoder.get_input_embeddings().weight.grad
        if grads is not None:
            target_grad = grads[placeholder_id, :].clone()
            grads.data.zero_()
            grads.data[placeholder_id, :] = target_grad

        optimizer.step()
        optimizer.zero_grad()

        if step % 100 == 0:
            print(f"    Step {step}: Loss {loss.item():.4f}")

    print(f"Embedding optimized successfully!")
    return placeholder_id

In [None]:
# =============================================================================
# [설정] 이미지 및 프롬프트 경로
# =============================================================================
source_image_path = "./car1.jpg"      # A 이미지 (구조)
reference_image_path = "./car2.jpg"    # B 이미지 (스타일/정체성)
source_prompt = "car standing on a beach"
source_word = "car"
placeholder_token = "<sks-car>"

# =============================================================================
# [Step 1] Textual Inversion (정체성 학습)
# =============================================================================
reset_attention_hooks(ldm_stable)

train_textual_inversion_with_mask(
    model=ldm_stable,
    tokenizer=tokenizer,
    ref_image_path=reference_image_path,
    placeholder_token=placeholder_token,
    initializer_token=source_word,
    num_train_steps=500,  # 충분히 학습
    lr=1e-3,
    device="cuda"
)

# =============================================================================
# [Step 2] Null-text Inversion (원본 구조 추출)
# =============================================================================
print(f"\n[Step 2] Running Null-text Inversion on Source...")

(image_gt, image_rec), x_t, uncond_embeddings = null_inversion.invert(
    source_image_path,
    source_prompt,
    offsets=(0, 0, 0, 0),
    verbose=True
)

print(f"Inversion Completed.")
print(f"   x_t shape: {x_t.shape}")

# 원본 재구성 확인
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1); plt.imshow(image_gt); plt.title("Original"); plt.axis('off')
plt.subplot(1, 2, 2); plt.imshow(image_rec); plt.title("Reconstruction"); plt.axis('off')
plt.show()

In [None]:
target_nti_strength = 0.3      # 구조 유지 강도 (낮을수록 참조 이미지 형태 따라감)
target_cross_replace = 0.7     # 교체 강도 (높을수록 참조 이미지 스타일 강함)
target_self_replace = 0.3      # Self Attention 교체 강도

device = "cuda" if torch.cuda.is_available() else "cpu"

# =============================================================================
# 3. 합성 단계 (Step 3) - 목표 파라미터로 이미지 생성
# =============================================================================
print(f"\n--- [합성 시작] NTI:{target_nti_strength}, Cross:{target_cross_replace}, Self:{target_self_replace} ---")

try:
    torch.cuda.empty_cache()
    reset_attention_hooks(ldm_stable)

    # (1) 프롬프트 설정
    target_prompt = source_prompt.replace(source_word, placeholder_token)
    prompts = [source_prompt, target_prompt]
    blend_word = (((source_word,), (placeholder_token,)),)

    # (2) Soft NTI (Slerp) 적용
    random_noise = torch.randn_like(x_t)
    start_latents = slerp(1.0 - target_nti_strength, x_t.flatten(1), random_noise.flatten(1))
    start_latents = start_latents.view_as(x_t)

    # (3) Controller 설정 (목표 파라미터 적용)
    controller = make_controller(
        prompts,
        True, # is_replace_controller
        {"default_": target_cross_replace},
        target_self_replace,
        blend_word,
        None
    )

    # (4) 이미지 생성 실행
    images_edit, _ = text2image_ldm_stable(
        ldm_stable,
        prompts,
        controller,
        num_inference_steps=50,
        guidance_scale=7.5,
        generator=None,
        latent=start_latents,
        uncond_embeddings=uncond_embeddings,
        start_time=50,
        return_type="image",
    )

    result_image = images_edit[0]
    print("합성 완료.")

except Exception as e:
    print(f"합성 중 에러 발생: {e}")
    result_image = np.zeros((512, 512, 3)) # 에러 시 검은 이미지


# =============================================================================
# 4. 결과 비교 시각화 (원본 vs 참조 vs 결과)
# =============================================================================
print("\n--- 결과 시각화 ---")

# 원본 이미지 불러오기
src_img_pil = Image.open(source_image_path).resize((512, 512))
ref_img_pil = Image.open(reference_image_path).resize((512, 512))

plt.figure(figsize=(18, 6))

# [왼쪽] Source Image
plt.subplot(1, 3, 1)
plt.imshow(src_img_pil)
plt.title("Source (Car 1)\nOriginal Structure", fontsize=14, fontweight='bold')
plt.axis('off')

# [가운데] Reference Image
plt.subplot(1, 3, 2)
plt.imshow(ref_img_pil)
plt.title("Reference (Car 2)\nTarget Style", fontsize=14, fontweight='bold')
plt.axis('off')

# [오른쪽] Result Image
plt.subplot(1, 3, 3)
plt.imshow(result_image)
param_str = f"N:{target_nti_strength}, C:{target_cross_replace}, S:{target_self_replace}"
plt.title(f"Result (Synthesized)\n{param_str}", fontsize=14, fontweight='bold', color='blue')
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# =============================================================================
# 실험 파라미터 설정 (여기만 수정해서 돌리세요)
# =============================================================================
params_to_test = []

# 예: NTI(구조 유지력) 2개 x Cross(교체 강도) 5개 = 10개 조합
nti_list = [0.3, 0.7]
cross_list = [0.4,0.5, 0.6, 0.7, 0.8]
fixed_self = 0.3                # Self Attention은 고정

for nti in nti_list:
    for cross in cross_list:
        params_to_test.append((nti, cross, fixed_self))

print(f"Generating {len(params_to_test)} images...")

# =============================================================================
# 이미지 생성 루프
# =============================================================================
results_images = []
results_labels = []

target_prompt = source_prompt.replace(source_word, placeholder_token)
prompts = [source_prompt, target_prompt]
blend_word = (((source_word,), (placeholder_token,)),)

for idx, (nti, cross, self_rep) in enumerate(params_to_test):
    print(f"[{idx+1}/{len(params_to_test)}] NTI:{nti}, Cross:{cross} ...", end="")

    try:
        #메모리 확보 (런타임 끊김 방지)
        torch.cuda.empty_cache()
        reset_attention_hooks(ldm_stable)

        # 1. Soft NTI (Slerp) - x_t 재사용
        random_noise = torch.randn_like(x_t)
        start_latents = slerp(1.0 - nti, x_t.flatten(1), random_noise.flatten(1))
        start_latents = start_latents.view_as(x_t)

        # 2. Controller 설정
        controller = make_controller(
            prompts,
            True,
            {"default_": cross},
            self_rep,
            blend_word,
            None
        )

        # 3. 이미지 생성
        images_edit, _ = text2image_ldm_stable(
            ldm_stable,
            prompts,
            controller,
            num_inference_steps=50,
            guidance_scale=7.5,
            generator=None,
            latent=start_latents,
            uncond_embeddings=uncond_embeddings, # Part 2에서 만든 것 재사용
            start_time=50,
            return_type="image",
        )

        results_images.append(images_edit[0])
        results_labels.append(f"N:{nti}, C:{cross}")
        print(" Done.")

    except Exception as e:
        print(f" Error: {e}")
        results_images.append(np.zeros((512, 512, 3)))
        results_labels.append("Error")

# =============================================================================
# 결과 시각화
# =============================================================================
plt.figure(figsize=(20, 8))
rows, cols = 2, 5  # 10개 기준

for i in range(len(results_images)):
    if i >= rows * cols: break
    ax = plt.subplot(rows, cols, i + 1)
    ax.imshow(results_images[i])
    ax.set_title(results_labels[i], fontsize=12, fontweight='bold')
    ax.axis('off')

plt.tight_layout()
plt.show()

## gird search

In [None]:
# ============================================================================
# Grid Search Pipeline for Hyperparameter Tuning
# ============================================================================

def identity_swapping_grid_search(
    model,
    tokenizer,
    null_inversion,
    source_image_path: str,
    reference_image_path: str,
    source_prompt: str,
    source_word: str,
    placeholder_token: str = "<sks-cat>",
    # Grid Search 파라미터
    experiments: list = None,
    ti_steps: int = 500,
    ti_lr: float = 1e-3,
    num_ddim_steps: int = 50,
    guidance_scale: float = 7.5,
    device: str = "cuda"
):
    """
    여러 하이퍼파라미터 조합을 한 번에 테스트

    Args:
        experiments: 테스트할 (nti_strength, cross_replace, self_replace) 조합 리스트
                    예: [(0.7, 0.4, 0.2), (0.8, 0.7, 0.2), ...]

    Returns:
        results: 각 실험 결과 이미지 리스트
        image_gt: 원본 이미지
        ref_image: 레퍼런스 이미지
    """

    # 기본 실험 설정
    if experiments is None:
        experiments = [
            (0.7, 0.4, 0.2),  # (NTI, Cross, Self)
            (0.8, 0.7, 0.2),
            (0.9, 0.4, 0.3),
            (0.9, 0.9, 0.3),
        ]

    print("="*70)
    print("Grid Search Started")
    print("="*70)
    print(f"  Total Experiments: {len(experiments)}")
    print(f"  Source (A): {source_image_path}")
    print(f"  Reference (B): {reference_image_path}")
    print()

    reset_attention_hooks(model)

    # ===================================================================
    # Step 1: Textual Inversion (한 번만 수행)
    # ===================================================================
    print("[Step 1] Training Textual Inversion (Once)")
    train_textual_inversion_with_mask(
        model, tokenizer, reference_image_path,
        placeholder_token=placeholder_token,
        initializer_token=source_word,
        num_train_steps=ti_steps,
        lr=ti_lr,
        device=device
    )

    # ===================================================================
    # Step 2: Null-text Inversion (한 번만 수행)
    # ===================================================================
    print(f"\n[Step 2] Null-text Inversion on Source Image")
    (image_gt, image_rec), x_t, uncond_embeddings = null_inversion.invert(
        source_image_path,
        source_prompt,
        offsets=(0, 0, 0, 0),
        verbose=False
    )
    print(f"Inversion completed")

    # 레퍼런스 이미지 로드
    ref_image = Image.open(reference_image_path).convert("RGB").resize((512, 512))
    ref_image = np.array(ref_image)

    # ===================================================================
    # Step 3: 각 하이퍼파라미터 조합에 대해 이미지 생성
    # ===================================================================
    results = []
    target_prompt = source_prompt.replace(source_word, placeholder_token)

    print(f"\n[Step 3] Running {len(experiments)} Experiments")
    print(f"  Source Prompt: '{source_prompt}'")
    print(f"  Target Prompt: '{target_prompt}'")
    print()

    for idx, (nti_strength, cross_replace, self_replace) in enumerate(experiments, 1):
        print(f"  [{idx}/{len(experiments)}] NTI:{nti_strength}, Cross:{cross_replace}, Self:{self_replace}")

        try:
            # Soft NTI
            random_noise = torch.randn_like(x_t)
            start_latents = slerp(1.0 - nti_strength, x_t.flatten(1), random_noise.flatten(1))
            start_latents = start_latents.view_as(x_t)

            # P2P Controller
            prompts = [source_prompt, target_prompt]
            blend_word = (((source_word,), (placeholder_token,)),)

            controller = make_controller(
                prompts,
                True,
                {"default_": cross_replace},
                self_replace,
                blend_word,
                None
            )

            # 이미지 생성
            images_edit, _ = text2image_ldm_stable(
                model,
                prompts,
                controller,
                num_inference_steps=num_ddim_steps,
                guidance_scale=guidance_scale,
                generator=None,
                latent=start_latents,
                uncond_embeddings=uncond_embeddings,
                start_time=num_ddim_steps,
                return_type="image",
            )

            results.append({
                'image': images_edit[0],
                'params': f"NTI:{nti_strength}\nCr:{cross_replace}, Sf:{self_replace}"
            })

        except Exception as e:
            print(f"Error: {e}")
            results.append({
                'image': np.zeros((512, 512, 3), dtype=np.uint8),
                'params': "Error"
            })

    reset_attention_hooks(model)

    print("\n" + "="*70)
    print("Grid Search Completed!")
    print("="*70)

    return results, image_gt, ref_image


# ============================================================================
# 결과 시각화 함수
# ============================================================================

def visualize_grid_results(results, image_gt, ref_image, figsize=(20, 6)):
    """
    Grid Search 결과를 한 줄로 시각화

    Layout: [원본] [실험1] [실험2] ... [실험N] [레퍼런스]
    """
    import matplotlib.pyplot as plt

    num_experiments = len(results)
    total_images = num_experiments + 2  # 원본 + 실험들 + 레퍼런스

    plt.figure(figsize=figsize)

    # (1) 원본 이미지
    plt.subplot(1, total_images, 1)
    plt.imshow(image_gt)
    plt.title("Original (A)", fontsize=10, fontweight='bold')
    plt.axis('off')

    # (2) 실험 결과들
    for i, result in enumerate(results, 2):
        plt.subplot(1, total_images, i)
        plt.imshow(result['image'])
        plt.title(result['params'], fontsize=9)
        plt.axis('off')

    # (3) 레퍼런스 이미지
    plt.subplot(1, total_images, total_images)
    plt.imshow(ref_image)
    plt.title("Ref (B)", fontsize=10, fontweight='bold')
    plt.axis('off')

    plt.tight_layout()
    plt.show()




In [None]:
# ============================================================================
# 실행 예시
# ============================================================================

# 실험할 하이퍼파라미터 조합들 정의
experiments = [
    (0.7, 0.4, 0.2),  # 현재 설정 (기준)
    (0.8, 0.7, 0.2),  # 배경/구도 강화
    (0.9, 0.4, 0.3),  # 전체적인 보존력 강화
    (0.9, 0.9, 0.3),  # 구조는 강하게, 텍스처는 유연하게
]

# Grid Search 실행
results, gt, ref = identity_swapping_grid_search(
    model=ldm_stable,
    tokenizer=tokenizer,
    null_inversion=null_inversion,
    source_image_path="./gnochi_mirror.jpeg",
    reference_image_path="./cat.jpg",
    source_prompt="a cat sitting next to a mirror",
    source_word="cat",
    placeholder_token="<sks-cat>",
    experiments=experiments,
    ti_steps=500,
)

# 결과 시각화
visualize_grid_results(results, gt, ref, figsize=(20, 6))