In [None]:
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoConfig
from PIL import Image
import numpy as np
from scipy.ndimage import uniform_filter
import matplotlib.pyplot as plt
import cv2
import re
import torch.nn.functional as F
import re
import gc
import os

# 初始化设备
device_map = "auto"

# 初始化设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载模型
# model_id = "llava-hf/llava-1.5-7b-hf"

model_id = "./llava-7b-hf"
# model = LlavaForConditionalGeneration.from_pretrained(
#     model_id, 
#     torch_dtype=torch.float16, 
#     low_cpu_mem_usage=True, 
# ).to(device)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map=device_map,  # 自动分配到可用的 GPU
    offload_folder="offload",  # 可选：指定离线存储路径以减少 GPU 内存占用
)

processor = AutoProcessor.from_pretrained(model_id)
model.eval()

In [None]:
# 定义会话历史并使用 `apply_chat_template` 获取正确格式的提示
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Is this brain MRI image normal or does it have a tumor?"},  # 文本prompt
            {"type": "image"},
        ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

# 自己加载图片
image_file = "/home/zhuxy/autodl-tmp/MedCLIP-SAMv2-main/biomedclip_finetuning/open_clip/src/data/brain_tumors/test_images/66.png"
raw_image = Image.open(image_file).convert("RGB")  # 确保图片为 RGB 格式
print(raw_image.mode)
# 确定图像补丁大小，根据模型文档或实际结构设置
image_patch_size = 14  # 示例值，需根据模型实际情况调整

# 准备输入
# inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(device, torch.float16)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt')


In [None]:
# 初始化解码器
# inputs = processor(images=raw_image, text=prompt, return_tensors='pt')
tokenizer = processor.tokenizer
max_generated_length = 50  # 假设最大生成长度为 20
generated_tokens = []
# decoder_input_ids = torch.tensor([[tokenizer.bos_token_id]]).to(device)  # BOS token as initial input
decoder_input_ids = inputs["input_ids"]

# 获取输入序列的长度
num_text_tokens = inputs["input_ids"].shape[1]  # 文本 token 长度
num_image_tokens = inputs["pixel_values"].shape[2]  # 图像 token 长度

decoder_input_ids = inputs["input_ids"]  # 使用初始输入 ID
# 逐步解码
generated_attention_maps = []
generated_token_count = 0

attention_mask = inputs["attention_mask"]  # 初始的 attention mask

for step in range(max_generated_length):
    print(f"\nStep {step + 1}/{max_generated_length}")


    # 执行前向传播
    with torch.no_grad():
        outputs = model(
            input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            pixel_values=inputs["pixel_values"],
            output_attentions=True  # 确保输出 attentions
        )

    # 获取生成的 token
    # 提取 logits 和 attentions
    logits = outputs.logits  # (batch_size, seq_len, vocab_size)
    attentions = outputs.attentions  # (num_layers, batch_size, num_heads, seq_len, seq_len)
    print("The shape of attentions:",attentions[0].shape)
    # 获取当前生成的 token
    next_token = logits[:, -1, :].argmax(dim=-1)
    generated_tokens.append(next_token.item())
    print("Final Generated Tokens:", processor.decode(generated_tokens))

    # 更新输入序列，拼接生成的 token
    decoder_input_ids = torch.cat([decoder_input_ids, next_token.unsqueeze(-1)], dim=-1)
    attention_mask = torch.cat([attention_mask, torch.ones_like(next_token).unsqueeze(-1)], dim=-1)

    # 检查生成结束
    if next_token.item() == processor.tokenizer.eos_token_id:
        print("Generated EOS token. Stopping generation.")
        break

