  # Import libraries

In [15]:
import os
import cv2  # OpenCV for image processing
import itertools # Added for saliency calculation
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torch.hub
# import torch.nn.functional as F # Already imported via torchvision below

from torchvision import models, transforms
import torchvision.transforms.functional as F # For potential use in visualization (though visualize_and_save_saliency uses different method)

# Use tqdm.auto for better console/notebook detection and nesting
from tqdm.auto import tqdm
import time # Optional: Can add timing info to postfix



  # Database creations using pytorch Dataset

In [16]:
class ImageAuthenticityDataset(Dataset):
    """Dataset for image quality assessment."""

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the CSV file with annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.dir_path = os.path.dirname(csv_file)  # Directory of the CSV file

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx,):
        """
        Retrieves an image and its labels by index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple (image, labels) where:
                image (PIL.Image): The image.
                labels (torch.Tensor): Tensor containing quality and authenticity scores.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # TODO: to be fixed, right now is folder dependent
        img_path_relative = self.data.iloc[idx, 3]
        # Construct absolute path based on CSV location
        base_dir = os.path.abspath(os.path.join(self.dir_path, '../../')) # Go up two levels from CSV dir
        img_name = os.path.join(base_dir, img_path_relative.replace("./", "")) # Combine and remove './'

        # Ensure path exists before opening
        if not os.path.exists(img_name):
             # Fallback or error handling if path logic is complex
             print(f"Warning: Image path {img_name} not found directly. Trying original relative path logic...")
             img_name = self.data.iloc[idx, 3].replace("./", "../../") # Original logic as fallback
             if not os.path.exists(img_name):
                  raise FileNotFoundError(f"Could not find image file at primary path: {os.path.join(base_dir, img_path_relative.replace('./', ''))} or fallback: {img_name}")


        image = Image.open(img_name).convert('RGB')
        authenticity = self.data.iloc[idx, 1]  # Authenticity column
        labels = torch.tensor([authenticity], dtype=torch.float)


        if self.transform:
            image = self.transform(image)

        return image, labels




  # Definitions of the models

In [17]:
class AuthenticityPredictor(nn.Module):
    def __init__(self, freeze_backbone=True):
        super().__init__()
        # Load pre-trained BarlowTwins ResNet50 instead of ResNet-152
        barlow_twins_resnet = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')

        # Freeze backbone if requested
        if freeze_backbone:
            for param in barlow_twins_resnet.parameters():
                param.requires_grad = False
            print("Model backbone frozen.")
        else:
            print("Model backbone NOT frozen (trainable).")


        self.features = nn.Sequential(*list(barlow_twins_resnet.children())[:-2])
        self.avgpool = barlow_twins_resnet.avgpool


        self.regression_head = nn.Sequential(
                nn.Linear(2048, 512),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(512, 128),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(128, 1)
            )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        predictions = self.regression_head(x)
        # Return both predictions and flattened features (useful for some saliency methods)
        return predictions, x



  ## Setup section

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Data transformations for the ImageNet dataset
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Define path relative to the script location or use an absolute path
# Assuming the script is run from a location where '../../Dataset/...' is valid
try:
    # Try relative path first
    annotations_file = '../../Dataset/AIGCIQA2023/real_images_annotations.csv'
    # Check if the file exists using the relative path
    if not os.path.exists(annotations_file):
        # If relative path fails, try constructing from script directory
        script_dir = os.path.dirname(__file__) # Get directory of the script
        annotations_file = os.path.abspath(os.path.join(script_dir, '../../Dataset/AIGCIQA2023/real_images_annotations.csv'))
        if not os.path.exists(annotations_file):
            raise FileNotFoundError(f"Annotations file not found at relative or script-based path: {annotations_file}")
except NameError:
     # __file__ is not defined (e.g., running in interactive environment like Jupyter)
     # Fallback to assuming relative path from CWD or specify absolute path directly
     annotations_file = '../../Dataset/AIGCIQA2023/real_images_annotations.csv' # Or provide absolute path
     print("Warning: __file__ not defined. Assuming relative path for annotations file.")
     if not os.path.exists(annotations_file):
        raise FileNotFoundError(f"Annotations file not found at relative path: {annotations_file}. Please provide absolute path if needed.")


print(f"Loading annotations from: {annotations_file}")

# Create the dataset
dataset = ImageAuthenticityDataset(csv_file=annotations_file, transform=data_transforms)

# Set random seeds for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42) # Use manual_seed_all for multi-GPU setups if relevant
np.random.seed(42)
# Potentially add for DataLoader determinism (might impact performance)
# def seed_worker(worker_id):
#     worker_seed = torch.initial_seed() % 2**32
#     numpy.random.seed(worker_seed)
#     random.seed(worker_seed)
# g = torch.Generator()
# g.manual_seed(42)

# Split the dataset into training, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
print(f"Dataset size: {len(dataset)}. Splitting into Train: {train_size}, Val: {val_size}, Test: {test_size}")
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)) # Add generator for split reproducibility


# Create data loaders
BATCH_SIZE = 1 # Set batch size to 1 for easier processing of individual images for saliency
NUM_WORKERS = 4 # Adjust based on your system
# Consider adding pin_memory=True if using GPU for potentially faster data transfer
# Create only the data loader for test
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)#, worker_init_fn=seed_worker, generator=g)



Using device: cuda
Loading annotations from: ../../Dataset/AIGCIQA2023/real_images_annotations.csv
Dataset size: 1368. Splitting into Train: 957, Val: 273, Test: 138


  # Models loading

In [19]:
# Define path relative to script or use absolute path
try:
    script_dir = os.path.dirname(__file__)
    BASELINE_MODEL_PATH = os.path.abspath(os.path.join(script_dir, 'Weights/BarlowTwins_real_authenticity_finetuned.pth'))
except NameError:
    # Fallback for interactive environments
    BASELINE_MODEL_PATH = 'Weights/BarlowTwins_real_authenticity_finetuned.pth'
    print("Warning: __file__ not defined. Assuming relative path for model weights.")

print(f"Loading baseline model weights from: {BASELINE_MODEL_PATH}")
if not os.path.exists(BASELINE_MODEL_PATH):
     raise FileNotFoundError(f"Model weights file not found at {BASELINE_MODEL_PATH}")

# Instantiate model with frozen backbone by default
baseline_model = AuthenticityPredictor(freeze_backbone=True)
# Load weights - ensure map_location handles CPU/GPU loading correctly
baseline_model.load_state_dict(torch.load(BASELINE_MODEL_PATH, map_location=device))
baseline_model.eval().to(device) # Set to evaluation mode and move to device
print("Baseline model loaded and set to evaluation mode.")



Loading baseline model weights from: Weights/BarlowTwins_real_authenticity_finetuned.pth


Using cache found in /home/icaro.redepaolini@unitn.it/.cache/torch/hub/facebookresearch_barlowtwins_main


Model backbone frozen.
Baseline model loaded and set to evaluation mode.


  baseline_model.load_state_dict(torch.load(BASELINE_MODEL_PATH, map_location=device))


  # Functions definitions (Image Utils & Saliency)

In [20]:
def create_sigma_list_from_image(image_tensor, percentages=[0.05, 0.10, 0.20], min_sigma=3):
    """
    Generates a list of sigma values based on image dimensions.

    Args:
        image_tensor (torch.Tensor): The input image tensor (e.g., shape [C, H, W]).
        percentages (list, optional): List of percentages of the smaller dimension
                                      to use for sigma values. Defaults to [0.05, 0.10, 0.20].
        min_sigma (int, optional): The minimum allowed sigma value. Defaults to 3.

    Returns:
        list: A sorted list of unique integer sigma values.
    """
    if image_tensor.dim() < 2:
        raise ValueError("image_tensor must have at least 2 dimensions (H, W)")

    # Get height and width (assuming shape [..., H, W])
    img_height = image_tensor.shape[-2]
    img_width = image_tensor.shape[-1]
    min_dim = min(img_height, img_width)

    print(f"Image dimensions: H={img_height}, W={img_width}. Using min_dim={min_dim} for sigma calculation.")

    sigma_list_calculated = []
    for p in percentages:
        # Calculate sigma based on percentage
        sigma = int(min_dim * p) # Truncate to integer

        # Ensure sigma is at least min_sigma
        sigma = max(min_sigma, sigma)

        # Ensure sigma is odd (can help with centering)
        if sigma % 2 == 0:
            sigma += 1

        # Ensure sigma doesn't exceed image dimensions (unlikely if based on min_dim, but safe check)
        sigma = min(sigma, min_dim)

        sigma_list_calculated.append(sigma)

    # Remove duplicates and sort the list
    sigma_list_final = sorted(list(set(sigma_list_calculated)))

    print(f"Generated sigma list: {sigma_list_final}")
    return sigma_list_final

def generate_mask(img_size, center, sigma):
    """Generates a binary mask with a square of zeros centered at 'center' with size 'sigma x sigma'."""
    mask = torch.ones(1, 1, img_size[0], img_size[1], device=device)
    start_x = max(0, int(center[0] - sigma // 2))
    end_x = min(img_size[1], int(center[0] + (sigma + 1) // 2))
    start_y = max(0, int(center[1] - sigma // 2))
    end_y = min(img_size[0], int(center[1] + (sigma + 1) // 2))
    if start_y < end_y and start_x < end_x:
        mask[:, :, start_y:end_y, start_x:end_x] = 0
    return mask

def calculate_saliency_map(model, image, original_score, sigma_list, mask_value=0.0):
    """
    Calculates the multiscale saliency map using the occlusion method
    by summing scores across scales and normalizing the result,
    with nested progress bars and showing current pixel impact in the postfix.

    Args:
        model (torch.nn.Module): The model to use for inference. Must be on the correct device.
        image (torch.Tensor): The input image tensor (C, H, W), must be on the correct device.
        original_score (float): The model's score for the original, unoccluded image.
        sigma_list (list): List of integers representing the sizes (side length) of the occlusion squares.
        mask_value (float, optional): Value to use for occluded regions. Defaults to 0.0.

    Returns:
        numpy.ndarray: A normalized saliency map (H, W) as a NumPy array on the CPU.
    """
    model.eval()
    # Ensure image is on the correct device and add batch dimension
    img_tensor = image.unsqueeze(0).to(device)
    img_size = img_tensor.shape[2:] # H, W
    # Initialize the final map to store the sum of per-scale saliencies
    saliency_map_final = torch.zeros(img_size, dtype=torch.float32, device=device)

    print(f"Calculating saliency for image size {img_size} using {len(sigma_list)} sigmas: {sigma_list}")

    # --- Outer Progress Bar for Sigmas ---
    outer_progress = tqdm(
        enumerate(sigma_list),
        total=len(sigma_list),
        desc="Overall Sigmas ", # Add space for better alignment
        unit="sigma",
        position=0,
        leave=True # Keep after finishing
    )

    for i, sigma in outer_progress:
        # Temporary map for the current sigma scale
        saliency_map_sigma = torch.zeros(img_size, dtype=torch.float32, device=device)

        # --- Inner Progress Bar for Pixels ---
        pixel_iterator = itertools.product(range(img_size[0]), range(img_size[1]))
        total_pixels = img_size[0] * img_size[1]
        inner_progress_bar = tqdm(
            pixel_iterator,
            total=total_pixels,
            desc=f"  Sigma {i+1}/{len(sigma_list)} (val={sigma: >3}) Pixels",
            leave=False, # Remove after each sigma finishes
            unit="pixel",
            
            position=1,
            mininterval=0.1 # Refresh rate throttle (optional)
        )

        start_time = time.time() # For calculating rate
        for y, x in inner_progress_bar:
            # Generate mask for the current pixel and sigma
            mask = generate_mask(img_size, (x, y), sigma)
            # Apply mask
            masked_image = img_tensor * mask + mask_value * (1 - mask)

            # Get model prediction for the masked image
            with torch.no_grad():
                output = model(masked_image)
                # Handle cases where model returns multiple outputs (e.g., prediction, features)
                if isinstance(output, tuple) and len(output) > 0:
                    masked_score_tensor = output[0]
                else:
                    masked_score_tensor = output
                # Ensure score is a scalar on CPU
                masked_score_item = masked_score_tensor.detach().cpu().item()

            # Calculate saliency value for this pixel and sigma
            saliency_value = original_score - masked_score_item
            saliency_map_sigma[y, x] = saliency_value

            # --- Update Postfix with Current Pixel Info ---
            inner_progress_bar.set_postfix(
                pixel=f"({y},{x})",
                impact=f"{saliency_value:.4f}", # Format saliency value
                refresh=False # Update display on tqdm's schedule
            )
            # -----------------------------------------------

        saliency_map_final += saliency_map_sigma
        # ---------------------------------------------------------------

        # Optional: Print time taken per sigma
        elapsed_time = time.time() - start_time
        pixels_per_sec = total_pixels / elapsed_time if elapsed_time > 0 else float('inf')
        tqdm.write(f"  Sigma {sigma} finished in {elapsed_time:.2f}s ({pixels_per_sec:.1f} pixels/sec)")

    # --- Normalization ---
    # Normalize the *summed* final map
    min_val = torch.min(saliency_map_final)
    max_val = torch.max(saliency_map_final)

    if max_val > min_val:
        # Perform min-max normalization to range [0, 1]
        saliency_map_normalized = (saliency_map_final - min_val) / (max_val - min_val)
    else:
        # Handle the case where the map is constant (all saliency values were the same)
        saliency_map_normalized = torch.zeros_like(saliency_map_final)
        print("Warning: Final saliency map was constant before normalization. Result is zero map.")

    # Return the normalized map as a NumPy array on the CPU
    return saliency_map_normalized.cpu().numpy()
def denormalize_image(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Denormalizes an image tensor."""
    if tensor.dim() != 3:
        raise ValueError(f"Input tensor must have 3 dimensions (C, H, W), but got {tensor.dim()}")
    if tensor.shape[0] != len(mean) or tensor.shape[0] != len(std):
         # Handle grayscale - assume mean/std are single values or adaptable
        if tensor.shape[0] == 1 and len(mean) == 3 and len(std) == 3:
             print("Warning: Denormalizing grayscale with potentially RGB stats. Using first value.")
             mean_used = [mean[0]]
             std_used = [std[0]]
        elif tensor.shape[0] == 1 and isinstance(mean, (int, float)) and isinstance(std, (int, float)):
             mean_used = [mean]
             std_used = [std]
        elif tensor.shape[0] == 1 and isinstance(mean, (list, tuple)) and isinstance(std, (list, tuple)) and len(mean) > 0 and len(std) > 0:
             print("Warning: Denormalizing grayscale with potentially multi-channel stats. Using first value.")
             mean_used = [mean[0]]
             std_used = [std[0]]
        else:
            raise ValueError(f"Channel mismatch: Tensor has {tensor.shape[0]} channels, mean has {len(mean)}, std has {len(std)}")
    else:
        mean_used = mean
        std_used = std

    mean_t = torch.as_tensor(mean_used, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    std_t = torch.as_tensor(std_used, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)

    denormalized_tensor = tensor * std_t + mean_t
    return torch.clamp(denormalized_tensor, 0., 1.)

def visualize_and_save_saliency(
    image_tensor,
    saliency_map,
    output_dir,
    filename_prefix,
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
    overlay_alpha=0.5,
    cmap_name='bwr'): # Added cmap_name parameter
    """
    Visualizes saliency map using a specified colormap, creates an overlay
    using OpenCV, and saves the original, heatmap, and overlay images.

    Args:
        image_tensor (torch.Tensor): Original image tensor (C, H, W), must be on CPU.
        saliency_map (numpy.ndarray): Calculated saliency map (H, W), normalized [0, 1].
        output_dir (str): Directory to save the output images.
        filename_prefix (str): Prefix for the saved filenames (e.g., 'sample_01').
        mean (list, optional): Mean used for image normalization.
        std (list, optional): Standard deviation used for image normalization.
        overlay_alpha (float, optional): Opacity of the heatmap in the overlay. Defaults to 0.5.
        cmap_name (str, optional): Name of the matplotlib colormap to use. Defaults to 'bwr'.
    """
    if image_tensor.is_cuda:
        print("Warning: image_tensor provided to visualize_and_save_saliency is on CUDA, moving to CPU.")
        image_tensor = image_tensor.cpu()

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    temp_dir = os.path.join(output_dir, 'temp_heatmap') # For temporary heatmap file
    os.makedirs(temp_dir, exist_ok=True)

    # 1. Prepare Original Image
    try:
        img_denorm_tensor = denormalize_image(image_tensor, mean, std)
    except ValueError as e:
        print(f"Error during denormalization: {e}")
        print(f"Image tensor shape: {image_tensor.shape}")
        return # Skip visualization for this image if denormalization fails

    img_np = img_denorm_tensor.numpy().transpose(1, 2, 0) # H, W, C
    # Ensure image values are in [0, 1] before scaling to [0, 255]
    img_np = np.clip(img_np, 0.0, 1.0)
    img_uint8 = (img_np * 255).astype(np.uint8)


    # Handle grayscale conversion for saving/display if needed
    if img_uint8.shape[2] == 1:
        img_display = cv2.cvtColor(img_uint8, cv2.COLOR_GRAY2RGB) # Keep 3 channels for consistency
        img_bgr = cv2.cvtColor(img_uint8, cv2.COLOR_GRAY2BGR) # For OpenCV overlay
    elif img_uint8.shape[2] == 3:
        img_display = img_uint8 # Already RGB H,W,C
        img_bgr = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR) # Convert to BGR for OpenCV
    else:
         print(f"Error: Unexpected number of channels ({img_uint8.shape[2]}) in denormalized image.")
         return


    # Save original image
    orig_save_path = os.path.join(output_dir, f"{filename_prefix}_original.png")
    plt.figure(figsize=(8, 8))
    plt.imshow(img_display)
    plt.axis('off')
    plt.title("Original Image")
    plt.savefig(orig_save_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    # print(f"Saved original image to {orig_save_path}") # Less verbose

    # 2. Prepare and Save Standalone Heatmap
    try:
        cmap = cm.get_cmap(cmap_name)
    except ValueError:
        print(f"Warning: Colormap '{cmap_name}' not found. Using default 'viridis'.")
        cmap = cm.get_cmap('viridis')
    norm = colors.Normalize(vmin=0, vmax=1) # Normalize from 0 to 1

    heatmap_save_path = os.path.join(output_dir, f"{filename_prefix}_heatmap_{cmap_name}.png")
    plt.figure(figsize=(8, 8))
    # Ensure saliency map has correct dimensions (H, W)
    if saliency_map.ndim != 2:
        print(f"Error: Saliency map has unexpected dimensions {saliency_map.shape}. Expected (H, W).")
        plt.close()
        # Clean up temp dir if created
        if os.path.exists(temp_heatmap_path): os.remove(temp_heatmap_path)
        if os.path.exists(temp_dir) and not os.listdir(temp_dir): os.rmdir(temp_dir)
        return

    plt.imshow(saliency_map, cmap=cmap, norm=norm)
    plt.colorbar(label=f'Normalized Saliency (0: Low/{cmap(0.0)[:3]}, 1: High/{cmap(1.0)[:3]})') # Indicate colors
    plt.title(f"Saliency Heatmap ({cmap_name})")
    plt.axis('off')
    plt.savefig(heatmap_save_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    # print(f"Saved heatmap to {heatmap_save_path}") # Less verbose


    # 3. Create Overlay using OpenCV
    # Generate colored heatmap image *without* axes/colorbar
    temp_heatmap_path = os.path.join(temp_dir, f"{filename_prefix}_temp_heatmap.png")
    # Match aspect ratio and use known DPI for predictable sizing
    fig_width_inches = img_display.shape[1] / 100.0
    fig_height_inches = img_display.shape[0] / 100.0
    plt.figure(figsize=(fig_width_inches, fig_height_inches), dpi=100)
    plt.imshow(saliency_map, cmap=cmap, norm=norm)
    plt.axis('off')
    plt.savefig(temp_heatmap_path, bbox_inches='tight', pad_inches=0, dpi=100)
    plt.close()

    # Read the saved heatmap with OpenCV
    colored_heatmap_bgr = cv2.imread(temp_heatmap_path)

    # Clean up temporary file
    if os.path.exists(temp_heatmap_path):
        os.remove(temp_heatmap_path)
        try:
            # Attempt to remove temp dir only if it's empty
            if not os.listdir(temp_dir):
                 os.rmdir(temp_dir)
        except OSError:
            pass # Ignore if not empty (e.g., race condition in parallel runs)

    if colored_heatmap_bgr is None:
        print(f"Error: Could not read temporary heatmap file: {temp_heatmap_path}")
        return

    # Resize heatmap to match original image size (important safety check)
    if colored_heatmap_bgr.shape[:2] != img_bgr.shape[:2]:
         print(f"Warning: Resizing heatmap from {colored_heatmap_bgr.shape[:2]} to {img_bgr.shape[:2]}")
         colored_heatmap_bgr = cv2.resize(colored_heatmap_bgr, (img_bgr.shape[1], img_bgr.shape[0]),
                                          interpolation=cv2.INTER_LINEAR)


    # Blend the images using cv2.addWeighted
    overlay = cv2.addWeighted(
        src1=img_bgr,             # Original image (BGR)
        alpha=1.0 - overlay_alpha,# Weight for original image
        src2=colored_heatmap_bgr, # Colored heatmap (BGR)
        beta=overlay_alpha,       # Weight for heatmap
        gamma=0.0                 # Scalar added to each sum
    )

    # Convert overlay back to RGB for saving with matplotlib/saving directly
    overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)

    # Save overlay image
    overlay_save_path = os.path.join(output_dir, f"{filename_prefix}_overlay_{cmap_name}.png")
    # Save directly using OpenCV for potentially better fidelity than matplotlib savefig
    try:
         # Convert RGB back to BGR for cv2.imwrite
         success = cv2.imwrite(overlay_save_path, cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR))
         if not success:
              print(f"Error: cv2.imwrite failed to save overlay to {overlay_save_path}")
              # Fallback to matplotlib saving if cv2 fails
              plt.figure(figsize=(8, 8))
              plt.imshow(overlay_rgb)
              plt.axis('off')
              plt.title(f"Saliency Overlay ({cmap_name})")
              plt.savefig(overlay_save_path, bbox_inches='tight', pad_inches=0)
              plt.close()

         # print(f"Saved overlay to {overlay_save_path}") # Less verbose
    except Exception as e:
         print(f"Exception during overlay saving: {e}")
         # Fallback just in case
         plt.figure(figsize=(8, 8))
         plt.imshow(overlay_rgb)
         plt.axis('off')
         plt.title(f"Saliency Overlay ({cmap_name})")
         plt.savefig(overlay_save_path, bbox_inches='tight', pad_inches=0)
         plt.close()



  # Main Execution Block

