In [9]:
import os
from llava_code import call_engine_llava, format_prompt_llava, load_model_llava
import torch
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoProcessor
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

In [10]:
device = "cuda:0"
model_id = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
dtype = torch.float32
model, processor = load_model_llava(model_id, device, dtype)

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 39.38 GiB of which 14.12 MiB is free. Including non-PyTorch memory, this process has 39.36 GiB memory in use. Of the allocated memory 38.17 GiB is allocated by PyTorch, and 687.60 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
use_dog = True

if use_dog:
    raw_img = np.array(Image.open("low_res_dog.png"))[:,:,:3]
    conversation = [
       {
         "role": "user",
         "content": [
             # {"type": "text", "text": sample['caption']},
             {"type": "text", "text": "the dog is brown and running in the snow with trees behind"},
             {"type": "text", "text": "what type of dog is in the picture"},
             {"type": "image"},
           ],
       },
    ]
else:
    raw_img = np.array(sample['image'])
    conversation = [
       {
         "role": "user",
         "content": [
             {"type": "text", "text": sample['caption']},
             # {"type": "text", "text": "the dog is brown and running in the snow with trees behind"},
             {"type": "text", "text": "What color color  is the shape"},
             {"type": "image"},
           ],
       },
    ]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

inputs = processor(images=raw_img, text=prompt, return_tensors='pt').to(device, dtype)

In [None]:
inputs['pixel_values'].requires_grad_(True)
outputs = model(**inputs, output_attentions=True)
logits = outputs.logits[0]

attention_maps = torch.stack(outputs.attentions)  # Shape: (num_layers, batch, num_heads, seq_len, seq_len)

tokens = [processor.batch_decode(iid) for iid in inputs.input_ids][0]
image_mask = [token == "<image>" for token in tokens]

# Extract attention to image tokens
img_attn = attention_maps[:, 0, :, image_mask, :]  # (num_layers, num_heads, seq_len, seq_len)

# Only keep attention scores where the query or key is an image token
img_attn = img_attn[:, :, :, image_mask]  # Adjust this indexing based on your token position handling

mean_attn = img_attn.mean()  # Shape: (num_layers, num_heads)
mean_attn.backward(retain_graph=True)
grad = inputs['pixel_values'].grad[0].detach().clone()
inputs['pixel_values'].grad.zero_()  # Reset gradients for next backward pass
"done"


In [None]:
img = inputs['pixel_values'].cpu().clone().detach()[0,0] + 1
img = torch.permute(img, (1,2,0)) / 2
plt.imshow(img)

In [None]:
g = torch.norm(grad, dim=[0]).cpu()
g = torch.permute(g, (1,2,0))
g = (g - g.min()) / (g.max() - g.min())
print(g.shape)

plt.imshow(g ** 0.33, cmap="viridis", interpolation="nearest")


In [None]:
import numpy as np
import torch