In [None]:
def find_tumor_token_indices(generated_tokens, processor):
    # 将生成的 tokens 解码为文本
    decoded_text = processor.decode(generated_tokens)

    # 如果 "tumor" 不在文本中，直接返回 None
    if "tumor" not in decoded_text:
        return None

    # 对最终文本再次进行分词（不加特殊符号）
    # 这样获取的 token 序列应该与 generated_tokens 在逻辑上对应
    re_encoded_ids = processor.tokenizer.encode(decoded_text, add_special_tokens=False)

    # 对 "tumor" 进行分词
    tumor_ids = processor.tokenizer.encode("tumor", add_special_tokens=False)

    # 在 re_encoded_ids 中寻找 tumor_ids 子序列的位置
    for start_idx in range(len(re_encoded_ids) - len(tumor_ids) + 1):
        if re_encoded_ids[start_idx:start_idx + len(tumor_ids)] == tumor_ids:
            # 找到匹配的位置
            return (start_idx, start_idx + len(tumor_ids) - 1)

    # 如果没有找到
    return None


# 使用示例
tumor_indices = find_tumor_token_indices(generated_tokens, processor)
if tumor_indices is not None:
    print(f"'tumor' found at token indices: {tumor_indices}")
else:
    print("No 'tumor' token found in the decoded sequence.")


In [None]:
def get_image_token_indices(inputs, image_token_id=32000):
    input_ids = inputs["input_ids"]  # [batch_size, seq_length]
    # 假设 batch_size = 1
    image_token_indices = (input_ids[0] == image_token_id).nonzero(as_tuple=True)[0].tolist()
    return image_token_indices

def get_image_token_start_end(image_token_indices):
    if not image_token_indices:
        raise ValueError("未找到具有指定 ID 的图像 token。")
    
    image_token_start = image_token_indices[0]
    image_token_end = image_token_indices[-1] + 1  # +1 使得 slicing 是 [start, end)
    
    # 验证图像 token 是否连续
    expected_indices = list(range(image_token_start, image_token_end))
    if image_token_indices != expected_indices:
        raise ValueError("图像 token 不是连续的。")
    
    return image_token_start, image_token_end

# 获取所有 image token 的索引
image_token_indices = get_image_token_indices(inputs, image_token_id=32000)
print(f"Number of image tokens found: {len(image_token_indices)}")
print(f"Image Token Indices: {image_token_indices}")

# 获取 image_token_start 和 image_token_end
try:
    image_token_start, image_token_end = get_image_token_start_end(image_token_indices)
    print(f"Image Token Start Index: {image_token_start}")
    print(f"Image Token End Index: {image_token_end}")
except ValueError as e:
    print(f"Error: {e}")

In [None]:
if tumor_indices is not None:
    start_index = tumor_indices[0] + num_text_tokens # 第一个元素（整数）
    end_index = tumor_indices[1] + num_text_tokens+1   # 第二个元素（整数）
    # 这里的 start_index 和 end_index 就是整数类型了




sum_attention = sum(attentions)  # 自动逐元素求和

# 求平均值
final_attention_map = sum_attention / len(attentions)
batch_size, num_heads, seq_len, _ = final_attention_map.shape


# 提取所有 generated_token 对 image_tokens 的注意力
feature_attention_map_all = final_attention_map[:, :, num_text_tokens:seq_len, image_token_start+1:image_token_end+1]
feature_attention_map_mean = feature_attention_map_all.mean(dim=(2))
# 提取特定token处对image_token的注意力并减去所有token的平均注意力
final_generated_token_attention = final_attention_map[:, :, start_index:end_index, image_token_start+1:image_token_end+1].mean(dim=(2))-feature_attention_map_mean
print(final_generated_token_attention.shape)
# 对所有头和所有生成 token 取平均
final_image_attention_avg = final_generated_token_attention.mean(dim=(1))
print(final_image_attention_avg)

# # 初始化二维注意力图为 24x24 的零张量
final_attention_map_2d = torch.zeros(24, 24).to(final_image_attention_avg.device)
attention_np = final_image_attention_avg.squeeze(0).cpu().numpy()

# 如果需要将 torch.Tensor 转换为 NumPy 数组进行后续处理或可视化
# final_attention_map_2d_np = final_attention_map_2d.cpu().numpy()
final_attention_map_2d_np = attention_np.reshape(24, 24)

# 打印或使用 final_attention_map_2d_np 进行可视化
print(final_attention_map_2d_np)

