# One-Prompt-One-Story: Free-Lunch Consistent Text-to-Image Generation Using a Single Prompt——ICLR 2025 Spotlight

解决T2I的主体一致性问题，可以视为SuppressEOT的续作

仓库代码组件
- Consistent Image Generation Code: `main.py`
- Gradio Code: `app.py`
- Benchmark Generation Code: `resource/gen_benchmark.py`

## Experiments


### Context Consitency in Text Embedding
- Multiprompt generation:$[P_0;P_i],i\in[1,N]$
    - $P_0$:identity prompt,$P_i$:i-th frame prompt
    - 第1帧: $[P_{0};P_{1}]$ → “一只可爱的水彩画风格的小猫在花园里”

    - 第2帧: $[P_{0};P_{2}]$ → “一只可爱的水彩画风格的小猫穿着超人斗篷”

    - 第3帧: $[P_{0};P_{3}]$ → “一只可爱的水彩画风格的小猫戴着项圈和铃铛”

    - 第4帧: $[P_{0};P_{4}]$ → “一只可爱的水彩画风格的小猫坐在篮子里”

    - 第5帧: $[P_{0};P_{5}]$ → “一只可爱的水彩画风格的小猫穿着可爱的毛衣”
- Singleprompt generation:$[P_0;P_1;P_2;...;P_N]$
     - 单一提示 P: “一只可爱的水彩画风格的小猫在花园里，穿着超人斗篷，戴着项圈和铃铛，坐在篮子里，穿着可爱的毛衣”,然后，我们通过调整每个场景描述的权重来生成每个场景的图像。例如：

    - 第1帧：通过增强“在花园里”的权重，生成“一只可爱的水彩画风格的小猫在花园里”的图像。

    - 第2帧：通过增强“穿着超人斗篷”的权重，生成“一只可爱的水彩画风格的小猫穿着超人斗篷”的图像。

    - 第3帧：通过增强“戴着项圈和铃铛”的权重，生成“一只可爱的水彩画风格的小猫戴着项圈和铃铛”的图像。

    - 第4帧：通过增强“坐在篮子里”的权重，生成“一只可爱的水彩画风格的小猫坐在篮子里”的图像。

    - 第5帧：通过增强“穿着可爱的毛衣”的权重，生成“一只可爱的水彩画风格的小猫穿着可爱的毛衣”的图像。
-  multiprompt embedding:$\mathcal{C} _i= \tau _\xi ( [ \mathcal{P} _0; \mathcal{P} _i] ) =[c^{SOT},c^{\mathcal{P}_0},c^{\mathcal{P}_i},c^{EOT}],(i=1,\ldots,N)$
- singleprompt embedding:$\mathcal{C}=\tau_\xi([\mathcal{P}_0;\mathcal{P}_1;\ldots;\mathcal{P}_N])=[\boldsymbol{c}^{SOT},\boldsymbol{c}^{\mathcal{P}_0},\boldsymbol{c}^{\mathcal{P}_1},\ldots,\boldsymbol{c}^{\mathcal{P}_N},\boldsymbol{c}^{EOT}].$

![alt text](../../Image/t-SNE.png)

### Context Consitency in Imgae Space

Naive Prompt Reweighting：使用一个缩放因子对特定的embeeding进行scaling

为了**可视化图像之间的身份相似性**，使用Carvkit移除背景，使用DINO-v2提取视觉特征，然后将这些特征通过t-SNE投影到2D空间

![alt text](<../../Image/屏幕截图 2025-05-26 131702.png>)

## Method
- Prompt Consolidation(above)
- Singular Value Reweighting(SVR)
- Identity-Preserving Cross-Attention (IPCA)

### Singular Value Reweighting
延续SupressEOT的思路，进行改进<p>
奇异值重新加权（Singular-Value Reweighting）的使用可以减少单提示生成中帧描述的混合
- SVR+:$\hat{\sigma}=\beta e^{\alpha\sigma}*\sigma$
- SVR-:$\tilde{\sigma}=\beta^{\prime}e^{-\alpha^{\prime}\hat{\sigma}}*\hat{\sigma}$