def compute_noisy_gradients(base_img, prompt, N=10, noise_std=10.0):
    """
    Generate N noisy versions of base_img by adding Gaussian noise,
    run each image through the model with the given prompt, and return
    a tensor of gradients computed from the mean image attention.
    
    Args:
        base_img (np.array): The original image (e.g., a NumPy array).
        prompt (str): The prompt string produced by processor.apply_chat_template.
        N (int): Number of noisy images to generate.
        noise_std (float): Standard deviation of the Gaussian noise.
        
    Returns:
        torch.Tensor: A tensor of shape (N, C, H, W) containing the gradient values.
    """
    grad_list = []
    
    for i in tqdm(range(N)):
        # Create a noisy version of the base image
        noise = np.random.normal(loc=0.0, scale=noise_std, size=base_img.shape)
        noisy_img = base_img + noise
        # Clip to valid pixel range and convert back to original dtype
        noisy_img = np.clip(noisy_img, 0, 255).astype(base_img.dtype)
        
        # Process the noisy image with the same text prompt
        inputs = processor(images=noisy_img, text=prompt, return_tensors='pt').to(0, torch.float32)
        inputs['pixel_values'].requires_grad_(True)
        
        # Forward pass with the model, requesting attention maps
        outputs = model(**inputs, output_attentions=True)
        
        # Stack all attention maps from different layers: shape (num_layers, batch, num_heads, seq_len, seq_len)
        attention_maps = torch.stack(outputs.attentions)
        
        # Determine which tokens correspond to the image (assumes "<image>" tokens)
        tokens = [processor.batch_decode(iid) for iid in inputs.input_ids][0]
        image_mask = [token == "<image>" for token in tokens]
        
        # Extract attention scores involving image tokens as queries and keys
        img_attn = attention_maps[:, 0, :, image_mask, :]
        img_attn = img_attn[:, :, :, image_mask]
        
        # Compute the mean attention value and backpropagate
        mean_attn = img_attn.mean()
        mean_attn.backward(retain_graph=True)
        
        # Get the gradient of the image pixels (assumes batch size = 1)
        grad = inputs['pixel_values'].grad[0].detach().clone()
        # Zero out gradients before the next iteration
        inputs['pixel_values'].grad.zero_()
        
        grad_list.append(grad)
    
    # Stack gradients along a new dimension: shape (N, C, H, W)
    grad_tensor = torch.stack(grad_list)
    return grad_tensor


In [None]:
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
gradients = compute_noisy_gradients(np.array(raw_img), prompt, N=50, noise_std=10.0)

In [None]:
mean = gradients.mean(dim=[0]).cpu()
mean = torch.permute(mean, (0, 2,3,1))

g = torch.norm(mean, dim=[0]).cpu()
# g = torch.permute(g, (1,2,0))
g = (g - g.min()) / (g.max() - g.min())
print(g.shape)

fig, axs = plt.subplots(1,2)
axs[0].imshow(g ** 0.33, cmap="viridis", interpolation="nearest")
axs[1].imshow(torch.norm(g, dim=[2]) ** 0.33, cmap="viridis", interpolation="nearest")


In [None]:
import torch

def compute_integrated_gradients_torch(base_img, prompt, baseline=None, steps=50, do_rescale=True):
    """
    Compute integrated gradients for base_img with respect to the model's output.
    
    Integrated gradients are computed by linearly interpolating between a baseline image
    and the target image, summing gradients along the path, averaging them, and scaling
    by the difference between the input and the baseline.
    
    Args:
        base_img (torch.Tensor): The target image tensor with values in the 0–255 range.
        prompt (str): The text prompt produced by processor.apply_chat_template.
        baseline (torch.Tensor, optional): The baseline image tensor; if None, a tensor of zeros is used.
        steps (int): The number of interpolation steps between the baseline and base_img.
        do_rescale (bool): Whether to rescale the image; for images in 0–255, set to True.
        
    Returns:
        torch.Tensor: A tensor of integrated gradients with the same shape as base_img.
    """
    # Use a black image (all zeros) as the baseline if none is provided.
    if baseline is None:
        baseline = torch.zeros_like(base_img).to(device, dtype)
    
    integrated_grad = 0.0
    # Interpolate from baseline to base_img over the specified number of steps.
    for alpha in tqdm(torch.linspace(0.01, 1, steps, device=base_img.device)):
        # Create an interpolated image.
        scaled_img = baseline + alpha * (base_img - baseline)
        
        # Process the scaled image; use do_rescale=True since base_img is in 0–255.
        inputs = processor(
            images=to_pil_image(scaled_img), 
            text=prompt, 
            return_tensors='pt', 
        ).to(0, torch.float32)
        inputs['pixel_values'].requires_grad_(True)
        
        # Forward pass through the model, retrieving attention outputs.
        outputs = model(**inputs, output_attentions=True)
        attention_maps = torch.stack(outputs.attentions)  # (num_layers, batch, num_heads, seq_len, seq_len)
        
        # Identify tokens corresponding to the image.
        tokens = [processor.batch_decode(iid) for iid in inputs.input_ids][0]
        image_mask = [token == "<image>" for token in tokens]
        
        # Extract attention scores where both query and key correspond to image tokens.
        img_attn = attention_maps[:, 0, :, image_mask, :]
        img_attn = img_attn[:, :, :, image_mask]
        
        # Compute the mean attention and backpropagate.
        mean_attn = img_attn.mean()
        mean_attn.backward(retain_graph=True)
        
        # Retrieve the gradients for this step.
        grad = inputs['pixel_values'].grad[0].detach().clone()
        inputs['pixel_values'].grad.zero_()
        
        integrated_grad += grad
    
    # Average the gradients over all interpolation steps.
    integrated_grad /= steps
    
    # Scale the integrated gradients by the difference between the base image and baseline.
    # ig = (base_img - baseline) * integrated_grad
    return integrated_grad