In [None]:
# # 将注意力映射到原始像素空间
# # 确保最终注意力图为 float32 类型
# final_attention_map_2d = final_attention_map_2d.astype(np.float32)
final_attention_map_2d_np = final_attention_map_2d_np.astype(np.float32)
original_size = raw_image.size  # (W, H)
print(raw_image.size)
heatmap_pil = Image.fromarray(final_attention_map_2d_np)
heatmap_resized = heatmap_pil.resize(original_size, resample=Image.BICUBIC)
raw_heatmap = np.array(heatmap_resized)

# # 应用均值滤波平滑
# kernel_size = 7  # 定义核大小
# final_heatmap = cv2.blur(raw_heatmap, (kernel_size, kernel_size))
final_heatmap = raw_heatmap
# 归一化到 [0, 255]
normalized_heatmap = (final_heatmap - final_heatmap.min()) / (final_heatmap.max() - final_heatmap.min())
heatmap_uint8 = (normalized_heatmap * 255).astype(np.uint8)



In [None]:
import matplotlib.pyplot as plt

# 创建热力图
plt.figure(figsize=(10, 8))
plt.imshow(heatmap_uint8, cmap='jet')  # 使用 jet 颜色映射
plt.colorbar()  # 添加颜色条
plt.axis('off')  # 关闭坐标轴
plt.show()

In [None]:
# 以下是method 3 获取heatmap的方法(待定)
# 以下是我用了hook获取每一层的query 和key来计算文章中提到的similiarity matrix，实际上可能这个就是attention matrix（需查验）
# =================== 3. 定义钩子函数 =====================
def cache_qk(name, module, output, cache_dict):
    """缓存特定模块的 Query 和 Key，并记录层数"""
    # 使用正则表达式提取层数，确保仅匹配 language_model 的模块
    match = re.search(r'language_model\.model\.layers\.(\d+)\.self_attn', name)
    if match:
        layer_num = int(match.group(1))  # 提取层数
    else:
        layer_num = -1  # 未找到层数

    if layer_num == -1:
        print(f"Skipping module: {name}")
        return  # 不处理未识别的层

    if "self_attn.q_proj" in name:
        print(f"Caching Query for: {name}, Layer: {layer_num}, Output Shape: {output.shape}")
        # 将输出移动到主 GPU (GPU 0)
        cache_dict["queries"].append((layer_num, output.detach().to("cuda:0")))
    elif "self_attn.k_proj" in name:
        print(f"Caching Key for: {name}, Layer: {layer_num}, Output Shape: {output.shape}")
        # 将输出移动到主 GPU (GPU 0)
        cache_dict["keys"].append((layer_num, output.detach().to("cuda:0")))
    else:
        print(f"Skipping module: {name}")

def register_hooks(model, cache_dict, total_layers=32):
    hook_handles = []  # 用于存储钩子对象

    # 遍历所有模块
    for name, module in model.named_modules():
        # 仅匹配 language_model 中的 self_attn.q_proj 和 self_attn.k_proj
        if 'language_model.model.layers.' in name and ('self_attn.q_proj' in name or 'self_attn.k_proj' in name):
            # 注册钩子，使用当前模块名称作为默认参数，避免闭包问题
            handle = module.register_forward_hook(
                lambda module, input, output, name=name: cache_qk(
                    name,
                    module,
                    output,
                    cache_dict
                )
            )
            hook_handles.append(handle)
            print(f"Hook registered for: {name}")

    print("All hooks registered for language_model's self_attn.q_proj and self_attn.k_proj.")
    print(f"Total hooks registered: {len(hook_handles)}")
    return hook_handles
# 初始化解码器
tokenizer = processor.tokenizer
max_generated_length = 30  # 假设最大生成长度为 20
generated_tokens = []
# decoder_input_ids = torch.tensor([[tokenizer.bos_token_id]]).to(device)  # BOS token as initial input
decoder_input_ids = inputs["input_ids"]

# 获取输入序列的长度
num_text_tokens = inputs["input_ids"].shape[1]  # 文本 token 长度
num_image_tokens = inputs["pixel_values"].shape[2]  # 图像 token 长度

