In [4]:
import torch
import os
from transformers import FlavaProcessor, FlavaForPreTraining
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch.amp import autocast

# Set random seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set environment variables for memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()

# Load FLAVA model and processor
model_name = "facebook/flava-full"
processor = FlavaProcessor.from_pretrained(model_name)
model = FlavaForPreTraining.from_pretrained(model_name).to("cuda")
model.eval()

FlavaForPreTraining(
  (flava): FlavaModel(
    (text_model): FlavaTextModel(
      (embeddings): FlavaTextEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): FlavaEncoder(
        (layer): ModuleList(
          (0-11): 12 x FlavaLayer(
            (attention): FlavaAttention(
              (attention): FlavaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): FlavaSelfOutput(
                (dense): Linear(in_features=768, out_f

In [5]:
def compute_gradcam_flava(image_path, texts, target_label, model, processor):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    images = [image] * len(texts)  # Replicate image to match number of texts
    inputs = processor(
        text=texts,
        images=images,
        return_tensors="pt",
        padding='max_length',
        max_length=77,
        return_codebook_pixels=True,
        return_image_mask=True,
        return_attention_mask=True
    ).to("cuda")

    # Get the target label index
    try:
        target_label_idx = texts.index(target_label)
    except ValueError:
        print(f"Target label '{target_label}' not found in texts")
        return

    # Get all layers in the FLAVA image encoder
    num_layers = len(model.flava.image_model.encoder.layer)
    print(f"Total layers in FLAVA Image Encoder: {num_layers}")

    # Dictionary to store feature maps for each layer
    feature_maps_dict = {}
    hooks = []

    # Register hooks for all layers
    for layer_idx in range(num_layers):
        layer = model.flava.image_model.encoder.layer[layer_idx]
        def hook_fn(module, input, output, idx=layer_idx):
            if isinstance(output, tuple):
                feature_maps = output[0]
            else:
                feature_maps = output
            feature_maps_dict[idx] = feature_maps
            feature_maps.retain_grad()  # Ensure gradients are retained
            def grad_hook(grad):
                print(f"Grad hook called for layer {idx} with grad shape: {grad.shape}")
                feature_maps_dict[idx].grad = grad.clone()
            feature_maps.register_hook(grad_hook)
            print(f"Layer {idx} - Hooked shape: {feature_maps.shape}, requires_grad: {feature_maps.requires_grad}")
        hook = layer.register_forward_hook(hook_fn)
        hooks.append(hook)

    # Forward and backward pass
    with torch.enable_grad(), autocast('cuda', dtype=torch.float32):
        try:
            outputs = model(**inputs)
            logits_per_image = outputs.contrastive_logits_per_image
            print(f"Contrastive logits shape: {logits_per_image.shape}")
            target_logit = logits_per_image[0, target_label_idx].float()
            print(f"Target logit: {target_logit.item()}")
            model.zero_grad()
            target_logit.backward()
            print(f"VRAM after backward: {torch.cuda.memory_allocated() / 1024**3:.2f} GiB")
        except Exception as e:
            print(f"Forward/backward error: {e}")
            for hook in hooks:
                hook.remove()
            return

    # Compute and save heatmap for each layer
    for layer_idx in range(num_layers):
        feature_maps = feature_maps_dict.get(layer_idx)
        if feature_maps is None:
            print(f"Layer {layer_idx} - Feature maps not captured")
            continue
        if not hasattr(feature_maps, 'grad') or feature_maps.grad is None:
            print(f"Layer {layer_idx} - Gradients not captured")
            continue

        gradients = feature_maps.grad
        print(f"Layer {layer_idx} - Gradients shape: {gradients.shape}")

        # Use feature maps and gradients for the first image
        feature_maps_first = feature_maps[0, :, :]  # (197, 768)
        gradients_first = gradients[0, :, :]  # (197, 768)

        # Ignore CLS token
        feature_maps_patches = feature_maps_first[1:, :]  # (196, 768)
        gradients_patches = gradients_first[1:, :]  # (196, 768)

        # Compute weights
        weights = gradients_patches.mean(dim=0, keepdim=True)  # (1, 768)

        # Compute heatmap
        heatmap = torch.relu((feature_maps_patches * weights).sum(dim=1))  # (196,)
        heatmap = heatmap.view(14, 14)  # Reshape to 14x14 grid
        heatmap = heatmap / (heatmap.max() + 1e-6)

        # Upsample to image size
        upsample = T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC)
        heatmap = upsample(heatmap.to(torch.float32).unsqueeze(0)).squeeze(0).detach().cpu().numpy()

        # Visualize
        heatmap = np.uint8(255 * heatmap)
        heatmap_colored = plt.get_cmap('jet')(heatmap / 255.0)[:, :, :3]
        image_np = np.array(image.resize((224, 224))) / 255.0
        superimposed_img = heatmap_colored * 0.4 + image_np * 0.6

        plt.figure(figsize=(8, 6))
        plt.imshow(superimposed_img)
        plt.axis('off')
        plt.title(f"Grad-CAM for FLAVA Layer {layer_idx}")
        plt.colorbar(plt.cm.ScalarMappable(cmap='jet'), label='Attention Intensity')
        #os.makedirs('/home/bboulbarss/gradcam/flava/flava_gradcam_', exist_ok=True)
        plt.savefig(f"/home/bboulbarss/gradcam_results/image1/flava-rel/thesis_flava_gradcam_rel_1_predicted_label/gradcam_flava_layer_{layer_idx}.png")
        plt.close()

        torch.cuda.empty_cache()
        print(f"Layer {layer_idx} - VRAM after heatmap: {torch.cuda.memory_allocated() / 1024**3:.2f} GiB")

    # Clean up
    for hook in hooks:
        hook.remove()
    torch.cuda.empty_cache()
    print(f"VRAM after cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GiB")


##############################################################################################################

########################################
############# FINAL IMAGES #############
########################################

################################################################################################
## Image 1 in final_gradcam
## Relational, correct label
texts = [
    "A photo of a cylinder left of a cone",
    "A photo of a cylinder right of a cone",
    "A photo of a cone left of a cylinder",
    "A photo of a cube right of a cylinder",
    "A photo of a sphere right of a cone",]
target_label = "A photo of a cylinder left of a cone"
image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_cone/CLEVR_rel_000020.png"

## Relational, predicted label
texts = [
    "A photo of a cylinder left of a cone",
    "A photo of a cylinder right of a cone",
    "A photo of a cone left of a cylinder",
    "A photo of a cube right of a cylinder",
    "A photo of a sphere right of a cone",]
target_label = "A photo of a cube right of a cylinder"
image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_cone/CLEVR_rel_000020.png"


# Two object
#texts = [
#    "A photo of a purple cylinder",
#    "A photo of a green cylinder",
#    "A photo of a purple cone",
#    "A photo of a red sphere",
#    "A photo of a blue cube"
#]
#target_label = "A photo of a purple cylinder"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_cone/CLEVR_rel_000020.png"

################################################################################################
# Image 2 in final_gradcam
# Relational
#texts = [
#    "A photo of a cylinder left of a sphere",
#    "A photo of a cylinder right of a sphere",
#    "A photo of a sphere left of a cylinder",
#    "A photo of a cube right a cone",
#    "A photo of a sphere right of a cone",]
#target_label = "A photo of a cylinder left of a sphere"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_sphere/CLEVR_rel_000031.png"
#
# Two object
#texts = [
#    "A photo of a green sphere",
#    "A photo of a blue sphere",
#    "A photo of a green cylinder",
#    "A photo of a red cone",
#    "A photo of a purple cube",
#]
#target_label = "A photo of a green sphere"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_sphere/CLEVR_rel_000031.png"

################################################################################################

compute_gradcam_flava(image_path, texts, target_label, model, processor)

`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if you are doing inference on unmasked text...


Total layers in FLAVA Image Encoder: 12
Layer 0 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 1 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 2 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 3 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 4 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 5 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 6 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 7 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 8 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 9 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 10 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 11 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 0 - Hooked shape: torch.Size([5, 197, 768]), requires_grad: True
Layer 1 - Hooked shape: torch.Size(



Contrastive logits shape: torch.Size([5, 5])
Target logit: 21.54041862487793
Grad hook called for layer 11 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 10 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 9 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 8 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 7 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 6 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 5 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 4 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 3 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 2 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 1 with grad shape: torch.Size([5, 197, 768])
Grad hook called for layer 0 with grad shape: torch.Size([5, 197, 768])
VRAM after backward: 5.03 GiB
Layer 0 - Gradients shape: 

  plt.colorbar(plt.cm.ScalarMappable(cmap='jet'), label='Attention Intensity')


Layer 0 - VRAM after heatmap: 5.03 GiB
Layer 1 - Gradients shape: torch.Size([5, 197, 768])
Layer 1 - VRAM after heatmap: 5.03 GiB
Layer 2 - Gradients shape: torch.Size([5, 197, 768])
Layer 2 - VRAM after heatmap: 5.03 GiB
Layer 3 - Gradients shape: torch.Size([5, 197, 768])
Layer 3 - VRAM after heatmap: 5.03 GiB
Layer 4 - Gradients shape: torch.Size([5, 197, 768])
Layer 4 - VRAM after heatmap: 5.03 GiB
Layer 5 - Gradients shape: torch.Size([5, 197, 768])
Layer 5 - VRAM after heatmap: 5.03 GiB
Layer 6 - Gradients shape: torch.Size([5, 197, 768])
Layer 6 - VRAM after heatmap: 5.03 GiB
Layer 7 - Gradients shape: torch.Size([5, 197, 768])
Layer 7 - VRAM after heatmap: 5.03 GiB
Layer 8 - Gradients shape: torch.Size([5, 197, 768])
Layer 8 - VRAM after heatmap: 5.03 GiB
Layer 9 - Gradients shape: torch.Size([5, 197, 768])
Layer 9 - VRAM after heatmap: 5.03 GiB
Layer 10 - Gradients shape: torch.Size([5, 197, 768])
Layer 10 - VRAM after heatmap: 5.03 GiB
Layer 11 - Gradients shape: torch.Size(

In [6]:

##### IMAGE 1 #####
## RELATIONAL ##
image_path = "/home/bboulbarss/large_dataset/relational/train/cube_left_sphere/CLEVR_rel_000065.png"
texts = [
    "A photo of a cube left of a sphere",
    "A photo of a cube right of a sphere",
    "A photo of a sphere left of a cube",
    "A photo of a cube left of a cone",
    "A photo of a cylinder right of a cube"
]
target_label = "A photo of a cube left of a sphere"

## TWO OBJECT ##
texts = [
    "A photo of a gray cube",
    "A photo of a yellow cube",
    "A photo of a gray sphere",
    "A photo of a purple cube",
    "A photo of a red cone"
]
target_label = "A photo of a grey cube"


##### IMAGE 2 #####
## RELATIONAL ##
image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_cone/CLEVR_rel_000085.png"
texts = [
    "A photo of a cylinder left of a cone",
    "A photo of a cone left of a cylinder",
    "A photo of a cylinder right of a cone",
    "A photo of a cube left of a cone",
    "A photo of a cylinder right of a cube"
]
target_label = "A photo of a cylinder left of a cone"

## TWO OBJECT ##
texts = [
    "A photo of a blue cone",
    "A photo of a purple cone",
    "A photo of a blue cylinder",
    "A photo of a gray cube",
    "A photo of a red sphere"
]
target_label = "A photo of a blue cone"


##### IMAGE 3 #####
## RELATIONAL ##
image_path = "/home/bboulbarss/large_dataset/relational/train/cube_right_cylinder/CLEVR_rel_000057.png"
texts = [
    "A photo of a cube right of a cylinder",
    "A photo of a cube left of a cylinder",
    "A photo of a cone right of a cylinder",
    "A photo of a cube left of a cone",
    "A photo of a cylinder right of a cone"
]
target_label = "A photo of a cube right of a cylinder"

## TWO OBJECT ##
texts = [
    "A photo of a blue cylinder",
    "A photo of a cyan cylinder",
    "A photo of a blue cube",
    "A photo of a brown sphere",
    "A photo of a yellow cube"
]
target_label = "A photo of a blue cylinder"


##### IMAGE 4 #####
## RELATIONAL ##
image_path = "/home/bboulbarss/large_dataset/relational/train/sphere_left_cone/CLEVR_rel_000007.png"
texts = [
    "A photo of a sphere left of a cone",
    "A photo of a sphere right of a cone",
    "A photo of a cone left of a sphere",
    "A photo of a cube left of a cone",
    "A photo of a cylinder right of a cone"
]
target_label = "A photo of a sphere left of a cone"

## TWO OBJECT ##
texts = [
    "A photo of a purple cone",
    "A photo of a blue cone",
    "A photo of a purple sphere",
    "A photo of a green cylinder",
    "A photo of a yellow cube"
]
target_label = "A photo of a purple cone"