In [None]:
import torch
from transformers import AutoProcessor
from PIL import Image
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from huggingface_hub import snapshot_download
import sys, os
repo_path = snapshot_download("amildravid4292/llava-llama-3-8b-test-time-registers") 
sys.path.insert(0, repo_path)
from modeling_custom_llava import LlavaRegistersForConditionalGeneration

device = "cuda:0"

In [None]:
# language model attention capture
class AttentionCaptureModel(LlavaRegistersForConditionalGeneration):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.captured_attentions = None

    def forward(self, *args, **kwargs):
        # Capture the attention weights
        output = super().forward(*args, **kwargs)
        self.captured_attentions = output.attentions
        return output


In [None]:
model = AttentionCaptureModel.from_pretrained(
    "xtuner/llava-llama-3-8b-v1_1-transformers", 
    torch_dtype=torch.float16,
    output_attentions=True
).to(device)
# use original processor
processor = AutoProcessor.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")

In [None]:
# hook output of vision model
patches = {}
def hook_output_patch(module, input, output):
    patches[-1] = output

hook_handle = model.vision_tower.vision_model.encoder.layers[-1].register_forward_hook(hook_output_patch)

In [None]:
prompt = ("<|start_header_id|>user<|end_header_id|>\n\n<image>\nHow many tennis balls are in the dog's mouth? Use one word.<|eot_id|>"
          "<|start_header_id|>assistant<|end_header_id|>\n\n")

# Load image
image_path = "images/dog_img.webp"
raw_image = Image.open(image_path)

inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float16)

# use original model without test-time register
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=1, do_sample=False, extra_tokens=0, neuron_dict=None)

tokenizer = processor.tokenizer
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print("Decoded output:", decoded_output)


In [None]:
patch_norms = torch.norm(patches[-1][0].float().squeeze(0), dim=-1).detach().cpu().numpy()
hook_handle.remove()

plt.axis('off')
plt.suptitle("Output Patch Norm", fontsize=25)

im = plt.imshow(patch_norms[1:].reshape(24, 24))
plt.colorbar(im)

plt.tight_layout()
plt.show()

In [None]:
atts = torch.cat(model.captured_attentions).float()
# visualize attention from answer to visual tokens
im = plt.imshow(atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24))
plt.axis("off")
plt.suptitle("Mean Attention Map for Answer Token ", fontsize = 20)
plt.tight_layout()
plt.colorbar(im)
plt.show()

In [None]:
atts = torch.cat(model.captured_attentions).float()
atts = atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24)
image = inputs["pixel_values"][0].permute(1,2,0).float().cpu()
scale_factor = 336/24
heatmap_upsampled = zoom(atts, scale_factor, order=1)  # bilinear interpolation

# Create the overlay
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image, cmap='gray')  # Show original image
ax.imshow(heatmap_upsampled, alpha=0.5, cmap='jet')  # Overlay heatmap with transparency
ax.axis('off')



plt.show()

In [None]:
# reset hook
patches = {}
hook_handle = model.vision_tower.vision_model.encoder.layers[-1].register_forward_hook(hook_output_patch)

prompt = ("<|start_header_id|>user<|end_header_id|>\n\n<image>\nHow many tennis balls are in the dog's mouth? Use one word.<|eot_id|>"
          "<|start_header_id|>assistant<|end_header_id|>\n\n")

# Load image
image_path = "images/dog_img.webp"

raw_image = Image.open(image_path)



inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float16)

# default uses test-time register
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=1, do_sample=False)

tokenizer = processor.tokenizer
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print("Decoded output:", decoded_output)


In [None]:
patch_norms = torch.norm(patches[-1][0].float().squeeze(0), dim=-1).detach().cpu().numpy()
hook_handle.remove()

plt.axis('off')
plt.suptitle("Output Patch Norm", fontsize=25)

im = plt.imshow(patch_norms[1:-1].reshape(24, 24))
plt.colorbar(im)

plt.tight_layout()
plt.show()

In [None]:
atts = torch.cat(model.captured_attentions).float()
# visualize attention from answer to visual tokens
im = plt.imshow(atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24))
plt.axis("off")
plt.suptitle("Mean Attention Map for Answer Token ", fontsize = 20)
plt.tight_layout()
plt.colorbar(im)
plt.show()

In [None]:
atts = torch.cat(model.captured_attentions).float()
atts = atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24)
image = inputs["pixel_values"][0].permute(1,2,0).float().cpu()
scale_factor = 336/24
heatmap_upsampled = zoom(atts, scale_factor, order=1)  # bilinear interpolation

# Create the overlay
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image, cmap='gray')  # Show original image
ax.imshow(heatmap_upsampled, alpha=0.5, cmap='jet')  # Overlay heatmap with transparency
ax.axis('off')



plt.show()