decoder_input_ids = inputs["input_ids"]  # 使用初始输入 ID
# 逐步解码
generated_attention_maps = []
generated_token_count = 0

cache_dict = {"queries": [], "keys": []}
hook_handles = register_hooks(model, cache_dict,32)
attention_mask = inputs["attention_mask"]  # 初始的 attention mask

for step in range(max_generated_length):
    print(f"\nStep {step + 1}/{max_generated_length}")

    # 清空当前步的缓存
    cache_dict["queries"].clear()
    cache_dict["keys"].clear()

    # 执行前向传播
    with torch.no_grad():
        outputs = model(
            input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            pixel_values=inputs["pixel_values"],
            output_attentions=True  # 确保输出 attentions
        )

    # 验证缓存结果
    num_queries_cached = len(cache_dict["queries"])
    num_keys_cached = len(cache_dict["keys"])
    print(f"Number of Queries Cached: {num_queries_cached}")
    if cache_dict["queries"]:
        for i, (layer_num, q) in enumerate(cache_dict["queries"]):
            print(f"Query {i+1}: Layer {layer_num}, Shape: {q.shape}, Device: {q.device}")
    else:
        print("No Queries Cached!")

    print(f"Number of Keys Cached: {num_keys_cached}")
    if cache_dict["keys"]:
        for i, (layer_num, k) in enumerate(cache_dict["keys"]):
            print(f"Key {i+1}: Layer {layer_num}, Shape: {k.shape}, Device: {k.device}")
    else:
        print("No Keys Cached!")

    # 获取生成的 token
    # 提取 logits 和 attentions
    logits = outputs.logits  # (batch_size, seq_len, vocab_size)
    attentions = outputs.attentions  # (num_layers, batch_size, num_heads, seq_len, seq_len)
    # 获取当前生成的 token
    next_token = logits[:, -1, :].argmax(dim=-1)
    generated_tokens.append(next_token.item())
    print("Final Generated Tokens:", processor.decode(generated_tokens))

    # 更新输入序列，拼接生成的 token
    decoder_input_ids = torch.cat([decoder_input_ids, next_token.unsqueeze(-1)], dim=-1)
    attention_mask = torch.cat([attention_mask, torch.ones_like(next_token).unsqueeze(-1)], dim=-1)

    # 检查生成结束
    if next_token.item() == processor.tokenizer.eos_token_id:
        print("Generated EOS token. Stopping generation.")
        break

    # 释放不需要的变量
    del outputs, next_token,logits,attentions
    gc.collect()
    torch.cuda.empty_cache()

# 移除钩子
for handle in hook_handles:
    handle.remove()
hook_handles.clear()

print("All hooks have been removed.")
print("Final Generated Tokens:", tokenizer.decode(generated_tokens))

In [None]:
# =================== 重塑 Queries 和 Keys =====================
# 定义多头设置
import math
num_heads = 32  # 根据模型配置
hidden_dim = 4096  # 总隐藏维度
head_dim = hidden_dim // num_heads  # 128
# scale = head_dim ** 0.5  # 缩放因子
scale = math.sqrt(hidden_dim)  # 缩放因子

num_layers_captured = 32

# 初始化列表来存储重塑后的 queries 和 keys
reshaped_queries = []
reshaped_keys = []

for layer_idx in range(num_layers_captured):
    q = cache_dict["queries"][layer_idx][1]  # [1, seq_len, hidden_dim]
    k = cache_dict["keys"][layer_idx][1]     # [1, seq_len, hidden_dim]

    # 重塑为 [batch_size, num_heads, seq_len, head_dim]
    q = q.view(q.size(0), q.size(1), num_heads, head_dim).transpose(1, 2)  # [1, H, N, D]
    k = k.view(k.size(0), k.size(1), num_heads, head_dim).transpose(1, 2)  # [1, H, N, D]

    # 移除 batch 维度
    q = q.squeeze(0)  # [H, N, D]
    k = k.squeeze(0)  # [H, N, D]

    reshaped_queries.append(q)  # 列表长度为 num_layers, 每个元素 [H, N, D]
    reshaped_keys.append(k)