In [21]:
if __name__ == "__main__":
    print("\n--- Starting Saliency Map Generation and Visualization ---")
    # Define the directory to save output images
    output_visualization_dir = 'saliency_visualizations_real'
    os.makedirs(output_visualization_dir, exist_ok=True)
    print(f"Output visualizations will be saved in: {output_visualization_dir}")

    # --- Configuration ---
    NUM_IMAGES_TO_PROCESS = 1 # Number of images to process from the test set
    # Sigma generation parameters
    SIGMA_PERCENTAGES = None
    SIGMA_LIST = [3,5,9,17,33,65] # Extracted from the original code, but can be None to generate based on percentages
    MIN_SIGMA = 5 # Minimum occlusion size in pixels
    MASK_VALUE = 0.0 # Value for occluded pixels (0.0 for black, could be mean image value)
    VIS_CMAP = 'jet' # Colormap for visualization ('jet', 'bwr', 'viridis', etc.)
    VIS_ALPHA = 0.6 # Overlay opacity

    print(f"Processing {NUM_IMAGES_TO_PROCESS} images from the test set.")
    print(f"Using device: {device}")
    print(f"Saliency generation sigmas based on percentages: {SIGMA_PERCENTAGES}, min_sigma: {MIN_SIGMA}")
    print(f"Occlusion mask value: {MASK_VALUE}")
    print(f"Visualization colormap: {VIS_CMAP}, alpha: {VIS_ALPHA}")


    # --- Processing Loop ---
    processed_count = 0
    # Ensure model is in eval mode (already done after loading, but good practice)
    baseline_model.eval()

    # Iterate through the test dataloader
    # Wrap the dataloader with tqdm for overall progress
    test_iterator = tqdm(test_dataloader, total=min(NUM_IMAGES_TO_PROCESS, len(test_dataloader)), desc="Overall Progress")

    for i, (images, labels) in enumerate(test_iterator):
        if processed_count >= NUM_IMAGES_TO_PROCESS:
            print(f"\nReached limit of {NUM_IMAGES_TO_PROCESS} images. Stopping.")
            break

        # Since BATCH_SIZE=1, images and labels contain single items
        image_tensor = images[0].to(device) # Get the image tensor, move to device
        label = labels[0]                   # Get the corresponding label tensor

        print(f"\nProcessing image {processed_count + 1}/{NUM_IMAGES_TO_PROCESS} (DataLoader index: {i})")

        # 1. Get Original Model Score
        with torch.no_grad():
            # Add batch dimension for the model, ensure image is on device
            original_output = baseline_model(image_tensor.unsqueeze(0))
            # Unpack score based on model's forward method
            if isinstance(original_output, tuple) and len(original_output) > 0:
                original_score_tensor = original_output[0]
            else:
                original_score_tensor = original_output
            # Get the scalar score value
            original_score = original_score_tensor.item()

        true_authenticity = label.item() # Get scalar value from label tensor
        print(f"  True Authenticity: {true_authenticity:.4f}")
        print(f"  Original Predicted Authenticity: {original_score:.4f}")


        # 2. Generate Sigma List for this image
        sigma_list = SIGMA_LIST if SIGMA_LIST else create_sigma_list_from_image(
            image_tensor=image_tensor, # Pass the single image tensor (C, H, W) on device
            percentages=SIGMA_PERCENTAGES if SIGMA_PERCENTAGES else [0.05, 0.10, 0.20],
            min_sigma=MIN_SIGMA
        )
        print(f"  Sigma list for this image: {sigma_list}")
        
        if not sigma_list:
             print("  Warning: No sigma values generated. Skipping saliency calculation.")
             processed_count += 1
             test_iterator.set_postfix_str(f"Image {processed_count}/{NUM_IMAGES_TO_PROCESS} (Skipped)")
             continue # Skip to the next image


        # 3. Calculate Saliency Map
        print(f"  Calculating saliency map...")
        saliency_map_np = calculate_saliency_map(
            model=baseline_model,
            image=image_tensor, # Pass the single image tensor (C, H, W) on device
            original_score=original_score,
            sigma_list=sigma_list,
            mask_value=MASK_VALUE
        )
        print(f"  Saliency map calculated with shape: {saliency_map_np.shape}")


        # 4. Visualization and Saving
        filename_prefix = f"image_{processed_count:03d}_auth_{true_authenticity:.2f}_pred_{original_score:.2f}"
        print(f"  Visualizing and saving results with prefix: {filename_prefix}...")

        # Ensure the image tensor passed to visualize is on CPU and without batch dim
        image_to_visualize = image_tensor.cpu()

        visualize_and_save_saliency(
            image_tensor=image_to_visualize,
            saliency_map=saliency_map_np, # Pass the calculated saliency map (numpy array)
            output_dir=output_visualization_dir,
            filename_prefix=filename_prefix,
            overlay_alpha=VIS_ALPHA,
            cmap_name=VIS_CMAP
        )
        print(f"  Visualization saved.")

        processed_count += 1
        # Update overall progress bar postfix
        test_iterator.set_postfix_str(f"Image {processed_count}/{NUM_IMAGES_TO_PROCESS}")


    # Close the main progress bar upon completion
    test_iterator.close()
    print("\n--- Saliency Map Generation and Visualization Finished ---")




