From 3975dc57617553c758f98af1ca99e1cb165a65bf Mon Sep 17 00:00:00 2001 From: ycsxh <1002533186@qq.com> Date: Tue, 2 Sep 2025 14:45:42 +0800 Subject: [PATCH 1/2] =?UTF-8?q?debug0.1-=E6=B7=BB=E5=8A=A0=E4=BA=86?= =?UTF-8?q?=E8=87=AA=E5=B7=B1=E7=9A=84log=E5=92=8Cmain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/main.py | 59 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 src/diffusers/main.py diff --git a/src/diffusers/main.py b/src/diffusers/main.py new file mode 100644 index 000000000000..782747d89a3b --- /dev/null +++ b/src/diffusers/main.py @@ -0,0 +1,59 @@ +import torch +import logging +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +# 设置日志 +# 设置日志,保存到文件 + +# 设置日志,保存到文件 +logging.basicConfig( + level=logging.DEBUG, + filename='cogvideox.log', # 日志保存到当前目录的 cogvideox.log + filemode='w', # 'w' 覆盖旧日志,'a' 追加 + format='%(asctime)s - %(levelname)s - %(message)s' # 日志格式 +) +logger = logging.getLogger(__name__) + +def main(): + + # 加载CogVideoX模型 + logger.info("Loading CogVideoXPipeline...") + pipe = CogVideoXPipeline.from_pretrained( + "THUDM/CogVideoX-2b", # 或CogVideoX-5b(需更多显存) + torch_dtype=torch.bfloat16 # 节省内存 + # torch_dtype = torch.float16 + ) + + # 优化:CPU卸载,适合低显存设备 + pipe.enable_model_cpu_offload() + logger.info("Model loaded and offloaded to CPU.") + + # 输入提示 + prompt = "A cat wearing a hat dancing in a colorful garden." + logger.info(f"Using prompt: {prompt}") + + try: + # 生成视频 + logger.info("Starting video generation...") + video = pipe( + prompt=prompt, + num_inference_steps=10, # 调试时用小值 + height=480, + width=720, + guidance_scale=6.0, + num_frames=5 # 调试时减少帧数 + ).frames[0] + logger.info(f"Generated video with {len(video)} frames, shape: {video[0].shape}") + + # 导出视频 + export_to_video(video, "output_video.mp4", fps=8) + logger.info("Video saved to output_video.mp4") + + except Exception as e: + logger.error(f"Error during generation: {str(e)}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file From c57826ba609734defe6eafddb552faf86f0329fc Mon Sep 17 00:00:00 2001 From: ycsxh <1002533186@qq.com> Date: Mon, 22 Sep 2025 23:42:42 +0800 Subject: [PATCH 2/2] =?UTF-8?q?vis1.0=EF=BC=8C=E5=A2=9E=E5=8A=A0=E4=BA=86t?= =?UTF-8?q?iming=E5=92=8C=E5=8F=AF=E8=A7=86=E5=8C=96hook?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/analysis/analyzer.py | 0 src/diffusers/main.py | 134 +++- src/diffusers/map.py | 126 ++++ src/diffusers/maputils.py | 219 +++++++ .../transformers/cogvideox_transformer_3d.py | 149 +++-- .../pipelines/cogvideo/pipeline_cogvideox.py | 194 ++++-- src/diffusers/vishook.py | 597 ++++++++++++++++++ 7 files changed, 1275 insertions(+), 144 deletions(-) create mode 100644 src/diffusers/analysis/analyzer.py create mode 100644 src/diffusers/map.py create mode 100644 src/diffusers/maputils.py create mode 100644 src/diffusers/vishook.py diff --git a/src/diffusers/analysis/analyzer.py b/src/diffusers/analysis/analyzer.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/main.py b/src/diffusers/main.py index 782747d89a3b..8964ee39932e 100644 --- a/src/diffusers/main.py +++ b/src/diffusers/main.py @@ -1,36 +1,114 @@ +import sys +sys.path.append('/home/lyc/diffusers/src') # 替换为你的 diffusers 路径 + import torch import logging + +# 先配置 logging(在导入 diffusers 前) +logger = logging.getLogger() # 根 logger +logger.setLevel(logging.DEBUG) + +# +output_path = '/home/lyc/diffusers_output/' + +# 文件 handler +file_handler = logging.FileHandler(output_path + 'cogvideox.log', mode='a') +file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) +logger.addHandler(file_handler) + +import time from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video +from torch.nn import CosineSimilarity +from torch.nn.functional import cosine_similarity -# 设置日志 -# 设置日志,保存到文件 -# 设置日志,保存到文件 -logging.basicConfig( - level=logging.DEBUG, - filename='cogvideox.log', # 日志保存到当前目录的 cogvideox.log - filemode='w', # 'w' 覆盖旧日志,'a' 追加 - format='%(asctime)s - %(levelname)s - %(message)s' # 日志格式 -) -logger = logging.getLogger(__name__) +# 全局变量:统一管理所有模块类型(attn, ff) +module_times = {'attn': {}, 'ff': {}} # {type: {name: time}} +start_events = {'attn': {}, 'ff': {}} # {type: {name: event}} -def main(): +def pre_timing_hook(module_type, module, input, name=""): + if name not in start_events[module_type]: + start_events[module_type][name] = torch.cuda.Event(enable_timing=True) + start_events[module_type][name].record() # 记录开始事件 + +def post_timing_hook(module_type, module, input, output, name=""): + end_event = torch.cuda.Event(enable_timing=True) + end_event.record() + end_event.synchronize() # 只在结束时同步 + elapsed = start_events[module_type][name].elapsed_time(end_event) # GPU 时间(ms) + module_times[module_type][name] = module_times[module_type].get(name, 0) + elapsed + # 无需在这里记录新 start_event(移除重复逻辑) - # 加载CogVideoX模型 +def compute_cosine_similarity(block_idx1, block_idx2, blocks, layer_name="to_k"): + """计算两个块中指定层(to_k 或 to_v)的余弦相似度""" + attn1 = blocks[block_idx1].attn1 + attn2 = blocks[block_idx2].attn1 + if layer_name == "to_k": + weight1 = attn1.to_k.weight.data.flatten() + weight2 = attn2.to_k.weight.data.flatten() + elif layer_name == "to_v": + weight1 = attn1.to_v.weight.data.flatten() + weight2 = attn2.to_v.weight.data.flatten() + else: + raise ValueError("layer_name must be 'to_k' or 'to_v'") + return cosine_similarity(weight1.unsqueeze(0), weight2.unsqueeze(0)).item() + +def main(): + logger.info(logging.getLogger('diffusers.pipelines.cogvideox.pipeline_cogvideox').handlers) logger.info("Loading CogVideoXPipeline...") + print(torch.cuda.is_available()) + print(torch.version.cuda) pipe = CogVideoXPipeline.from_pretrained( "THUDM/CogVideoX-2b", # 或CogVideoX-5b(需更多显存) - torch_dtype=torch.bfloat16 # 节省内存 - # torch_dtype = torch.float16 + torch_dtype=torch.float16 # 节省内存 ) # 优化:CPU卸载,适合低显存设备 pipe.enable_model_cpu_offload() logger.info("Model loaded and offloaded to CPU.") + # 获取 Transformer 模块 + model = pipe.transformer + + # 访问Transformer块(假设模型有.model.transformer.blocks) + blocks = model.transformer_blocks # 列表 of TransformerBlock + block_idx1, block_idx2 = 0, 1 # 比较第0和第1块,可改 + + # 获取注意力层(CogVideoX用UNet-like Transformer,attn有to_q/to_k/to_v) + attn11 = blocks[block_idx1].attn1 + attn12 = blocks[block_idx2].attn1 + + # 提取to_k权重(.data避免梯度) + k_weight1 = attn11.to_k.weight.data.flatten() # flatten成1D向量 + k_weight2 = attn12.to_k.weight.data.flatten() + + # 计算余弦相似度(unsqueeze为batch dim) + sim_k = cosine_similarity(k_weight1.unsqueeze(0), k_weight2.unsqueeze(0)).item() + + # 同理to_v + v_weight1 = attn11.to_v.weight.data.flatten() + v_weight2 = attn12.to_v.weight.data.flatten() + sim_v = cosine_similarity(v_weight1.unsqueeze(0), v_weight2.unsqueeze(0)).item() + + print(f"Block {block_idx1} vs {block_idx2}:") + print(f"to_k 余弦相似度: {sim_k:.6f}") + print(f"to_v 余弦相似度: {sim_v:.6f}") + print(f"两个block的余弦相似度: {compute_cosine_similarity(block_idx1, block_idx2, blocks, "to_k")}") + + # 注册 Hook:避免重叠,只匹配顶层模块(调整匹配条件) + for name, module in model.named_modules(): + if "attn" in name.lower() and not any(sub in name for sub in ['proj', 'qkv']): # 避免内部子模块重叠 + module.register_forward_pre_hook(lambda m, i, n=name: pre_timing_hook('attn', m, i, n)) + module.register_forward_hook(lambda m, i, o, n=name: post_timing_hook('attn', m, i, o, n)) + logger.info(f"注册 Hook 到 attn 模块: {name}") + if "ff" in name.lower() and not any(sub in name for sub in ['net', 'proj']): # 类似,避免 FF 子层 + module.register_forward_pre_hook(lambda m, i, n=name: pre_timing_hook('ff', m, i, n)) + module.register_forward_hook(lambda m, i, o, n=name: post_timing_hook('ff', m, i, o, n)) + logger.info(f"注册 Hook 到 ff 模块: {name}") + # 输入提示 - prompt = "A cat wearing a hat dancing in a colorful garden." + prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." logger.info(f"Using prompt: {prompt}") try: @@ -38,22 +116,34 @@ def main(): logger.info("Starting video generation...") video = pipe( prompt=prompt, - num_inference_steps=10, # 调试时用小值 - height=480, - width=720, - guidance_scale=6.0, - num_frames=5 # 调试时减少帧数 + num_videos_per_prompt=1, + num_inference_steps=10, # 生成步骤越多,质量越高,但速度越慢 + height=480, # 视频高度 + width=720, # 视频宽度 + num_frames=9, # 视频帧数, 8N+1, N<=6 + guidance_scale=6, + generator=torch.Generator(device="cuda").manual_seed(42), ).frames[0] - logger.info(f"Generated video with {len(video)} frames, shape: {video[0].shape}") + logger.info(f"Generated video with {len(video)} frames, shape: {video[0].size}") # 导出视频 - export_to_video(video, "output_video.mp4", fps=8) + export_to_video(video, output_path + "output_video.mp4", fps=16) logger.info("Video saved to output_video.mp4") except Exception as e: logger.error(f"Error during generation: {str(e)}") raise + finally: + # 强制刷新日志,确保写入文件 + for handler in logger.handlers: + handler.flush() + + # 输出总耗时 + total_attention_time = sum(module_times['attn'].values()) + print(f"所有 Attention 模块总耗时: {total_attention_time:.2f} ms") + total_ff_time = sum(module_times['ff'].values()) + print(f"所有 FF 模块总耗时: {total_ff_time:.2f} ms") if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/diffusers/map.py b/src/diffusers/map.py new file mode 100644 index 000000000000..a7cb8fea636a --- /dev/null +++ b/src/diffusers/map.py @@ -0,0 +1,126 @@ +import sys +import os +sys.path.append('/home/lyc/diffusers/src') +# import torch +import logging + +# 先配置 logging(在导入 diffusers 前) +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +output_path = '/home/lyc/diffusers_output/' +save_dir = '/home/lyc/diffusers_output/attn_maps' + +# 文件 handler +file_handler = logging.FileHandler(output_path + 'cogvideox.log', mode='a') +file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) +logger.addHandler(file_handler) + +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video +import torch +import os, re, math, numpy as np, torch +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from diffusers.hooks import ModelHook, HookRegistry +from diffusers.maputils import compute_query_redundancy_cosine_lowmem +from diffusers.maputils import save_redundancy_heatmap_lowmem +from matplotlib.colors import LinearSegmentedColormap +import torch.nn.functional as F +from diffusers.pipelines import CogVideoXPipeline +from diffusers.vishook import AttnCaptureHook, LatentFrameVizHook, TransformerStepHook, assign_layer_ids_and_register + +# 通过,解耦:保存路径 +def testAttnCaptureHook(pipe): + shared_state = {"last_timestep": None, "step_index": -1, "timestep": None} + HookRegistry.check_if_exists_or_initialize(pipe.transformer).register_hook( + TransformerStepHook(shared_state), "step_hook" + ) + # 配置:只捕获5层,第10步,立即处理,记得设置计算的是哪种图,因为都一起计算的话显存不够? + attn_hook = AttnCaptureHook( + shared_state, + target_layers=list(range(3)), # 前5层 + target_heads=[0, 1], # 前2个头 + target_steps=[10], # 只捕获第10步 + store_limit_per_layer=1, + force_square=True, + max_sequence_length=51200, + process_immediately=True, # 立即处理 + attn_qk_map=True, # 计算QK的attention score map + redundancy_q_map=False, # 计算query之间的冗余度 + output_dir=save_dir, # 输出目录 + ) + total_layers = assign_layer_ids_and_register(pipe.transformer, attn_hook) + logger.info(f"Registered attention capture on {total_layers} layers.") + +def testLatentFrameVizHook(pipe): + # 2) 注册 hooks + shared_state = {"last_timestep": None, "step_index": -1, "timestep": None} + HookRegistry.check_if_exists_or_initialize(pipe.transformer).register_hook( + TransformerStepHook(shared_state), "step_hook" + ) + + # 2).a 注册帧级 latent 可视化 hook,只在第10步、第 {3,12} 帧做可视化,最多 4096 tokens: + frame_viz = LatentFrameVizHook( + save_root="/home/lyc/diffusers_output/frame_viz", + target_steps=[12], + target_frames=[4], # 只看第12帧,注意,这里是被 vae 时间压缩过的。。。 + query_indices=[0, 128, 512], # 在 Full Attention 图里额外展示这几个 query + max_hw_tokens=9182, + row_chunk=28, # 视内存调小 + cosine_device="cpu", + cosine_dtype=torch.float32, + temperature=1.0, + decode_latents=True, # 解码 latent + ) + HookRegistry.check_if_exists_or_initialize(pipe.transformer).register_hook(frame_viz, "latent_frame_viz") + + + +def main(): + print(torch.cuda.is_available()) + print(torch.version.cuda) + + # 1) 加载 pipeline + pipe = CogVideoXPipeline.from_pretrained( + "THUDM/CogVideoX-2b", + torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + logger.info("Model loaded and offloaded to CPU.") + + # 2) 注册 hooks + + # 2).a 注册帧级 latent 可视化 hook + testLatentFrameVizHook(pipe) + + # 2).b 注册 attention capture hook + # testAttnCaptureHook(pipe) + + # 3) 推理 + prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest..." + logger.info(f"Using prompt: {prompt}") + + logger.info("Starting video generation...") + result = pipe( + prompt=prompt, + num_videos_per_prompt=1, + num_inference_steps=20, # 增加步数以确保能捕获到第10步 + height=480, + width=720, + num_frames=17, + guidance_scale=6, + generator=torch.Generator(device="cuda").manual_seed(42) if torch.cuda.is_available() else None, + ) + video = result.frames[0] + logger.info(f"Generated video with {len(video)} frames, shape: {video[0].size}") + + # 4) 保存视频 + os.makedirs(output_path, exist_ok=True) + export_to_video(video, os.path.join(output_path, "output_video.mp4"), fps=16) + logger.info("Video saved to output_video.mp4") + logger.info("All processing completed!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/diffusers/maputils.py b/src/diffusers/maputils.py new file mode 100644 index 000000000000..1ead97830663 --- /dev/null +++ b/src/diffusers/maputils.py @@ -0,0 +1,219 @@ +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap +import os + +def compute_query_redundancy_cosine_lowmem( + attn_2d: torch.Tensor, + row_chunk: int = 1024, + max_keys: int | None = None, + topk_per_row: int | None = None, + proj_dim: int | None = None, + method: str = "cosine", + device: str = "cpu", + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + 返回 [Tq, Tq] 的冗余矩阵,使用 X_norm @ X_norm.T 的方式避免 [Tq,Tq,Tk] 的广播内存爆炸。 + - max_keys: 先截断列维 Tk(例如 2048) + - topk_per_row: 每个 query 只保留注意力最大的前 k 个 key,其余置零 + - proj_dim: 用随机投影把列维从 Tk 降到 proj_dim(如 256/512) + - row_chunk: 分块计算相似度,限制峰值内存 + """ + x = attn_2d.to(device=device, dtype=dtype, copy=False) + + # 1) 可选:列维降维/裁剪 + Tq, Tk = x.shape + if max_keys is not None and Tk > max_keys: + x = x[:, :max_keys] + Tk = max_keys + + if topk_per_row is not None and topk_per_row < Tk: + # 稀疏化列:每行保留 top-k + vals, idxs = torch.topk(x, k=topk_per_row, dim=1) + x_sparse = torch.zeros_like(x) + x_sparse.scatter_(1, idxs, vals) + x = x_sparse + del x_sparse, vals, idxs + + if proj_dim is not None and proj_dim < Tk: + # 随机高斯投影(Johnson–Lindenstrauss),把列维降到 proj_dim + # 为保证可复现可设固定 seed + with torch.no_grad(): + rand_proj = torch.randn(Tk, proj_dim, device=device, dtype=dtype) / (proj_dim ** 0.5) + x = x @ rand_proj + Tk = proj_dim + + # 2) 行归一化 + x = x / (x.norm(dim=1, keepdim=True) + 1e-8) + + # 3) 分块矩阵乘法:X_norm @ X_norm.T + Tq = x.shape[0] + out = torch.empty((Tq, Tq), dtype=dtype, device=device) + for i in range(0, Tq, row_chunk): + i_end = min(i + row_chunk, Tq) + xi = x[i:i_end] # [chunk, Tk] + # 直接乘全体行;如果内存仍吃紧,可双重分块再分列块 + out[i:i_end] = xi @ x.T # [chunk, Tq] + del xi + torch.cuda.empty_cache() if device.startswith("cuda") else None + + return out + +def save_redundancy_heatmap_lowmem(attn_2d: torch.Tensor, save_path: str, title: str = None): + # 防止超大输入 + Tq, Tk = attn_2d.shape + # 若 Tq 超大,行采样到最多 Nq(例如 1024) + Nq = 1024 + if Tq > Nq: + idx = torch.linspace(0, Tq - 1, steps=Nq).long() + attn_2d = attn_2d[idx] + + # 用低内存余弦相似度 + R = compute_query_redundancy_cosine_lowmem( + attn_2d, + row_chunk=256, # 可根据内存调小 + max_keys=2048, # 限制列维 + topk_per_row=256, # 每行只保留 top-256 的 key + proj_dim=None, # 或者用 proj_dim=256 做随机投影 + device="cpu", + dtype=torch.float32, + method="cosine", + ).clamp_(0, 1).cpu() + + save_redundancy_heatmap(R, save_path, title=title, method="cosine") + del R + +def save_redundancy_heatmap(redundancy_matrix: torch.Tensor, save_path: str, + title: str = None, method: str = "cosine"): + """保存冗余度热力图""" + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # 设置白色背景 + plt.style.use('default') + fig, ax = plt.subplots(figsize=(8, 8), facecolor='white') + ax.set_facecolor('white') + + # 使用蓝色系颜色映射:从白色到深蓝 + colors = ['white', 'lightblue', 'skyblue', 'steelblue', 'blue', 'darkblue', 'navy'] + n_bins = 256 + blue_cmap = LinearSegmentedColormap.from_list('blue_gradient', colors, N=n_bins) + + # 绘制热力图 + im = ax.imshow(redundancy_matrix.detach().cpu().numpy(), + cmap=blue_cmap, + aspect='auto', + vmin=0, # 最小值设为0(白色) + vmax=1) # 最大值设为1(深蓝) + + # 添加颜色条 + cbar = plt.colorbar(im, ax=ax, label=f"Query Redundancy ({method})") + cbar.ax.set_facecolor('white') + + # 设置坐标轴 + ax.set_xlabel("Query Token Index", fontsize=12) + ax.set_ylabel("Query Token Index", fontsize=12) + ax.invert_yaxis() # 反转Y轴,使得第0行在顶部 + + if title: + ax.set_title(title, fontsize=14, pad=20) + + # 设置网格线 + # ax.grid(True, alpha=0.3, color='lightgray') + + # 保存图片 + plt.tight_layout() + plt.savefig(save_path, dpi=200, bbox_inches="tight", + facecolor='white', edgecolor='none') + plt.close() + +def analyze_query_redundancy(cap, out_dir: str, batch_index: int = 0, head_index: int = 0, + methods: list = ["cosine", "pearson"], + aggregate_method: str = "single"): + """ + 分析并保存query token的冗余度 + + Args: + cap: 捕获的注意力数据 + out_dir: 输出目录 + batch_index: batch索引 + head_index: head索引 + methods: 冗余度计算方法列表 + aggregate_method: 注意力聚合方法 + """ + A = cap["attn"] # [B*H, Tq, Tk] + H = cap["num_heads"] or A.shape[0] + b = batch_index or 0 + + # 获取注意力矩阵 + if aggregate_method == "single": + idx = b * H + head_index if (cap["num_heads"] is not None and cap["batch_size"] is not None) else head_index + attn_2d = A[idx] # [Tq, Tk] + suffix = f"_head{head_index:02d}" + else: + if cap["num_heads"] is not None and cap["batch_size"] is not None: + batch_start = b * H + batch_end = (b + 1) * H + heads_attn = A[batch_start:batch_end] # [H, Tq, Tk] + else: + heads_attn = A # [H, Tq, Tk] + + if aggregate_method == "average": + attn_2d = heads_attn.mean(dim=0) # [Tq, Tk] + elif aggregate_method == "max": + attn_2d = heads_attn.max(dim=0)[0] # [Tq, Tk] + elif aggregate_method == "sum": + attn_2d = heads_attn.sum(dim=0) # [Tq, Tk] + else: + raise ValueError(f"Unknown aggregate_method: {aggregate_method}") + + suffix = f"_{aggregate_method}" + + step = cap["step_index"] + ts = cap["timestep"] + tstr = f"t{int(ts)}" if ts is not None else (f"s{step}" if step is not None else "sNA") + + # 为每种方法计算冗余度 + for method in methods: + print(f"Computing {method} redundancy for layer {cap['layer_id']}...") + + # 计算冗余度矩阵 + redundancy_matrix = compute_query_redundancy(attn_2d, method=method) + + # 保存冗余度热力图 + subdir = os.path.join(out_dir, f"{tstr}", f"layer{cap['layer_id']:03d}") + fname = f"{cap['layer_name'].replace('.', '_')}{suffix}_redundancy_{method}.png" + + save_redundancy_heatmap( + redundancy_matrix, + os.path.join(subdir, fname), + title=f"{cap['layer_name']} {tstr} Query Redundancy ({method})", + method=method + ) + + print(f"Saved redundancy heatmap: {fname}") + + # 打印统计信息 + print(f"Redundancy statistics ({method}):") + print(f" Mean: {redundancy_matrix.mean().item():.4f}") + print(f" Std: {redundancy_matrix.std().item():.4f}") + print(f" Max: {redundancy_matrix.max().item():.4f}") + print(f" Min: {redundancy_matrix.min().item():.4f}") + + # 找出最冗余的query对 + # 排除对角线(自己与自己的相似度) + mask = torch.eye(redundancy_matrix.shape[0], dtype=torch.bool) + off_diagonal = redundancy_matrix.masked_select(~mask) + + if len(off_diagonal) > 0: + max_redundancy = off_diagonal.max().item() + print(f" Max off-diagonal redundancy: {max_redundancy:.4f}") + + # 找出最冗余的query对 + max_indices = torch.nonzero(redundancy_matrix == max_redundancy, as_tuple=False) + if len(max_indices) > 0: + q1, q2 = max_indices[0] + print(f" Most redundant query pair: {q1.item()} <-> {q2.item()}") + diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index a8c98bccb86c..4cba504fea79 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -17,10 +17,20 @@ import torch from torch import nn +import functools +import time +import logging +import logging.handlers # 用于添加 handler + +# 先配置 logging(在导入 diffusers 前) +logger = logging.getLogger() # 根 logger +logger.setLevel(logging.DEBUG) +from contextlib import contextmanager from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +# from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 @@ -31,8 +41,29 @@ from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - +# logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def timing_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch.cuda.synchronize() + start = time.perf_counter() # 高精度计时器 + result = func(*args, **kwargs) + torch.cuda.synchronize() + end = time.perf_counter() + logger.info(f"{func.__name__} 执行耗时: {(end - start) * 1000:.2f} ms") + return result + return wrapper + +@contextmanager +def timing_context(label="代码段"): + torch.cuda.synchronize() + start = time.perf_counter() + yield + torch.cuda.synchronize() + end = time.perf_counter() + # print(f"{label} 执行耗时: {(end - start) * 1000:.2f} ms") + logger.info(f"{label} 执行耗时: {(end - start) * 1000:.2f} ms") @maybe_allow_in_graph class CogVideoXBlock(nn.Module): @@ -115,6 +146,7 @@ def __init__( bias=ff_bias, ) + def forward( self, hidden_states: torch.Tensor, @@ -218,6 +250,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"] @register_to_config + @timing_decorator def __init__( self, num_attention_heads: int = 30, @@ -431,6 +464,7 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @timing_decorator def forward( self, hidden_states: torch.Tensor, @@ -460,71 +494,76 @@ def forward( batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - - if self.ofs_embedding is not None: - ofs_emb = self.ofs_proj(ofs) - ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) - ofs_emb = self.ofs_embedding(ofs_emb) - emb = emb + ofs_emb + with timing_context("去噪网络Time Embedding用时"): + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + if self.ofs_embedding is not None: + ofs_emb = self.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = self.ofs_embedding(ofs_emb) + emb = emb + ofs_emb # 2. Patch embedding - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = self.embedding_dropout(hidden_states) + with timing_context("去噪网络Patch Embedding用时"): + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] # 3. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - attention_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - ) + with timing_context("去噪网络Expert Transformer Blocks用时"): + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + attention_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + ) hidden_states = self.norm_final(hidden_states) # 4. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) + with timing_context("去噪网络Final Block用时"): + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) # 5. Unpatchify - p = self.config.patch_size - p_t = self.config.patch_size_t + with timing_context("去噪网络Unpatchify用时"): + p = self.config.patch_size + p_t = self.config.patch_size_t - if p_t is None: - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - else: - output = hidden_states.reshape( - batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p - ) - output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 4ac33b24bbe1..7588455fd05f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -15,6 +15,21 @@ import inspect import math +import time +import functools +import logging +import logging.handlers # 用于添加 handler +from contextlib import contextmanager + +# 先配置 logging(在导入 diffusers 前) +logger = logging.getLogger() # 根 logger +logger.setLevel(logging.DEBUG) + +# 文件 handler +# file_handler = logging.FileHandler('/home/lyc/diffusers/src/diffusers/cogvideox.log', mode='a') +# file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) +# logger.addHandler(file_handler) + from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -26,7 +41,8 @@ from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring +# from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils import is_torch_xla_available, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from .pipeline_output import CogVideoXPipelineOutput @@ -39,8 +55,6 @@ else: XLA_AVAILABLE = False -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - EXAMPLE_DOC_STRING = """ Examples: @@ -83,6 +97,39 @@ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) +def timing_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch.cuda.synchronize() + start = time.perf_counter() # 高精度计时器 + result = func(*args, **kwargs) + torch.cuda.synchronize() + end = time.perf_counter() + logger.info(f"{func.__name__} 执行耗时: {(end - start) * 1000:.2f} ms") + return result + return wrapper + +@contextmanager +def timing_context(label="代码段"): + torch.cuda.synchronize() + start = time.perf_counter() + yield + torch.cuda.synchronize() + end = time.perf_counter() + # print(f"{label} 执行耗时: {(end - start) * 1000:.2f} ms") + logger.info(f"{label} 执行耗时: {(end - start) * 1000:.2f} ms") + +@contextmanager +def timing_accumulator(total_time_list): + torch.cuda.synchronize() + start = time.perf_counter() + yield + end = time.perf_counter() + torch.cuda.synchronize() + # print(end - start) + total_time_list[0] += (end - start) # 累加到列表(可变对象) + +total_time_list = [0.0] # 用列表包裹,便于修改 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( @@ -322,6 +369,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + @timing_decorator def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -348,6 +396,7 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents + @timing_decorator def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = 1 / self.vae_scaling_factor_image * latents @@ -504,6 +553,7 @@ def interrupt(self): @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) + @timing_decorator def __call__( self, prompt: Optional[Union[str, List[str]]] = None, @@ -609,7 +659,8 @@ def __call__( [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - + logger.info("Pipeline __call__ started") + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -706,71 +757,79 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: # for DPM-solver++ old_pred_original_sample = None - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - # predict noise model_output - with self.transformer.cache_context("cond_uncond"): - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred.float() - - # perform guidance - if use_dynamic_cfg: - self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 - ) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, - return_dict=False, - ) - latents = latents.to(prompt_embeds.dtype) - - # call the callback, if provided - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() + with timing_context("去噪循环用时"): + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # 累计计时 + with timing_accumulator(total_time_list): + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() self._current_timestep = None + # print("Denoising complete") + logger.info(f"scheduler.step(): {total_time_list[0] * 1000:.2f} ms") + logger.info("Denoising complete") + + # record the time taken to decode latents if not output_type == "latent": # Discard any padding frames that were added for CogVideoX 1.5 @@ -787,3 +846,4 @@ def __call__( return (video,) return CogVideoXPipelineOutput(frames=video) + diff --git a/src/diffusers/vishook.py b/src/diffusers/vishook.py new file mode 100644 index 000000000000..0fce77ae0c66 --- /dev/null +++ b/src/diffusers/vishook.py @@ -0,0 +1,597 @@ + +import sys +import os +sys.path.append('/home/lyc/diffusers/src') +# import torch +import logging + +# 先配置 logging(在导入 diffusers 前) +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +output_path = '/home/lyc/diffusers_output/' +save_dir = '/home/lyc/diffusers_output/attn_maps' + +# 文件 handler +file_handler = logging.FileHandler(output_path + 'cogvideox.log', mode='a') +file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) +logger.addHandler(file_handler) + +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video +import torch +import os, re, math, numpy as np, torch +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from diffusers.hooks import ModelHook, HookRegistry +from diffusers.maputils import compute_query_redundancy_cosine_lowmem +from diffusers.maputils import save_redundancy_heatmap_lowmem +from matplotlib.colors import LinearSegmentedColormap +import torch.nn.functional as F +from diffusers.pipelines import CogVideoXPipeline + +def save_attention_heatmap(attn_2d: torch.Tensor, save_path: str, title: str = None, + xlabel: str = "Query Tokens", ylabel: str = "Key Tokens"): + """保存注意力热力图,白色背景,QK依赖越强越红""" + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # 设置白色背景 + plt.style.use('default') + fig, ax = plt.subplots(figsize=(8, 6), facecolor='white') + ax.set_facecolor('white') + + # 使用红色系颜色映射:从白色/淡粉红到深红 + from matplotlib.colors import LinearSegmentedColormap + colors = ['white', 'mistyrose', 'lightpink', 'pink', 'hotpink', 'red', 'darkred'] + # n_bins = 256 + # red_cmap = LinearSegmentedColormap.from_list('red_gradient', colors, N=n_bins) + + # 自定义红色渐变:从纯白到深红 + colors = [ + (1.0, 1.0, 1.0), # 纯白色 (RGB) + (1.0, 0.9, 0.9), # 极淡粉红 + (1.0, 0.8, 0.8), # 淡粉红 + (1.0, 0.6, 0.6), # 粉红 + (1.0, 0.4, 0.4), # 中粉红 + (1.0, 0.2, 0.2), # 红色 + (0.8, 0.1, 0.1), # 深红 + (0.6, 0.0, 0.0) # 极深红 + ] + + red_cmap = LinearSegmentedColormap.from_list('custom_red', colors, N=256) + + + # 绘制热力图 + im = ax.imshow(attn_2d.detach().cpu().numpy(), + cmap=red_cmap, + aspect='auto', + vmin=0, # 最小值设为0(白色) + vmax=0.0006) # 最大值设为1(深红) + + # 添加颜色条 + cbar = plt.colorbar(im, ax=ax, label="Attention Score") + cbar.ax.set_facecolor('white') + + # 设置坐标轴 + ax.set_xlabel(xlabel, fontsize=12) + ax.set_ylabel(ylabel, fontsize=12) + ax.invert_yaxis() # 反转Y轴,使得第0行在顶部 + + if title: + ax.set_title(title, fontsize=14, pad=20) + + # 设置网格线(可选,让图更清晰) + # ax.grid(True, alpha=0.3, color='lightgray') + + # 保存图片 + plt.tight_layout() + plt.savefig(save_path, dpi=200, bbox_inches="tight", + facecolor='white', edgecolor='none') + plt.close() + print(f"Saved attention heatmap to {save_path}") + +def save_cap_head(cap, out_dir: str, batch_index: int, head_index: int, + aggregate_method: str = "single"): + """保存注意力图,支持多种聚合方式""" + A = cap["attn"] # [B*H, Tq, Tk] + H = cap["num_heads"] or A.shape[0] + b = batch_index or 0 + + if aggregate_method == "single": + # 单个head + idx = b * H + head_index if (cap["num_heads"] is not None and cap["batch_size"] is not None) else head_index + attn_2d = A[idx] # [Tq, Tk] + suffix = f"_head{head_index:02d}" + else: + # 聚合所有head + if cap["num_heads"] is not None and cap["batch_size"] is not None: + # 有明确的batch和head信息 + batch_start = b * H + batch_end = (b + 1) * H + heads_attn = A[batch_start:batch_end] # [H, Tq, Tk] + else: + # 退化情况:假设所有都是head + heads_attn = A # [H, Tq, Tk] + + if aggregate_method == "average": + attn_2d = heads_attn.mean(dim=0) # [Tq, Tk] + elif aggregate_method == "max": + attn_2d = heads_attn.max(dim=0)[0] # [Tq, Tk] + elif aggregate_method == "sum": + attn_2d = heads_attn.sum(dim=0) # [Tq, Tk] + else: + raise ValueError(f"Unknown aggregate_method: {aggregate_method}") + + suffix = f"_{aggregate_method}" + + step = cap["step_index"] + ts = cap["timestep"] + tstr = f"t{int(ts)}" if ts is not None else (f"s{step}" if step is not None else "sNA") + + # 目录/文件名:step/layer/head + subdir = os.path.join(out_dir, f"{tstr}", f"layer{cap['layer_id']:03d}") + fname = f"{cap['layer_name'].replace('.', '_')}{suffix}.png" + + save_attention_heatmap( + attn_2d, + os.path.join(subdir, fname), + title=f"{cap['layer_name']} {tstr} {aggregate_method}", + xlabel="Query Tokens", + ylabel="Key Tokens" + ) + +def _to_scalar_timestep(ts): + if ts is None: + return None + try: + import numpy as np + except Exception: + np = None + + if torch.is_tensor(ts): + if ts.numel() == 0: + return None + return ts.reshape(-1)[0].item() + if isinstance(ts, (list, tuple)): + return _to_scalar_timestep(ts[0]) if ts else None + if np is not None and isinstance(ts, np.ndarray): + return ts.reshape(-1)[0].item() if ts.size > 0 else None + return int(ts) + +class AttnCaptureHook(ModelHook): + def __init__( + self, + shared_state: dict, + target_layers: list[int] = None, + target_heads: list[int] = None, + target_steps: list[int] = None, # 目标步数 + store_limit_per_layer: int = 1, + eps: float = 1e-3, + force_square: bool = True, + max_sequence_length: int = 51200, + process_immediately: bool = True, # 是否立即处理 + attn_qk_map = False, # 计算QK的attention score map + redundancy_q_map = False, # 计算query之间的冗余度 + output_dir: str = None, # 输出目录 + ): + super().__init__() + self.state = shared_state + self.target_layers = set(target_layers) if target_layers else None + self.target_heads = set(target_heads) if target_heads else None + self.target_steps = set(target_steps) if target_steps else None + self.store_limit_per_layer = store_limit_per_layer + self.eps = eps + self.force_square = force_square + self.max_sequence_length = max_sequence_length + self.process_immediately = process_immediately + self.attn_qk_map = attn_qk_map + self.redundancy_q_map = redundancy_q_map + self.output_dir = output_dir + self.captured = [] + self._layer_store_count: dict[int, int] = {} + + def _layer_allowed(self, layer_id: int) -> bool: + if self.target_layers is None: + return True + return layer_id in self.target_layers + + def _step_allowed(self, step_index: int | None) -> bool: + if self.target_steps is None: + return True + return (step_index is not None) and (step_index in self.target_steps) + + def pre_forward(self, module, *args, **kwargs): + ts = kwargs.get("timestep", None) + ts_val = _to_scalar_timestep(ts) + + has_proj = all(hasattr(module, a) for a in ["to_q", "to_k", "to_v"]) + if not has_proj: + self._can_capture = False + return args, kwargs + + fqn = getattr(module, "_diffusers_fqn", module.__class__.__name__) + layer_id = getattr(module, "_attn_layer_id", -1) + + if not self._layer_allowed(layer_id): + self._can_capture = False + return args, kwargs + + step_index = self.state.get("step_index", None) + if not self._step_allowed(step_index): + self._can_capture = False + return args, kwargs + + if self._layer_store_count.get(layer_id, 0) >= self.store_limit_per_layer: + self._can_capture = False + return args, kwargs + + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + + if hidden_states is not None: + q_source = hidden_states + k_source = hidden_states + attention_type = "self" + elif encoder_hidden_states is not None: + q_source = encoder_hidden_states + k_source = encoder_hidden_states + attention_type = "self_encoder" + else: + self._can_capture = False + return args, kwargs + + if q_source.shape[1] > self.max_sequence_length: + print(f"Layer {layer_id}: sequence too long ({q_source.shape[1]}), skipping") + self._can_capture = False + return args, kwargs + + print(f"Capturing Layer {layer_id} at step {step_index}") + + q = module.to_q(q_source) + k = module.to_k(k_source) + + if hasattr(module, "head_to_batch_dim"): + q = module.head_to_batch_dim(q) + k = module.head_to_batch_dim(k) + num_heads = getattr(module, "heads", None) + batch_size = None + else: + b, tq, _ = q.shape + num_heads = getattr(module, "heads", None) + if num_heads is None: + self._can_capture = False + return args, kwargs + d = q.shape[-1] // num_heads + q = q.view(b, tq, num_heads, d).permute(0, 2, 1, 3).reshape(b * num_heads, tq, d) + b2, tk, _ = k.shape + k = k.view(b2, tk, num_heads, d).permute(0, 2, 1, 3).reshape(b2 * num_heads, tk, d) + batch_size = b + + self._can_capture = True + self._ctx = { + "fqn": fqn, + "layer_id": layer_id, + "num_heads": num_heads, + "batch_size": batch_size, + "q": q, "k": k, + "step_index": step_index, + "timestep": ts_val, + } + return args, kwargs + + def post_forward(self, module, output): + if not getattr(self, "_can_capture", False): + return output + + q, k = self._ctx["q"], self._ctx["k"] + scale = getattr(module, "scale", None) + if scale is None: + scale = 1.0 / math.sqrt(q.shape[-1]) + + if self.force_square and q.shape[1] != k.shape[1]: + min_len = min(q.shape[1], k.shape[1]) + q = q[:, :min_len, :] + k = k[:, :min_len, :] + + attn = torch.einsum("bqd,bkd->bqk", q, k).mul_(scale).softmax(dim=-1).detach().float().cpu() + + layer_id = self._ctx["layer_id"] + step_index = self._ctx["step_index"] + + print(f"Processing Layer {layer_id} at step {step_index}") + + # 立即处理并保存 + if self.process_immediately: + self._process_and_save_immediately(attn, layer_id, step_index) + else: + # 传统方式:存储到内存 + self._layer_store_count[layer_id] = self._layer_store_count.get(layer_id, 0) + 1 + self.captured.append({ + "layer_id": layer_id, + "layer_name": self._ctx["fqn"], + "step_index": step_index, + "timestep": self._ctx["timestep"], + "num_heads": self._ctx["num_heads"], + "batch_size": self._ctx["batch_size"], + "attn": attn, + }) + + return output + + def _process_and_save_immediately(self, attn: torch.Tensor, layer_id: int, step_index: int): + """立即处理并保存注意力数据,然后丢弃""" + try: + H = self._ctx["num_heads"] or attn.shape[0] + if self._ctx["num_heads"] is not None and self._ctx["batch_size"] is not None: + # 有明确的batch和head信息 + batch_size = self._ctx["batch_size"] + heads_attn = attn.view(batch_size, H, attn.shape[1], attn.shape[2]) + avg_attn = heads_attn.mean(dim=1) # [B, Tq, Tk] + attn_2d = avg_attn[0] # 取第一个batch + else: + # 退化情况:假设所有都是head + heads_attn = attn # [H, Tq, Tk] + attn_2d = heads_attn.mean(dim=0) # [Tq, Tk] + # 计算平均注意力(所有head的平均) + if self.attn_qk_map: + H = self._ctx["num_heads"] or attn.shape[0] + # 保存注意力热力图 + tstr = f"s{step_index}" + subdir = os.path.join(self.output_dir, f"{tstr}", f"layer{layer_id:03d}") + fname = f"{self._ctx['fqn'].replace('.', '_')}_average.png" + save_attention_heatmap( + attn_2d, + os.path.join(subdir, fname), + title=f"Layer {layer_id} Step {step_index} Average Attention", + xlabel="Query Tokens", + ylabel="Key Tokens" + ) + + # 计算并保存冗余度热力图 + if self.redundancy_q_map: + # 计算并保存冗余度热力图(低内存) + tstr = f"s{step_index}" + subdir = os.path.join(self.output_dir, f"{tstr}", f"layer{layer_id:03d}") + redundancy_fname = f"{self._ctx['fqn'].replace('.', '_')}_redundancy_cosine.png" + save_redundancy_heatmap_lowmem( + attn_2d, + os.path.join(subdir, redundancy_fname), + title=f"Layer {layer_id} Step {step_index} Query Redundancy", + ) + print(f"Saved attention maps for Layer {layer_id} at step {step_index}") + + # 打印统计信息 + print(f" Attention shape: {attn_2d.shape}") + print(f" Mean attention: {attn_2d.mean().item():.4f}") + # print(f" Mean redundancy: {redundancy_matrix.mean().item():.4f}") + + except Exception as e: + print(f"Error processing layer {layer_id}: {e}") + finally: + # 清理内存 + del attn + if hasattr(self, '_ctx'): + del self._ctx + torch.cuda.empty_cache() # 清理GPU缓存 + +class TransformerStepHook(ModelHook): + def __init__(self, shared_state: dict): + super().__init__() + self.state = shared_state + + def pre_forward(self, module, *args, **kwargs): + ts = kwargs.get("timestep", None) + ts_val = _to_scalar_timestep(ts) + if ts_val is not None: + last = self.state.get("last_timestep", None) + if last != ts_val: + self.state["step_index"] = self.state.get("step_index", -1) + 1 + self.state["last_timestep"] = ts_val + self.state["timestep"] = ts_val + return args, kwargs + +def _ensure_dir(path): + os.makedirs(os.path.dirname(path), exist_ok=True) + +def _save_heatmap(arr_2d: torch.Tensor, save_path: str, title: str = None, + cmap="magma", vmin=None, vmax=None, xlabel=None, ylabel=None, + invert_y=True, white_bg=True): + _ensure_dir(save_path) + plt.style.use("default") + fig, ax = plt.subplots(figsize=(6, 5), facecolor="white" if white_bg else None) + if white_bg: + ax.set_facecolor("white") + im = ax.imshow(arr_2d.detach().cpu().numpy(), cmap=cmap, aspect="auto", vmin=vmin, vmax=vmax) + plt.colorbar(im, ax=ax) + if xlabel: ax.set_xlabel(xlabel) + if ylabel: ax.set_ylabel(ylabel) + if title: ax.set_title(title) + if invert_y: ax.invert_yaxis() + plt.tight_layout() + plt.savefig(save_path, dpi=200, bbox_inches="tight", facecolor="white" if white_bg else None) + plt.close() + +def _cosine_sim_lowmem(x: torch.Tensor, row_chunk: int = 1024, device="cpu", dtype=torch.float32): + # x: [N, D] -> return [N, N] with X̂ X̂ᵀ, 分块避免内存峰值 + x = x.to(device=device, dtype=dtype, copy=False) + x = x / (x.norm(dim=1, keepdim=True) + 1e-8) + N, D = x.shape + out = torch.empty((N, N), dtype=dtype, device=device) + for i in range(0, N, row_chunk): + i_end = min(i + row_chunk, N) + out[i:i_end] = x[i:i_end] @ x.T + return out + +def _entropy_sparsity_from_sim(sim_row: torch.Tensor, temperature: float = 1.0): + # sim_row: [N], -> p=softmax(sim/τ), 稀疏度 = 1 - H_norm + p = F.softmax(sim_row / max(1e-6, temperature), dim=0) + logp = torch.log(p + 1e-12) + H = -(p * logp).sum() + H_norm = H / math.log(p.shape[0] + 1e-12) + return 1.0 - H_norm # 高→更稀疏 + +class LatentFrameVizHook(ModelHook): + """ + 在 CogVideoX3DTransformer 入口(hidden_states)上做“帧级”可视化: + - Full Attention(query->所有key 的相似度): 取某个 query 的一行,重排为 H×W。 + - Query Sparsity map:每个 query 的稀疏度,重排为 H×W。 + 计算结束立即落盘并释放内存。 + """ + def __init__( + self, + save_root: str, + target_steps: list[int] = None, # 只在这些 step_index 上可视化 + target_frames: list[int] = None, # 只在这些帧索引上可视化 + query_indices: list[int] = None, # 在 Full Attention 图中要展示的若干 query(在该帧的网格索引) + max_hw_tokens: int = 4096, # 控制 HxW 最大 tokens;过大则等距下采样 + row_chunk: int = 1024, # 相似度分块 + cosine_device: str = "cpu", # 相似度计算放 CPU,避免占用显存 + cosine_dtype = torch.float32, + temperature: float = 1.0, # 熵稀疏度的温度 + decode_latents = False, # 是否解码 latent + ): + super().__init__() + self.save_root = save_root + self.target_steps = set(target_steps) if target_steps else None + self.target_frames = set(target_frames) if target_frames else None + self.query_indices = query_indices or [0] # 默认画一个 query + self.max_hw_tokens = max_hw_tokens + self.row_chunk = row_chunk + self.cosine_device = "cuda" if (cosine_device.startswith("cuda") and torch.cuda.is_available()) else "cpu" + self.cosine_dtype = cosine_dtype + self.temperature = temperature + self.decode_latents = decode_latents + + def pre_forward(self, module, *args, **kwargs): + self._do = False + step_index = getattr(module._diffusers_hook.hooks["step_hook"], "state", {}).get("step_index", None) \ + if hasattr(module, "_diffusers_hook") and "step_hook" in getattr(module._diffusers_hook, "hooks", {}) \ + else None + if (self.target_steps is not None) and (step_index not in self.target_steps): + return args, kwargs + + x = kwargs.get("hidden_states", None) + if x is None or x.dim() < 4: + return args, kwargs + + # x 可能是 [B, C, T, H, W] 或 [B, T, C, H, W];做个鲁棒判定 + # 认为 C 是 64~2048 间的典型通道数;T 通常 <= 33;H,W 16~128 + shape = x.shape + if x.dim() == 5: + b, d1, d2, d3, d4 = shape + # 判断哪个是 C + candidates = [d1, d2] + if 64 <= d1 <= 4096 and d2 <= 64: + layout = "B C T H W" + elif 64 <= d2 <= 4096 and d1 <= 64: + layout = "B T C H W" + else: + # 回退:若 d1 > d2 认为 d1 是 C + layout = "B C T H W" if d1 >= d2 else "B T C H W" + else: + # 其他情况暂不处理 + return args, kwargs + + self._ctx = {"layout": layout, "step_index": step_index} + self._x_ref = x.detach().to("cpu") # 转 CPU,避免卡显存 + self._do = True + return args, kwargs + + def post_forward(self, module, output): + if not getattr(self, "_do", False): + return output + try: + x = self._x_ref # [B,*,*,H,W] + layout = self._ctx["layout"] + step = self._ctx["step_index"] + B = x.shape[0] + + # decode 一下 + # 改进,动态使用不同 pipeline 的 vae 的 decode_latent + + + # 选择目标帧集合 + if layout == "B C T H W": + T = x.shape[2]; C = x.shape[1]; H = x.shape[3]; W = x.shape[4] + frame_take = list(self.target_frames or range(T)) + for t in frame_take: + xt = x[0, :, t] # [C, H, W] + self._process_one_frame(xt, t, H, W, step) + else: # "B T C H W" + # 注意,这里的 vae 已经压缩过了。。。 + T = x.shape[1]; C = x.shape[2]; H = x.shape[3]; W = x.shape[4] + frame_take = list(self.target_frames or range(T)) + for t in frame_take: + xt = x[0, t] # [C, H, W] + self._process_one_frame(xt, t, H, W, step) + + finally: + del self._x_ref + torch.cuda.empty_cache() + return output + + def _process_one_frame(self, xt: torch.Tensor, t: int, H: int, W: int, step: int): + # xt: [C, H, W] on CPU; 下采样到 <= max_hw_tokens + C = xt.shape[0] + h, w = H, W + N = h * w + stride = 1 + while (h // stride) * (w // stride) > self.max_hw_tokens: + stride *= 2 + if stride > 1: + xt_ds = F.avg_pool2d(xt.unsqueeze(0), kernel_size=stride, stride=stride).squeeze(0) # [C, h', w'] + else: + xt_ds = xt + h2, w2 = xt_ds.shape[1], xt_ds.shape[2] + X = xt_ds.permute(1, 2, 0).reshape(-1, C) # [N2, C] + + # 相似度矩阵(低内存) + S = _cosine_sim_lowmem(X, row_chunk=self.row_chunk, device=self.cosine_device, dtype=self.cosine_dtype) # [N2,N2] + + # Query Sparsity map(熵稀疏度) + sparsity = torch.empty(S.shape[0], dtype=torch.float32, device=S.device) + for i in range(S.shape[0]): + sparsity[i] = _entropy_sparsity_from_sim(S[i], temperature=self.temperature) + sparsity_map = sparsity.reshape(h2, w2).cpu() + + # Full attention for selected queries(每个 query 的一行) + q_indices = [q for q in self.query_indices if 0 <= q < S.shape[0]] + + # 保存 + root = os.path.join(self.save_root, f"s{step}", f"frame{t:03d}") + _save_heatmap( + sparsity_map, os.path.join(root, "query_sparsity.png"), + title=f"Query Sparsity (frame {t}, step {step})", + cmap=LinearSegmentedColormap.from_list("green", ["white","lightgreen","green","darkgreen"], N=256), + vmin=0.0, vmax=0.0010, xlabel="X", ylabel="Y" + ) + for qi in q_indices: + attn_q = S[qi].reshape(h2, w2).cpu() + colors = ['white', 'lightblue', 'skyblue', 'steelblue', 'blue', 'darkblue', 'navy'] + n_bins = 256 + blue_cmap = LinearSegmentedColormap.from_list('blue_gradient', colors, N=n_bins) + _save_heatmap( + attn_q, os.path.join(root, f"full_attention_q{qi:05d}.png"), + title=f"Full Attention of q={qi} (frame {t}, step {step})", + cmap=blue_cmap, + vmin=-1.0, vmax=1.0, xlabel="Key X", ylabel="Key Y" + ) + +def assign_layer_ids_and_register(model, attn_hook: AttnCaptureHook, layer_name_patterns=None): + """为每个注意力模块分配连续的 layer_id,并注册捕获 hook""" + patterns = [re.compile(p) for p in (layer_name_patterns or [])] + def allow(fqn: str): + return True if not patterns else any(p.search(fqn) for p in patterns) + + layer_id = 0 + for fqn, m in model.named_modules(): + try: + m._diffusers_fqn = fqn + except Exception: + pass + if hasattr(m, "to_q") and hasattr(m, "to_k") and allow(fqn): + setattr(m, "_attn_layer_id", layer_id) + HookRegistry.check_if_exists_or_initialize(m).register_hook(attn_hook, name=f"attn_capture_{layer_id}") + layer_id += 1 + return layer_id \ No newline at end of file