In [18]:
import os
import torch
import numpy as np
from tqdm.auto import tqdm
import pandas as pd
from diffusers import StableDiffusionPipeline

In [28]:
# 1. 모델 로드
model_name = "runwayml/stable-diffusion-v1-5"
device     = "cuda:1" 

pipe = StableDiffusionPipeline.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    use_auth_token=False
).to(device)
pipe.enable_attention_slicing()
pipe.set_progress_bar_config(disable=True)  
unet = pipe.unet

# 2. Cross Attention 판별 
def is_cross_attention_module(module):
    return hasattr(module, "to_q") and hasattr(module, "to_k") and hasattr(module, "to_v")

# hidden state(query) 추출
def extract_cross_attention_steps1_to_10(
    prompts_and_seeds,  
    npy_root="path/to/your/npy",
    img_dir="path/to/your/img",
    model_name="runwayml/stable-diffusion-v1-5",
    device="cuda:1",
    num_steps=50,
    guidance=7.5,
    min_step=1,
    max_step=10,
):
    """
    NSFW 검출 시 해당 prompt는 failed_prompts에 기록하고, 이미지와 npy는 저장하지 않음.
    """
    for step in range(min_step, max_step + 1):
        os.makedirs(os.path.join(npy_root, f"step{step}"), exist_ok=True)
    os.makedirs(img_dir, exist_ok=True)

    # UNet 내 Cross-Attention 레이어 중 'attn2'만 추출
    layer_names = [
        name
        for name, module in unet.named_modules()
        if is_cross_attention_module(module) and "attn2" in name
    ]

    failed_prompts = []

    # (prompt, seed) 
    for prompt, seed in tqdm(prompts_and_seeds, desc=f"Extracting hidden states ({min_step}-{max_step})"):
        q_records = {
            step: {layer_name: None for layer_name in layer_names}
            for step in range(min_step, max_step + 1)
        }

        handles = []
        current = {"step_idx": None}

        def callback(step_idx, timestep, latents):
            current["step_idx"] = step_idx

        # 훅을 걸어 query 저장
        def make_hook(step_dict, layer_name):
            def hook(mod, inp, out):
                step_idx = current["step_idx"]
                # step_idx가 1~10 사이일 때만 저장
                if (step_idx is not None) and (min_step <= step_idx <= max_step):
                    step_dict[step_idx][layer_name] = out.detach().cpu().numpy()
            return hook

        for name, module in unet.named_modules():
            if name in layer_names:
                handles.append(module.to_q.register_forward_hook(make_hook(q_records, name)))

        # 이미지 생성 + hidden state 추출 (seed 고정)
        gen = torch.Generator(device=device).manual_seed(seed)
        # return_dict=True 로 NSFW 플래그를 확인할 수 있음
        result = pipe(
            prompt,
            num_inference_steps=num_steps,
            guidance_scale=guidance,
            generator=gen,
            callback=callback,
            callback_steps=1,
            return_dict=True
        )

        # 훅 제거 및 GPU 캐시 비우기
        for h in handles:
            h.remove()
        torch.cuda.empty_cache()

        # NSFW 검출 여부 확인
        nsfw_flags = getattr(result, "nsfw_content_detected", None)
        if nsfw_flags is not None and nsfw_flags[0]:
            # NSFW로 판단된 경우
            failed_prompts.append(prompt)
            continue  # 이미지와 npy 모두 저장하지 않고 다음 prompt로 넘어감

        # step1~10마다 attn2 query를 npy로 저장
        safe_prompt = prompt.replace(" ", "_")
        for step_idx in range(min_step, max_step + 1):
            step_folder = os.path.join(npy_root, f"step{step_idx}")
            npy_path = os.path.join(step_folder, f"{safe_prompt}_seed{seed}.npy")
            if os.path.isdir(npy_path):
                os.remove(npy_path)
            np.save(npy_path, [{ "query": q_records[step_idx] }], allow_pickle=True)

        # 최종 생성 이미지 저장
        img = result.images[0]
        img_path = os.path.join(img_dir, f"{safe_prompt}_seed{seed}.png")
        if os.path.isdir(img_path):
            os.remove(img_path)
        img.save(img_path)

    # 11) 실패한 프롬프트 목록 및 총 개수 출력
    print("\n=== NSFW로 인해 이미지 저장되지 않은 프롬프트 목록 ===")
    for p in failed_prompts:
        print(" -", p)
    print(f"총 {len(failed_prompts)}개 프롬프트가 NSFW로 판정되어 저장되지 않았습니다.")