--- Starting Saliency Map Generation and Visualization ---
Output visualizations will be saved in: saliency_visualizations_real
Processing 1 images from the test set.
Using device: cuda
Saliency generation sigmas based on percentages: None, min_sigma: 5
Occlusion mask value: 0.0
Visualization colormap: jet, alpha: 0.6


Overall Progress:   0%|          | 0/1 [00:00<?, ?it/s]


Processing image 1/1 (DataLoader index: 0)
  True Authenticity: 36.9786
  Original Predicted Authenticity: 33.1730
  Sigma list for this image: [3, 5, 9, 17, 33, 65]
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65]


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Pixels:   0%|          | 0/50176 [00:00<?, ?pixel/s]

  Sigma 3 finished in 218.27s (229.9 pixels/sec)


  Sigma 2/6 (val=  5) Pixels:   0%|          | 0/50176 [00:00<?, ?pixel/s]

  Sigma 5 finished in 218.20s (230.0 pixels/sec)


  Sigma 3/6 (val=  9) Pixels:   0%|          | 0/50176 [00:00<?, ?pixel/s]

  Sigma 9 finished in 218.21s (229.9 pixels/sec)


  Sigma 4/6 (val= 17) Pixels:   0%|          | 0/50176 [00:00<?, ?pixel/s]

  Sigma 17 finished in 218.78s (229.3 pixels/sec)


  Sigma 5/6 (val= 33) Pixels:   0%|          | 0/50176 [00:00<?, ?pixel/s]

  Sigma 33 finished in 218.62s (229.5 pixels/sec)


  Sigma 6/6 (val= 65) Pixels:   0%|          | 0/50176 [00:00<?, ?pixel/s]

  Sigma 65 finished in 218.32s (229.8 pixels/sec)
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: image_000_auth_36.98_pred_33.17...


  cmap = cm.get_cmap(cmap_name)


  Visualization saved.

Reached limit of 1 images. Stopping.

--- Saliency Map Generation and Visualization Finished ---