In [None]:
def swr_single_prompt_embeds(swr_words,prompt_embeds,prompt,tokenizer,alpha=1.0, beta=1.2, zero_eot=False):
    # swr_words:Suppress words,想要弱化的词
    # prompt:原始的完整文本提示字符串
    # alpha,beta:punish_wight函数的参数
    # zero_eot:决定如何处理EOT token的输入
    punish_indices = []

    prompt_tokens = prompt2tokens(tokenizer,prompt)
    
    words_tokens = prompt2tokens(tokenizer,swr_words)
    words_tokens = [word for word in words_tokens if word != '<|endoftext|>' and word != '<|startoftext|>']
    index_of_words = find_sublist_index(prompt_tokens,words_tokens)
    
    if index_of_words != -1:
        punish_indices.extend([num for num in range(index_of_words, index_of_words+len(words_tokens))])
    
    if zero_eot:
        # 找到token sequence中的所有EOT的索引
        eot_indices = [index for index, word in enumerate(prompt_tokens) if word == '<|endoftext|>']
        # 将EOT token索引对应的embedding乘0.9，进行削弱
        prompt_embeds[eot_indices] *= 9e-1
        pass
    else:
        punish_indices.extend([index for index, word in enumerate(prompt_tokens) if word == '<|endoftext|>'])

    punish_indices = list(set(punish_indices))
    # 从完整的 prompt_embeds 张量中，提取出所有索引在 punish_indices 列表中的 token 嵌入
    # wo_batch：[num_indices_to_punish, embedding_dim]
    wo_batch = prompt_embeds[punish_indices]
    # ------------------------------------------------------------------
    wo_batch = punish_wight(wo_batch.T.to(float), 
                            wo_batch.size(0), 
                            alpha=alpha, 
                            beta=beta, 
                            calc_similarity=False).T.to(prompt_embeds.dtype)
    # ------------------------------------------------------------------
    prompt_embeds[punish_indices] = wo_batch

In [1]:
# 将prompt转换为input_ids之后padding到max_length，然后decode为token sequence
def prompt2tokens(tokenizer, prompt):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    tokens = []
    for text_input_id in text_input_ids[0]:
        token = tokenizer.decoder[text_input_id.item()]
        tokens.append(token)
    return tokens

In [3]:
import torch
from scipy.spatial.distance import cdist
def punish_wight(tensor, latent_size, alpha=1.0, beta=1.2, calc_similarity=False):
    u, s, vh = torch.linalg.svd(tensor)
    u = u[:,:latent_size]
    zero_idx = int(latent_size * alpha)

    if calc_similarity:
        _s = s.clone()
        _s *= torch.exp(-alpha*_s) * beta
        _s[zero_idx:] = 0
        _tensor = u @ torch.diag(_s) @ vh
        dist = cdist(tensor[:,0].unsqueeze(0).cpu(), _tensor[:,0].unsqueeze(0).cpu(), metric='cosine')
        print(f'The distance between the word embedding before and after the punishment: {dist}')
    # ------------------------------------------------------------
    s *= torch.exp(-alpha*s) * beta
    # ------------------------------------------------------------
    tensor = u @ torch.diag(s) @ vh
    return tensor

### Identity-Preserving Cross-Attention 
- 交叉注意力图：能够捕捉标记的特征信息，即与文本提示中的特定部分相关的语义信息<p>
  自注意力：保留了图像的布局信息和形状细节，即与图像的整体结构和外观相关的视觉信息<p>
(From 《Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing (CVPR 2024)》)
- https://github.com/alibaba/EasyNLP/tree/master/diffusion/FreePromptEditing
- SVR在单提示生成中减少不同帧描述之间的混合，但它可能对单提示内的上下文一致性产生负面影响

SVR之后,我们得到了更新后的$\hat{C}$;在去噪过程中，在Cross-attention过程得到$\hat{Q},\hat{K},\hat{V}$。<p>
将$P_i,i\in[1,N]$的$\hat{K}$设置为0,即$\bar{K}$。<p>
将两者拼接起来，$\tilde{K}=\text{Concat}(\tilde{K}^T,\bar{K}^T)^T$,同理,$\tilde{\mathcal{V}}=\mathrm{Concat}(\tilde{\mathcal{V}}^{\top},\bar{\mathcal{V}}^{\top})^{\top}$

$\tilde{\mathcal{A}}=softmax\left(\tilde{\mathcal{Q}}\tilde{\mathcal{K}}^\top/\sqrt{d}\right)$