# 转换为 Tensor，形状为 [num_layers, num_heads, seq_len, head_dim]
reshaped_queries = torch.stack(reshaped_queries, dim=0)  # [L, H, N, D]
reshaped_keys = torch.stack(reshaped_keys, dim=0)        # [L, H, N, D]

print(f"Shape of reshaped queries: {reshaped_queries.shape}")  # [32, 32, 612, 128]
print(f"Shape of reshaped keys: {reshaped_keys.shape}")        # [32, 32, 612, 128]

In [None]:
# =================== 13. 计算 Similarity Matrix S =====================
# S ∈ [N, N, H, L]
N = reshaped_queries.size(2)  # seq_len
H = reshaped_queries.size(1)  # 32
L = reshaped_queries.size(0)  # num_layers

# 初始化 S
S = torch.zeros((N, N, H, L), device="cuda:1")  # [seq_len, seq_len, H, L]

for layer_idx in range(L):
    for head_idx in range(H):
        q = reshaped_queries[layer_idx, head_idx, :, :]  # [N, D]
        k = reshaped_keys[layer_idx, head_idx, :, :]     # [N, D]

        # 计算相似度
        scores = torch.matmul(q, k.transpose(0, 1)) / scale  # [N, N]

        # 应用因果掩码
        causal_mask = torch.tril(torch.ones((N, N), device="cuda:0")).bool()  # [N, N]
        scores = scores.masked_fill(~causal_mask, float('-inf'))

        # 应用 Softmax
        similarity_map = torch.softmax(scores, dim=-1)  # [N, N]
        # 存储到 S
        S[:, :, head_idx, layer_idx] = similarity_map  # [N, N, H, L]

print(f"Shape of similarity matrix S: {S.shape}")  # [N, N, H, L]

In [None]:
# =================== 14. 计算聚合权重 W =====================
# W ∈ [1, 1, H, L]

# 计算 Max(S, dim=1)
# S: [N, N, H, L]
Max_S, _ = S.max(dim=1)  # [N, H, L]

# 计算 Mean(Max_S, dim=0)
W = Max_S.mean(dim=0)  # [H, L]

# 调整形状为 [1, 1, H, L]
W = W.unsqueeze(0).unsqueeze(0)  # [1, 1, H, L]

print(f"Shape of attention head weights W: {W.shape}")  # [1, 1, H, L]

# =================== 15. 可视化 W =====================
# # 可视化 W 的部分内容，例如第一个头和第一层
# if W.numel() > 0:
#     H_idx = 0
#     L_idx = 0
#     W_val = W[0, 0, H_idx, L_idx].item()
#     print(f"Attention Weight W for Head {H_idx + 1}, Layer {L_idx + 1}: {W_val}")

#     # 可视化 W 的某个特定头和层
#     plt.figure(figsize=(6, 4))
#     plt.bar(range(H), W[0, 0, :, L_idx].cpu().numpy())
#     plt.xlabel('Head Index')
#     plt.ylabel('Attention Weight')
#     plt.title(f'Attention Weights for Layer {L_idx + 1}')
#     plt.show()
# else:
#     print("Attention head weights W is empty.")


In [None]:
# 确保 W 位于与 S 相同的设备上 (cuda:1)
S_prime = (S * W).mean(dim=2)  # S_prime: [N, N, L]
# =================== 16. 计算 Attention Rollout =====================

# 获取 N, L
N = S_prime.size(0)  # seq_len
L = S_prime.size(2)  # num_layers

# 定义设备，确保与 S_prime 相同
device = S_prime.device

# 初始化单位矩阵 I
I = torch.eye(N, device=device)

# 初始化累积注意力矩阵 S_rollout 为单位矩阵
S_rollout = I.clone()

print("Starting Attention Rollout computation...")

