In [2]:
import torch
from transformers import AutoImageProcessor, SwinModel
from PIL import Image as pimg
import sys
import torch.nn as nn
import torch
import numpy as np
import matplotlib.pyplot as plt
from captum.attr import visualization as viz
from captum.attr import Occlusion
import tempfile
import os
import uuid
import torch.nn.functional as F
from captum.attr import Occlusion

def preprocess_image(img_path):
    image = pimg.open(img_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    return image, inputs

def get_model_features(inputs):
    with torch.no_grad():
        return swin(**inputs).pooler_output

def get_cmap_by_label(label):
    return "Greens" if label == "Real" else "Reds"

def generate_occlusion_map3(img_path, label):
    image, inputs = preprocess_image(img_path)
    img_tensor = inputs['pixel_values'].to(device)
    baseline = torch.zeros_like(img_tensor).to(device)

    def model_wrapper(input_tensor):
        feats = swin(pixel_values=input_tensor).pooler_output
        return classifier(feats)

    occlusion = Occlusion(model_wrapper)
    attributions_occ = occlusion.attribute(
        img_tensor,
        strides=(3, 16, 16),
        sliding_window_shapes=(3, 30, 30),
        baselines=baseline
    )

    cmap = get_cmap_by_label(label)

    # Get the fig and axis separately
    fig, _ = viz.visualize_image_attr_multiple(
        attributions_occ[0].cpu().permute(1, 2, 0).detach().numpy(),
        np.array(image) / 255.0,
        methods=["original_image", "heat_map"],
        signs=["all", "positive"],
        titles=["Original", "Occlusion Map"],
        cmap=cmap,
        show_colorbar=True,
        outlier_perc=1,
        use_pyplot=False
    )

    fig.suptitle(
        "Detection Justification via Occlusion Maps",
        fontsize=16,
        fontweight='bold',
        color="#004830"
    )

    
    fig.tight_layout(rect=[0, 0, 1, 0.97])

    # Save to a temp image
    temp_dir = tempfile.gettempdir()
    filename = f"occlusion_map_{uuid.uuid4().hex}.png"
    save_path = os.path.join(temp_dir, filename)
    fig.savefig(save_path, bbox_inches='tight')
    plt.close(fig)

    return save_path