In [None]:
import torch
def ipca(q, k, v, scale, unet_controller: Optional[UNetController] = None): # eg. q: [4,20,1024,64] k,v: [4,20,77,64] 
    # 沿着batch维度给成negative prompt和positive prompt的qkv
    # q:batch_size,num_heads,seq_len_q,head_dim
    # k,v:batch_size,num_heads,seq_len_k,head_dim
    q_neg, q_pos = torch.split(q, q.size(0) // 2, dim=0)
    k_neg, k_pos = torch.split(k, k.size(0) // 2, dim=0)
    v_neg, v_pos = torch.split(v, v.size(0) // 2, dim=0)

    # 1. negative_attn，negative注意力计算
    # 将后两个维度进行转职以便进行矩阵惩罚
    scores_neg = torch.matmul(q_neg, k_neg.transpose(-2, -1)) * scale
    # scores_neg:..,..,seq_len_q,seq_len_k
    attn_weights_neg = torch.softmax(scores_neg, dim=-1)
    attn_output_neg = torch.matmul(attn_weights_neg, v_neg)

    # 2. positive_attn (we do ipca only on positive branch)

    # 2.1 ipca 
    # k_pos: [batch_size_pos, num_heads, seq_len_k, head_dim]
    # k_pos.transpose(-2, -1): [batch_size_pos, num_heads, head_dim, seq_len_k]
    # tuple(...): 将按头分割的Key张量变成一个元组，每个元素是 [batch_size_pos, head_dim, seq_len_k]
    # torch.cat(..., dim=2):  [batch_size_pos, head_dim, num_heads * seq_len_k]
    # unsequeeze(0):[1, batch_size_pos, head_dim, num_heads * seq_len_k]
    # .repeat:[batch_size_pos, batch_size_pos, head_dim, num_heads * seq_len_k]

    # 拼接后的K
    k_plus = torch.cat(tuple(k_pos.transpose(-2, -1)), dim=2).unsqueeze(0).repeat(k_pos.size(0),1,1,1) # 𝐾+ = [𝐾1 ⊕ 𝐾2 ⊕ . . . ⊕ 𝐾𝑁 ]
    # 拼接后的V
    v_plus = torch.cat(tuple(v_pos), dim=1).unsqueeze(0).repeat(v_pos.size(0),1,1,1) # 𝑉+ = [𝑉1 ⊕ 𝑉2 ⊕ . . . ⊕ 𝑉𝑁 ]


    # 2.2 apply mask
    if unet_controller is not None:
        scores_pos = torch.matmul(q_pos, k_plus) * scale

        # 2.2.1 apply dropout mask
        dropout_mask = gen_dropout_mask(scores_pos.shape, unet_controller, unet_controller.Ipca_dropout) # eg: [a,1024,154]   


        # 2.2.3 apply embeds mask
        if unet_controller.Use_embeds_mask:
            apply_embeds_mask(unet_controller,dropout_mask, add_eot=False)

        mask = dropout_mask

        mask = mask.unsqueeze(1).repeat(1,scores_pos.size(1),1,1)
        attn_weights_pos = torch.softmax(scores_pos + torch.log(mask), dim=-1)

    else:
        scores_pos = torch.matmul(q_pos, k_plus) * scale
        attn_weights_pos = torch.softmax(scores_pos, dim=-1)


    attn_output_pos = torch.matmul(attn_weights_pos, v_plus)
    # 3. combine
    attn_output = torch.cat((attn_output_neg, attn_output_pos), dim=0)

    return attn_output

In [None]:
def ipca2(q, k, v, scale, unet_controller: Optional[UNetController] = None): # eg. q: [4,20,1024,64] k,v: [4,20,77,64] 
    if unet_controller.ipca_time_step != unet_controller.current_time_step:
        unet_controller.ipca_time_step = unet_controller.current_time_step
        unet_controller.ipca2_index = 0
    else:
        unet_controller.ipca2_index += 1

    if unet_controller.Store_qkv is True:

        key = f"cross {unet_controller.current_time_step} {unet_controller.current_unet_position} {unet_controller.ipca2_index}"
        unet_controller.k_store[key] = k
        unet_controller.v_store[key] = v

        scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
    else:
        # batch > 1
        if unet_controller.frame_prompt_express_list is not None:
            batch_size = q.size(0) // 2
            attn_output_list = []

            for i in range(batch_size):
                q_i = q[[i, i + batch_size], :, :, :]
                k_i = k[[i, i + batch_size], :, :, :]
                v_i = v[[i, i + batch_size], :, :, :]

                q_neg_i, q_pos_i = torch.split(q_i, q_i.size(0) // 2, dim=0)
                k_neg_i, k_pos_i = torch.split(k_i, k_i.size(0) // 2, dim=0)
                v_neg_i, v_pos_i = torch.split(v_i, v_i.size(0) // 2, dim=0)

                key = f"cross {unet_controller.current_time_step} {unet_controller.current_unet_position} {unet_controller.ipca2_index}"
                q_store = q_i
                k_store = unet_controller.k_store[key]
                v_store = unet_controller.v_store[key]

                q_store_neg, q_store_pos = torch.split(q_store, q_store.size(0) // 2, dim=0)
                k_store_neg, k_store_pos = torch.split(k_store, k_store.size(0) // 2, dim=0)
                v_store_neg, v_store_pos = torch.split(v_store, v_store.size(0) // 2, dim=0)    

                q_neg = torch.cat((q_neg_i, q_store_neg), dim=0)
                q_pos = torch.cat((q_pos_i, q_store_pos), dim=0)
                k_neg = torch.cat((k_neg_i, k_store_neg), dim=0)
                k_pos = torch.cat((k_pos_i, k_store_pos), dim=0)
                v_neg = torch.cat((v_neg_i, v_store_neg), dim=0)
                v_pos = torch.cat((v_pos_i, v_store_pos), dim=0)

                q_i = torch.cat((q_neg, q_pos), dim=0)
                k_i = torch.cat((k_neg, k_pos), dim=0)
                v_i = torch.cat((v_neg, v_pos), dim=0)

                attn_output_i = ipca(q_i, k_i, v_i, scale, unet_controller)
                attn_output_i = attn_output_i[[0, 2], :, :, :]
                attn_output_list.append(attn_output_i)
            
            attn_output_ = torch.cat(attn_output_list, dim=0)
            attn_output = torch.zeros(size=(q.size(0), attn_output_i.size(1), attn_output_i.size(2), attn_output_i.size(3)), device=q.device, dtype=q.dtype)
            for i in range(batch_size):
                attn_output[i] = attn_output_[i*2]
            for i in range(batch_size):
                attn_output[i + batch_size] = attn_output_[i*2 + 1]
        # batch = 1
        else:
            q_neg, q_pos = torch.split(q, q.size(0) // 2, dim=0)
            k_neg, k_pos = torch.split(k, k.size(0) // 2, dim=0)
            v_neg, v_pos = torch.split(v, v.size(0) // 2, dim=0)

            key = f"cross {unet_controller.current_time_step} {unet_controller.current_unet_position} {unet_controller.ipca2_index}"
            q_store = q
            k_store = unet_controller.k_store[key]
            v_store = unet_controller.v_store[key]

            q_store_neg, q_store_pos = torch.split(q_store, q_store.size(0) // 2, dim=0)
            k_store_neg, k_store_pos = torch.split(k_store, k_store.size(0) // 2, dim=0)
            v_store_neg, v_store_pos = torch.split(v_store, v_store.size(0) // 2, dim=0)    

            q_neg = torch.cat((q_neg, q_store_neg), dim=0)
            q_pos = torch.cat((q_pos, q_store_pos), dim=0)
            k_neg = torch.cat((k_neg, k_store_neg), dim=0)
            k_pos = torch.cat((k_pos, k_store_pos), dim=0)
            v_neg = torch.cat((v_neg, v_store_neg), dim=0)
            v_pos = torch.cat((v_pos, v_store_pos), dim=0)

            q = torch.cat((q_neg, q_pos), dim=0)
            k = torch.cat((k_neg, k_pos), dim=0)
            v = torch.cat((v_neg, v_pos), dim=0)

            attn_output = ipca(q, k, v, scale, unet_controller)
            attn_output = attn_output[[0, 2], :, :, :]
    
    return attn_output

In [None]:
def apply_embeds_mask(unet_controller: Optional[UNetController],dropout_mask, add_eot=False):   
    id_prompt = unet_controller.id_prompt
    prompt_tokens = prompt2tokens(unet_controller.tokenizer,unet_controller.prompts[0])
    
    words_tokens = prompt2tokens(unet_controller.tokenizer,id_prompt)
    words_tokens = [word for word in words_tokens if word != '<|endoftext|>' and word != '<|startoftext|>']
    index_of_words = find_sublist_index(prompt_tokens,words_tokens)    
    index_list = [index+77 for index in range(index_of_words, index_of_words+len(words_tokens))]
    if add_eot:
        index_list.extend([index+77 for index, word in enumerate(prompt_tokens) if word == '<|endoftext|>'])

    mask_indices = torch.arange(dropout_mask.size(-1), device=dropout_mask.device)
    mask = (mask_indices >= 78) & (~torch.isin(mask_indices, torch.tensor(index_list, device=dropout_mask.device)))
    dropout_mask[0, :, mask] = 0


In [None]:
def gen_dropout_mask(out_shape, unet_controller: Optional[UNetController], drop_out):
    gen_length = out_shape[3]
    attn_map_side_length = out_shape[2]

    batch_num = out_shape[0]
    mask_list = []
    
    for prompt_index in range(batch_num):
        start = prompt_index * int(gen_length / batch_num)
        end = (prompt_index + 1) * int(gen_length / batch_num)
    
        mask = torch.bernoulli(torch.full((attn_map_side_length,gen_length), 1 - drop_out, dtype=unet_controller.torch_dtype, device=unet_controller.device))        
        mask[:, start:end] = 1

        mask_list.append(mask)

    concatenated_mask = torch.stack(mask_list, dim=0)
    return concatenated_mask

![alt text](../../Image/overall_2.png)