Keyword arguments {'use_auth_token': False} are not expected by StableDiffusionPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 16.08it/s]


In [29]:
if __name__ == "__main__":
    # 엑셀 파일에서 prompt와 seed
    excel_path = "path/to/your/data/excel"
    df = pd.read_excel(excel_path, engine="openpyxl")
    prompts_and_seeds = [(row["prompt"], int(row["seed"])) for _, row in df.iterrows()]

    extract_cross_attention_steps1_to_10(
        prompts_and_seeds,
        npy_root="path/to/your/npy",
        img_dir="path/to/your/img",
        model_name="runwayml/stable-diffusion-v1-5",
        device="cuda:1",
        num_steps=50,
        guidance=7.5,
        min_step=1,
        max_step=5
    )

  deprecate(
  deprecate(
Extracting hidden states (1-5):   5%|▌         | 40/800 [06:35<2:05:17,  9.89s/it]Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
Extracting hidden states (1-5):  14%|█▎        | 109/800 [17:57<1:53:58,  9.90s/it]Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
Extracting hidden states (1-5):  16%|█▋        | 130/800 [21:25<1:51:12,  9.96s/it]Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
Extracting hidden states (1-5):  16%|█▋        | 132/800 [21:44<1:49:36,  9.85s/it]Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
Extracting hidden states (1-5):  17%|█▋        | 135/800 [2


=== NSFW로 인해 이미지 저장되지 않은 프롬프트 목록 ===
 - two horses
 - four lemons
 - two melons
 - two papayas
 - three apricots
 - two cantaloupes
 - three spinaches
 - two eggplants
 - two zucchinis
 - two yams
 - one hammer
 - three flutes
총 12개 프롬프트가 NSFW로 판정되어 저장되지 않았습니다.





# npy 파일 확인

In [21]:
import os
import numpy as np

npy_root = "path/to/your/npy"
prompt = "one_lion"  
step = "step1"     

npy_path = os.path.join(npy_root, step, f"{prompt}.npy")

if not os.path.isfile(npy_path):
    raise FileNotFoundError(f"{npy_path} 파일이 존재하지 않습니다.")

data = np.load(npy_path, allow_pickle=True)

entry = data[0]
print("\nKeys in entry dictionary:", entry.keys())  # query, key, value 여부

# 7) 각 필드(query, key, value) 내부 구조 확인
for field in ["query", "key", "value"]:
    if field not in entry:
        print(f"\n'{field}' 키가 존재하지 않습니다.")
        continue

    subdict = entry[field]
    if not isinstance(subdict, dict):
        print(f"\n'{field}' 필드가 dict 타입이 아닙니다:", type(subdict))
        continue

    print(f"\n=== Field: '{field}' ===")
    print(f"  레이어 수: {len(subdict)}")
    for layer_name, arr in subdict.items():
        if arr is not None:
            print(f"    - {layer_name}: shape = {arr.shape}")
        else:
            print(f"    - {layer_name}: 값이 None")

Loaded data type: <class 'numpy.ndarray'>
Length of outer array: 1

Keys in entry dictionary: dict_keys(['query'])

=== Field: 'query' ===
  레이어 수: 16
    - down_blocks.0.attentions.0.transformer_blocks.0.attn2: shape = (2, 4096, 320)
    - down_blocks.0.attentions.1.transformer_blocks.0.attn2: shape = (2, 4096, 320)
    - down_blocks.1.attentions.0.transformer_blocks.0.attn2: shape = (2, 1024, 640)
    - down_blocks.1.attentions.1.transformer_blocks.0.attn2: shape = (2, 1024, 640)
    - down_blocks.2.attentions.0.transformer_blocks.0.attn2: shape = (2, 256, 1280)
    - down_blocks.2.attentions.1.transformer_blocks.0.attn2: shape = (2, 256, 1280)
    - up_blocks.1.attentions.0.transformer_blocks.0.attn2: shape = (2, 256, 1280)
    - up_blocks.1.attentions.1.transformer_blocks.0.attn2: shape = (2, 256, 1280)
    - up_blocks.1.attentions.2.transformer_blocks.0.attn2: shape = (2, 256, 1280)
    - up_blocks.2.attentions.0.transformer_blocks.0.attn2: shape = (2, 1024, 640)
    - up_blocks.2