In [None]:
# tImg = pil_to_tensor(raw_img).to(device, dtype)
tImg = torch.tensor(raw_img).to(device, dtype).permute((2,0,1))
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
ig_tensor = compute_integrated_gradients_torch(tImg, prompt, baseline=None, steps=50)


In [None]:
g = ig_tensor.mean(dim=[0]).cpu()
# mean = torch.permute(mean, (1,2,0))

g = torch.norm(mean, dim=[0]).cpu()
# g = torch.permute(g, (1,2,0))
g = (g - g.min()) / (g.max() - g.min())

plt.imshow(g ** 0.33)

In [None]:
import torch
from tqdm import tqdm

def compute_integrated_gradients_torch(base_img, prompt, baseline=None, steps=50, do_rescale=True):
    """
    Compute integrated gradients for base_img using LLaVA-OneVision.
    
    This function creates a batch of interpolated images (from baseline to base_img) and processes 
    them with the processor as follows:
    
        inputs = processor(
                     images=pil_images,
                     text=[prompt] * len(pil_images),
                     return_tensors='pt',
                 )
    
    Then, it performs a forward pass with output_attentions=True to obtain attention maps.
    Integrated gradients are computed by backpropagating an overall mean attention scalar 
    and per-layer/per-head attention scalars.
    
    Args:
        base_img (torch.Tensor): Target image tensor (values in 0–255) with shape (C, H, W).
        prompt (str): The prompt (e.g. produced by processor.apply_chat_template) to be paired with every image.
        baseline (torch.Tensor, optional): Baseline image tensor; if None, uses a black image.
        steps (int): Number of interpolation steps.
        do_rescale (bool): Whether to apply rescaling (e.g. 0–255 to 0–1) if required.
        
    Returns:
        integrated_grad (torch.Tensor): Overall integrated gradient (same shape as base_img).
        pixel_gradients_per_layer_head (torch.Tensor): Per-layer/per-head gradients (shape: (num_layers, num_heads, C, H, W)).
        grad_magnitude_per_layer_head (torch.Tensor): Gradient magnitudes per layer/head (shape: (num_layers, num_heads, H, W)).
    """
    # Ensure the prompt is a string.
    if isinstance(prompt, list):
        prompt = " ".join(map(str, prompt))
    elif not isinstance(prompt, str):
        prompt = str(prompt)
    
    # Use a black image as baseline if none is provided.
    if baseline is None:
        baseline = torch.zeros_like(base_img).to(device, dtype)
    
    # Create a batch of interpolated images.
    base_img_batch = base_img.unsqueeze(0)      # (1, C, H, W)
    baseline_batch = baseline.unsqueeze(0)        # (1, C, H, W)
    alphas = torch.linspace(0.01, 1.0, steps, device=base_img.device).view(-1, 1, 1, 1)  # (steps, 1, 1, 1)
    scaled_imgs = baseline_batch + alphas * (base_img_batch - baseline_batch)  # (steps, C, H, W)
    
    # Optionally apply rescaling if required (e.g., converting 0–255 to 0–1).
    if do_rescale:
        # For example: scaled_imgs = scaled_imgs / 255.0
        pass
    
    # Convert each interpolated image to a PIL image.
    pil_images = [to_pil_image(img) for img in scaled_imgs]
    
    # Process the batch of images paired with the prompt.
    # The processor expects a list of images and a corresponding list of text prompts.
    inputs = processor(
        images=pil_images,
        text=[prompt] * len(pil_images),
        return_tensors='pt',
    )
    
    # Move each tensor in the inputs to the target device.
    # For pixel_values, we use torch.float32; for others (e.g. input_ids), we leave their dtypes unchanged.
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            if k == 'pixel_values':
                inputs[k] = v.to(model.device, torch.float32)
            else:
                inputs[k] = v.to(model.device)
    
    # Enable gradient computation on the pixel values.
    inputs['pixel_values'].requires_grad_(True)  # shape: (steps, C, H, W)
    
    # Run a single forward pass through the model with output_attentions=True.
    outputs = model(**inputs, output_attentions=True)
    # Stack attention outputs from all layers into one tensor.
    # Expected shape: (num_layers, steps, num_heads, seq_len, seq_len)
    attention_maps = torch.stack(outputs.attentions)
    
    # Use tokens from the first example (assuming the same prompt for every image).
    tokens = processor.batch_decode(inputs['input_ids'][0:1])[0]
    image_mask = [token == "<image>" for token in tokens]
    
    # Extract the attention corresponding to image tokens.
    img_attn = attention_maps[:, :, :, image_mask, :]   # (num_layers, steps, num_heads, num_img_tokens, seq_len)
    img_attn = img_attn[:, :, :, :, image_mask]          # (num_layers, steps, num_heads, num_img_tokens, num_img_tokens)
    
    # --- Overall Integrated Gradient ---
    overall_mean_attn = img_attn.mean()  # Scalar average over all dimensions.
    overall_mean_attn.backward(retain_graph=True)
    # Average gradients over the interpolation steps.
    integrated_grad = inputs['pixel_values'].grad.mean(dim=0).detach().clone() / steps
    inputs['pixel_values'].grad.zero_()  # Reset gradients.
    
    # --- Per-Layer, Per-Head Gradients ---
    # Average attention over steps and token positions -> shape: (num_layers, num_heads)
    mean_attention_per_layer_head = img_attn.mean(dim=[1, 3, 4])
    
    pixel_gradients_per_layer_head_list = []
    num_layers, num_heads = mean_attention_per_layer_head.shape
    for layer in tqdm(range(num_layers), desc="Computing per-layer/head gradients"):
        for head in range(num_heads):
            scalar = mean_attention_per_layer_head[layer, head]
            scalar.backward(retain_graph=True)
            grad_avg = inputs['pixel_values'].grad.mean(dim=0).detach().clone() / steps
            pixel_gradients_per_layer_head_list.append(grad_avg)
            inputs['pixel_values'].grad.zero_()
    
    # Reshape to (num_layers, num_heads, C, H, W)
    pixel_gradients_per_layer_head = torch.stack(pixel_gradients_per_layer_head_list).view(
        num_layers, num_heads, *inputs['pixel_values'].shape[1:]
    )
    
    # Compute gradient magnitudes (norm over the channel dimension).
    grad_magnitude_per_layer_head = torch.norm(pixel_gradients_per_layer_head, dim=2)
    
    return integrated_grad, pixel_gradients_per_layer_head, grad_magnitude_per_layer_head


In [None]:
# Assume raw_img is already loaded (e.g., as a NumPy array or PIL image)
# and device, dtype are defined.
tImg = torch.tensor(raw_img).to(device, dtype).permute((2, 0, 1))
# Call the new integrated gradients function.
ig_tensor, per_head_gradients, grad_magnitude = compute_integrated_gradients_torch(
    tImg,
    conversation,
    baseline=None,
    steps=50,
    do_rescale=True
)