# 递归计算累积注意力矩阵
for l in range(L):
    print(f"Processing layer {l + 1}/{L}")
    
    # 获取当前层的 S_prime
    S_prime_l = S_prime[:, :, l]  # [N, N]
    
    # 计算 (I + S_prime^l)
    A_l = I + S_prime_l  # [N, N]
    
    # 更新 S_rollout: A_l @ S_rollout
    S_rollout = torch.matmul(A_l, S_rollout)  # [N, N]
    
    # # 可选：归一化 S_rollout 或其他操作
    # # 例如，可以按行归一化
    # S_rollout = S_rollout / S_rollout.sum(dim=1, keepdim=True)
    
    print(f"S_rollout after layer {l + 1}: {S_rollout.shape}")

print("Attention Rollout computation completed.")

In [None]:
# # j 从 1 到 N，对应于索引 0 到 N-1 (这步是正则化操作 (待定))
multiplier = ((torch.arange(N, device=device).float()+1) / N)  # [N]

# # 调整形状为 [1, N] 以便与 S_rollout 进行逐列相乘
# multiplier = multiplier.unsqueeze(0)  # [1, N]
# # print(multiplier)
# # 应用正则化项
S_rollout_reg = S_rollout*multiplier # [N, N]
# S_rollout_reg = S_rollout
# print("Applied regularization to S_rollout.")
# print(f"Shape of regularized S_rollout: {S_rollout_reg.shape}")  # [N, N]
# print(S_rollout_reg[0,0])

In [None]:
# 其中generated_token_index 是对应想要的那几个token的index(怎么获取参照上一个方法)，如果是多个下面这个需要对第0维取平均，image_token_start(end)的获取与上一个方法的相同
image_tokens_attention = S_rollout_reg[generated_token_index, image_token_start:image_token_end]

In [None]:
grid_size = 24

assert image_tokens_attention.size(0) == grid_size * grid_size, \
    f"Expected {grid_size * grid_size} image tokens, but got {image_tokens_attention.size(0)}."

# 将 1D 向量重塑为 2D 网格
attention_grid = image_tokens_attention.view(grid_size, grid_size)  # [24, 24]

print(f"Shape of attention_grid: {attention_grid.shape}")  # [24, 24]

# =================== 5. 插值到原始图像大小 =====================
# 将 attention_grid 转换为 [1, 1, H, W] 的形状以便插值
attention_grid = attention_grid.unsqueeze(0).unsqueeze(0)  # [1, 1, 24, 24]

# 获取原始图像的尺寸
original_width, original_height = raw_image.size  # (width, height)

# 使用插值将 attention_grid 调整到原始图像大小
attention_resized = F.interpolate(attention_grid, size=(original_height, original_width), mode='bilinear', align_corners=False)  # [1, 1, H, W]

# 移除多余的维度
attention_resized = attention_resized.squeeze(0).squeeze(0)  # [original_height, original_width]

print(f"Shape of attention_resized: {attention_resized.shape}")  # [original_height, original_width]

# =================== 6. 应用平滑滤波 =====================
# 将 attention_resized 转换为 NumPy 数组
attention_np = attention_resized.cpu().numpy()

# 应用高斯平滑
kernel_size = (3, 3)  # 根据需要调整
sigma = 1.0

# 使用 OpenCV 进行高斯滤波
attention_smoothed = cv2.GaussianBlur(attention_np, kernel_size, sigma)

print(f"Shape of attention_smoothed: {attention_smoothed.shape}")  # [original_height, original_width]


In [None]:
# =================== 7. 归一化并保存为红蓝热图 =====================
# 归一化 attention_smoothed 到 [0, 1]
min_val = attention_smoothed.min()
max_val = attention_smoothed.max()
if max_val - min_val != 0:
    attention_normalized = (attention_smoothed - min_val) / (max_val - min_val)
else:
    attention_normalized = np.zeros_like(attention_smoothed)

# 定义红蓝色调的颜色映射
cmap = plt.get_cmap('bwr')  # 'bwr' 是蓝-白-红色映射

# 将归一化的热图转换为 RGB 图像
attention_colored = cmap(attention_normalized)[:, :, :3]  # 去除 alpha 通道

# 转换为 uint8 类型
attention_uint8 = (attention_colored * 255).astype(np.uint8)

# 将 NumPy 数组转换为 PIL Image
attention_image = Image.fromarray(attention_uint8)
# 后续可以进行可视化操作或者保存