   # Import libraries

In [1]:
import os
import cv2  # OpenCV for image processing
import itertools # Added for saliency calculation
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
import math # For math.ceil

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.models import resnet152, ResNet152_Weights

from torchvision import models, transforms
import torchvision.transforms.functional as F # For potential use in visualization (though visualize_and_save_saliency uses different method)

# Use tqdm.auto for better console/notebook detection and nesting
from tqdm.auto import tqdm
import time # Optional: Can add timing info to postfix




   # Database creations using pytorch Dataset

In [2]:
class ImageAuthenticityDataset(Dataset):
    """Dataset for image quality assessment."""

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the CSV file with annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.dir_path = os.path.dirname(csv_file)  # Directory of the CSV file

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx,):
        """
        Retrieves an image and its labels by index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple (image, labels) where:
                image (PIL.Image): The image.
                labels (torch.Tensor): Tensor containing quality and authenticity scores.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path_relative = self.data.iloc[idx, 3]
        base_dir = os.path.abspath(os.path.join(self.dir_path, '../../')) 
        img_name = os.path.join(base_dir, img_path_relative.replace("./", ""))

        if not os.path.exists(img_name):
             print(f"Warning: Image path {img_name} not found directly. Trying original relative path logic...")
             img_name_fallback = self.data.iloc[idx, 3].replace("./", "../../") 
             # Corrected fallback construction:
             img_name_fallback_abs = os.path.abspath(os.path.join(self.dir_path, img_name_fallback))

             if os.path.exists(img_name_fallback_abs):
                 img_name = img_name_fallback_abs
             elif os.path.exists(img_name_fallback): # Check original fallback relative to CWD if absolute fails
                 img_name = img_name_fallback
             else:
                  raise FileNotFoundError(f"Could not find image file at primary path: {os.path.join(base_dir, img_path_relative.replace('./', ''))} or fallback attempts: {img_name_fallback_abs}, {img_name_fallback}")


        image = Image.open(img_name).convert('RGB')
        authenticity = self.data.iloc[idx, 1]  # Authenticity column
        labels = torch.tensor([authenticity], dtype=torch.float)


        if self.transform:
            image = self.transform(image)

        return image, labels





   # Definitions of the models

In [3]:
class AuthenticityPredictor(nn.Module):
    def __init__(self, freeze_backbone=True):
        super().__init__()
        # Load pre-trained ResNet-152 instead of VGG16
        resnet = resnet152(weights=ResNet152_Weights.DEFAULT)
        
        # Freeze backbone if requested
        if freeze_backbone:
            for param in resnet.parameters():
                param.requires_grad = False
                
        # Store the backbone (excluding the final fc layer)
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.avgpool = resnet.avgpool
        
        self.regression_head = nn.Sequential(
                nn.Linear(2048, 512),
                nn.ReLU(),
                nn.Dropout(0.5),  # Reduced dropout ratio
                nn.Linear(512, 128),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(128, 1)
            )    
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        predictions = self.regression_head(x)
        return predictions, x 
        


   ## Setup section

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


try:
    annotations_file = '../../Dataset/AIGCIQA2023/real_images_annotations.csv'
    if not os.path.exists(annotations_file):
        script_dir = os.path.dirname(os.path.abspath(__file__)) 
        annotations_file = os.path.abspath(os.path.join(script_dir, '../../Dataset/AIGCIQA2023/real_images_annotations.csv'))
        if not os.path.exists(annotations_file):
            raise FileNotFoundError(f"Annotations file not found at relative or script-based path.")
except NameError:
     annotations_file = '../../Dataset/AIGCIQA2023/real_images_annotations.csv'
     print("Warning: __file__ not defined (e.g., running in Jupyter). Assuming relative path for annotations file from CWD.")
     if not os.path.exists(annotations_file):
        # Try one level up if in a common 'notebooks' or 'scripts' subdir
        annotations_file_alt = '../Dataset/AIGCIQA2023/real_images_annotations.csv'
        if os.path.exists(annotations_file_alt):
            annotations_file = annotations_file_alt
        else:
            raise FileNotFoundError(f"Annotations file not found at '{annotations_file}' or '{annotations_file_alt}'. Please provide absolute path if needed.")


print(f"Loading annotations from: {annotations_file}")

dataset = ImageAuthenticityDataset(csv_file=annotations_file, transform=data_transforms)

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
np.random.seed(42)

train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
print(f"Dataset size: {len(dataset)}. Splitting into Train: {train_size}, Val: {val_size}, Test: {test_size}")
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

BATCH_SIZE = 1
NUM_WORKERS = min(os.cpu_count(), 4) if os.cpu_count() else 4 # Safer NUM_WORKERS
# test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# extract images based on the indices in the test dataset
test_indices = test_dataset.indices
indices_to_extract = [1,3,33,50,82]

# Create a subset of the test dataset with only the specified indices
extracted_dataset = torch.utils.data.Subset(test_dataset, indices_to_extract)
# Create a DataLoader for the extracted dataset
test_dataloader = DataLoader(extracted_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

Using device: cuda
Loading annotations from: ../../Dataset/AIGCIQA2023/real_images_annotations.csv
Dataset size: 1368. Splitting into Train: 957, Val: 273, Test: 138


   # Models loading

In [5]:
BASELINE_MODEL_PATH = 'Weights/ResNet-152_real_authenticity_finetuned.pth'
PRUNED_MODEL_PATH = 'Weights/real_authenticity_noise_out_pruned_model.pth'    

baseline_model = AuthenticityPredictor(freeze_backbone=True)
baseline_model.load_state_dict(torch.load(BASELINE_MODEL_PATH, map_location=device))
baseline_model.eval().to(device)
print("Baseline model loaded and set to evaluation mode.")

pruned_model = AuthenticityPredictor(freeze_backbone=True)
pruned_model.load_state_dict(torch.load(PRUNED_MODEL_PATH, map_location=device))
pruned_model.eval().to(device)
print("Pruned model loaded and set to evaluation mode.")



  baseline_model.load_state_dict(torch.load(BASELINE_MODEL_PATH, map_location=device))


Baseline model loaded and set to evaluation mode.


  pruned_model.load_state_dict(torch.load(PRUNED_MODEL_PATH, map_location=device))


Pruned model loaded and set to evaluation mode.


   # Functions definitions (Image Utils & Saliency)

In [6]:
def generate_mask(img_size, center, sigma):
    """Generates a binary mask with a square of zeros centered at 'center' with size 'sigma x sigma'."""
    mask = torch.ones(1, 1, img_size[0], img_size[1], device=device) # Ensure mask is on the correct device
    start_x = max(0, int(center[0] - sigma // 2))
    end_x = min(img_size[1], int(center[0] + (sigma + 1) // 2))
    start_y = max(0, int(center[1] - sigma // 2))
    end_y = min(img_size[0], int(center[1] + (sigma + 1) // 2))
    if start_y < end_y and start_x < end_x:
        mask[:, :, start_y:end_y, start_x:end_x] = 0
    return mask

def calculate_saliency_map(model, image, original_score, sigma_list, mask_value=0.0, pixel_batch_size=32): # Added pixel_batch_size
    """
    Calculates the multiscale saliency map using the occlusion method
    by summing scores across scales and normalizing the result,
    with batched processing for pixel occlusions.
    """
    model.eval()
    # Ensure image is on the correct device and has a batch dimension
    if image.dim() == 3: # (C, H, W)
        img_tensor_base = image.unsqueeze(0).to(device) # (1, C, H, W)
    elif image.dim() == 4 and image.shape[0] == 1: # (1, C, H, W)
        img_tensor_base = image.to(device)
    else:
        raise ValueError(f"Input image tensor has unexpected dimensions: {image.shape}, expected (C,H,W) or (1,C,H,W)")

    img_size = img_tensor_base.shape[2:] # H, W
    saliency_map_final = torch.zeros(img_size, dtype=torch.float32, device=device)

    print(f"Calculating saliency for image size {img_size} using {len(sigma_list)} sigmas: {sigma_list} with pixel_batch_size: {pixel_batch_size}")

    outer_progress = tqdm(
        enumerate(sigma_list),
        total=len(sigma_list),
        desc="Overall Sigmas ",
        unit="sigma",
        position=0,
        leave=True
    )

    for i, sigma in outer_progress:
        saliency_map_sigma = torch.zeros(img_size, dtype=torch.float32, device=device)
        
        all_pixel_coords = list(itertools.product(range(img_size[0]), range(img_size[1]))) # (y, x)
        total_pixels = len(all_pixel_coords)
        num_batches = math.ceil(total_pixels / pixel_batch_size)

        inner_progress_bar = tqdm(
            range(num_batches),
            total=num_batches,
            desc=f"  Sigma {i+1}/{len(sigma_list)} (val={sigma: >3}) Batches",
            leave=False,
            unit="batch",
            position=1, # Nested progress bar
            mininterval=0.1
        )
        
        processed_pixels_count = 0
        for batch_idx in inner_progress_bar:
            batch_start_idx = batch_idx * pixel_batch_size
            batch_end_idx = min(total_pixels, (batch_idx + 1) * pixel_batch_size)
            
            current_coords_batch = all_pixel_coords[batch_start_idx:batch_end_idx]
            actual_batch_size = len(current_coords_batch)

            if actual_batch_size == 0:
                continue

            masked_images_list = []
            for y_coord, x_coord in current_coords_batch:
                # generate_mask expects center as (x, y)
                mask = generate_mask(img_size, (x_coord, y_coord), sigma) 
                # img_tensor_base is (1, C, H, W), mask is (1, 1, H, W)
                # Broadcasting applies mask correctly: (1,C,H,W) * (1,1,H,W) -> (1,C,H,W)
                masked_image = img_tensor_base * mask + mask_value * (1 - mask) # Occlusion happens here
                masked_images_list.append(masked_image) # Appending (1,C,H,W) tensors
            
            # Stack along a new batch dimension, result shape: (actual_batch_size, C, H, W)
            batch_of_masked_images = torch.cat(masked_images_list, dim=0)

            with torch.no_grad():
                output_batch, _ = model(batch_of_masked_images) # Model returns (predictions, features)
                # Ensure scores are flat (actual_batch_size,)
                masked_scores_tensor_batch = output_batch.squeeze()
                if masked_scores_tensor_batch.dim() == 0: # If only one item in batch and squeeze made it scalar
                    masked_scores_tensor_batch = masked_scores_tensor_batch.unsqueeze(0)

            for k in range(actual_batch_size):
                y, x = current_coords_batch[k]
                masked_score_item = masked_scores_tensor_batch[k].item()
                saliency_value = original_score - masked_score_item # Higher score drop = more salient
                saliency_map_sigma[y, x] = saliency_value
            
            processed_pixels_count += actual_batch_size
            inner_progress_bar.set_postfix_str(
                f"Pixels {processed_pixels_count}/{total_pixels}",
                refresh=True 
            )

        saliency_map_final += saliency_map_sigma
        elapsed_time_sigma = inner_progress_bar.format_dict['elapsed'] # Get elapsed time from tqdm
        tqdm.write(f"  Sigma {sigma} finished processing in {elapsed_time_sigma:.2f}s.")
        inner_progress_bar.close()


    min_val = torch.min(saliency_map_final)
    max_val = torch.max(saliency_map_final)

    if max_val > min_val:
        saliency_map_normalized = (saliency_map_final - min_val) / (max_val - min_val)
    else:
        saliency_map_normalized = torch.zeros_like(saliency_map_final)
        print("Warning: Final saliency map was constant before normalization. Result is zero map.")
    
    outer_progress.close()
    return saliency_map_normalized.cpu().numpy()

def denormalize_image(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Denormalizes an image tensor."""
    if tensor.dim() != 3:
        raise ValueError(f"Input tensor must have 3 dimensions (C, H, W), but got {tensor.dim()}")
    
    mean_used = mean
    std_used = std
    if tensor.shape[0] != len(mean) or tensor.shape[0] != len(std):
        if tensor.shape[0] == 1: # Grayscale
             print("Warning: Denormalizing grayscale with potentially RGB stats. Using first value of mean/std.")
             mean_used = [mean[0]] if isinstance(mean, list) else [mean]
             std_used = [std[0]] if isinstance(std, list) else [std]
        else:
            raise ValueError(f"Channel mismatch: Tensor has {tensor.shape[0]} channels, mean has {len(mean)}, std has {len(std)}")

    mean_t = torch.as_tensor(mean_used, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    std_t = torch.as_tensor(std_used, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)

    denormalized_tensor = tensor * std_t + mean_t
    return torch.clamp(denormalized_tensor, 0., 1.)

def visualize_and_save_saliency(
    image_tensor,
    saliency_map,
    output_dir,
    filename_prefix,
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
    overlay_alpha=0.5,
    cmap_name='bwr'):
    """
    Visualizes saliency map, creates an overlay, and saves images.
    """
    if image_tensor.is_cuda:
        # print("Warning: image_tensor provided to visualize_and_save_saliency is on CUDA, moving to CPU.")
        image_tensor = image_tensor.cpu()

    # Create a sub-folder for each image's visualizations
    image_specific_output_dir = os.path.join(output_dir, filename_prefix)
    os.makedirs(image_specific_output_dir, exist_ok=True)
    
    temp_dir = os.path.join(image_specific_output_dir, 'temp_heatmap_cache') 
    os.makedirs(temp_dir, exist_ok=True)

    if saliency_map.ndim != 2:
        print(f"Error: Saliency map has unexpected dimensions {saliency_map.shape}. Expected (H, W). Skipping visualization.")
        return
    saliency_map = np.clip(saliency_map, 0.0, 1.0)
    
    NUMPY_DIR = os.path.join(output_dir, 'numpy_saliency_maps') # Centralized numpy maps
    os.makedirs(NUMPY_DIR, exist_ok=True)
    np.save(os.path.join(NUMPY_DIR, f"{filename_prefix}_saliency_map.npy"), saliency_map)

    try:
        img_denorm_tensor = denormalize_image(image_tensor, mean, std)
    except ValueError as e:
        print(f"Error during denormalization for {filename_prefix}: {e}. Skipping visualization.")
        return

    img_np = img_denorm_tensor.numpy().transpose(1, 2, 0)
    img_np = np.clip(img_np, 0.0, 1.0)
    img_uint8 = (img_np * 255).astype(np.uint8)

    if img_uint8.shape[2] == 1:
        img_display = cv2.cvtColor(img_uint8, cv2.COLOR_GRAY2RGB)
        img_bgr = cv2.cvtColor(img_uint8, cv2.COLOR_GRAY2BGR)
    elif img_uint8.shape[2] == 3:
        img_display = img_uint8
        img_bgr = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR)
    else:
         print(f"Error: Unexpected number of channels ({img_uint8.shape[2]}) for {filename_prefix}. Skipping visualization.")
         return

    orig_save_path = os.path.join(image_specific_output_dir, f"{filename_prefix}_original.png")
    plt.figure(figsize=(img_display.shape[1]/100, img_display.shape[0]/100), dpi=100) # Match size
    plt.imshow(img_display)
    plt.axis('off'); plt.title("Original Image")
    plt.savefig(orig_save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    try:
        cmap = cm.get_cmap(cmap_name)
    except ValueError:
        print(f"Warning: Colormap '{cmap_name}' not found. Using default 'viridis'.")
        cmap = cm.get_cmap('viridis')
    norm = colors.Normalize(vmin=0, vmax=1)

    heatmap_save_path = os.path.join(image_specific_output_dir, f"{filename_prefix}_heatmap_{cmap_name}.png")
    plt.figure(figsize=(saliency_map.shape[1]/100, saliency_map.shape[0]/100), dpi=100) # Match size
    plt.imshow(saliency_map, cmap=cmap, norm=norm)
    plt.colorbar(label=f'Saliency')
    plt.title(f"Saliency Heatmap ({cmap_name})"); plt.axis('off')
    plt.savefig(heatmap_save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    temp_heatmap_path = os.path.join(temp_dir, f"{filename_prefix}_temp_heatmap_for_overlay.png")
    fig_width_inches = img_display.shape[1] / 100.0
    fig_height_inches = img_display.shape[0] / 100.0
    plt.figure(figsize=(fig_width_inches, fig_height_inches), dpi=100)
    plt.imshow(saliency_map, cmap=cmap, norm=norm); plt.axis('off')
    plt.savefig(temp_heatmap_path, bbox_inches='tight', pad_inches=0, dpi=100)
    plt.close()

    colored_heatmap_bgr = cv2.imread(temp_heatmap_path)
    if os.path.exists(temp_heatmap_path): os.remove(temp_heatmap_path)
    if os.path.exists(temp_dir) and not os.listdir(temp_dir): 
        try: os.rmdir(temp_dir)
        except OSError: pass # Might fail if another process/thread is accessing

    if colored_heatmap_bgr is None:
        print(f"Error: Could not read temporary heatmap file for {filename_prefix}: {temp_heatmap_path}. Skipping overlay.")
        return

    if colored_heatmap_bgr.shape[:2] != img_bgr.shape[:2]:
         # print(f"Warning: Resizing heatmap from {colored_heatmap_bgr.shape[:2]} to {img_bgr.shape[:2]} for {filename_prefix}")
         colored_heatmap_bgr = cv2.resize(colored_heatmap_bgr, (img_bgr.shape[1], img_bgr.shape[0]), interpolation=cv2.INTER_LINEAR)

    overlay = cv2.addWeighted(img_bgr, 1.0 - overlay_alpha, colored_heatmap_bgr, overlay_alpha, 0.0)
    overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)

    overlay_save_path = os.path.join(image_specific_output_dir, f"{filename_prefix}_overlay_{cmap_name}.png")
    success = cv2.imwrite(overlay_save_path, cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR))
    if not success:
        print(f"Error: cv2.imwrite failed for overlay {overlay_save_path}. Trying plt.savefig.")
        plt.figure(figsize=(overlay_rgb.shape[1]/100, overlay_rgb.shape[0]/100), dpi=100)
        plt.imshow(overlay_rgb); plt.axis('off'); plt.title(f"Saliency Overlay ({cmap_name})")
        plt.savefig(overlay_save_path, bbox_inches='tight', pad_inches=0)
        plt.close()

def run_saliency_analysis(
    model,
    dataloader,
    output_dir,
    num_images_to_process,
    sigma_list,
    pixel_batch_size, # Added
    mask_value=0.0,
    vis_cmap='bwr',
    vis_alpha=0.6,
    device='cpu', 
    model_name="Model"
    ):
    """
    Facade function to run saliency map generation (with batched pixel processing)
    and visualization for a given model.
    """
    print(f"\n--- Starting Saliency Analysis for {model_name} ---")
    # Central output directory for this model run
    model_output_dir = os.path.join(output_dir, model_name.replace(" ", "_"))
    os.makedirs(model_output_dir, exist_ok=True)
    print(f"Output visualizations will be saved in subfolders within: {model_output_dir}")
    print(f"Processing up to {num_images_to_process} images.")
    print(f"Using sigmas: {sigma_list}, Pixel batch size: {pixel_batch_size}")

    processed_count = 0
    model.eval()

    # Determine the actual number of images to process
    num_to_iterate = min(num_images_to_process, len(dataloader.dataset) if BATCH_SIZE == 1 else len(dataloader))


    test_iterator = tqdm(
        dataloader,
        total=num_to_iterate,
        desc=f"{model_name} Image Progress"
    )

    for i, (images, labels) in enumerate(test_iterator):
        if processed_count >= num_images_to_process:
            print(f"\nReached limit of {num_images_to_process} images for {model_name}. Stopping.")
            break

        image_tensor = images[0].to(device) # (C, H, W) as BATCH_SIZE=1
        label = labels[0]                   # scalar tensor

        print(f"\nProcessing image {processed_count + 1}/{num_images_to_process} (DataLoader index: {i}) for {model_name}")

        with torch.no_grad():
            original_output, _ = model(image_tensor.unsqueeze(0)) # Model expects batch
            original_score = original_output.item()

        true_authenticity = label.item()
        print(f"  True Authenticity: {true_authenticity:.4f}")
        print(f"  {model_name} Predicted Authenticity (Original): {original_score:.4f}")

        if not sigma_list:
             print("  Warning: No sigma values provided. Skipping saliency calculation.")
             processed_count += 1
             test_iterator.set_postfix_str(f"Image {processed_count}/{num_images_to_process} (Skipped)")
             continue
        
        print(f"  Calculating saliency map...")
        saliency_map_np = calculate_saliency_map( # Calls the modified function
            model=model,
            image=image_tensor, # Pass (C,H,W) tensor
            original_score=original_score,
            sigma_list=sigma_list,
            mask_value=mask_value,
            pixel_batch_size=pixel_batch_size # Pass new arg
        )
        print(f"  Saliency map calculated with shape: {saliency_map_np.shape}")

        filename_prefix = f"img_{processed_count:03d}_auth_{true_authenticity:.2f}_pred_{original_score:.2f}"
        print(f"  Visualizing and saving results with prefix: {filename_prefix}...")
        
        # Pass the model_output_dir for this specific model run
        visualize_and_save_saliency(
            image_tensor=image_tensor.cpu(), # Ensure tensor is on CPU for visualization
            saliency_map=saliency_map_np,
            output_dir=model_output_dir, # Pass the specific dir for this model
            filename_prefix=filename_prefix,
            overlay_alpha=vis_alpha,
            cmap_name=vis_cmap
        )
        print(f"  Visualization saved for {filename_prefix}.")

        processed_count += 1
        test_iterator.set_postfix_str(f"Image {processed_count}/{num_images_to_process} Done")

    test_iterator.close()
    print(f"\n--- Saliency Analysis for {model_name} Finished ---")



   # Main Execution Block

In [7]:
# --- Configuration ---
NUM_IMAGES_TO_PROCESS = 5 # Reduced for quicker testing, adjust as needed
SIGMA_LIST = [3, 5, 9, 17, 33, 65] # Fixed list
MASK_VALUE = 0.0 
VIS_CMAP = 'bwr' 
VIS_ALPHA = 0.6 
PIXEL_BATCH_SIZE = 128


# Define a main output directory
MAIN_OUTPUT_DIR = '5_imgs_masking_experiment_outputs'
os.makedirs(MAIN_OUTPUT_DIR, exist_ok=True)


# Run for Baseline Model
run_saliency_analysis(
    model=baseline_model,
    dataloader=test_dataloader,
    output_dir=MAIN_OUTPUT_DIR, # Pass the main output directory
    num_images_to_process=NUM_IMAGES_TO_PROCESS,
    sigma_list=SIGMA_LIST,
    pixel_batch_size=PIXEL_BATCH_SIZE, 
    mask_value=MASK_VALUE,
    vis_cmap=VIS_CMAP,
    vis_alpha=VIS_ALPHA,
    device=device,
    model_name="Baseline_Model" 
)

# Run for Pruned Model
run_saliency_analysis(
    model=pruned_model,
    dataloader=test_dataloader,
    output_dir=MAIN_OUTPUT_DIR, # Pass the main output directory
    num_images_to_process=NUM_IMAGES_TO_PROCESS,
    sigma_list=SIGMA_LIST,
    pixel_batch_size=PIXEL_BATCH_SIZE,
    mask_value=MASK_VALUE,
    vis_cmap=VIS_CMAP,
    vis_alpha=VIS_ALPHA,
    device=device,
    model_name="Pruned_Model" 
)

print(f"\n--- All Saliency Analyses Completed. Outputs in '{MAIN_OUTPUT_DIR}' ---")



--- Starting Saliency Analysis for Baseline_Model ---
Output visualizations will be saved in subfolders within: 5_imgs_masking_experiment_outputs/Baseline_Model
Processing up to 5 images.
Using sigmas: [3, 5, 9, 17, 33, 65], Pixel batch size: 128


Baseline_Model Image Progress:   0%|          | 0/5 [00:00<?, ?it/s]


Processing image 1/5 (DataLoader index: 0) for Baseline_Model


  True Authenticity: 48.4945
  Baseline_Model Predicted Authenticity (Original): 46.9550
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 77.73s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.30s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.44s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.49s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.50s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.52s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_000_auth_48.49_pred_46.96...
  Visualization saved for img_000_auth_48.49_pred_46.96.

Processing image 2/5 (DataLoader index: 1) for Baseline_Model


  True Authenticity: 56.5779
  Baseline_Model Predicted Authenticity (Original): 34.3024
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


  cmap = cm.get_cmap(cmap_name)


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.53s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.55s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.54s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.59s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.59s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.56s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_001_auth_56.58_pred_34.30...
  Visualization saved for img_001_auth_56.58_pred_34.30.

Processing image 3/5 (DataLoader index: 2) for Baseline_Model
  True Authenticity: 51.3419
  Baseline_Model Predicted Authenticity (Original): 46.8908
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.56s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.48s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.53s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.55s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.57s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.58s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_002_auth_51.34_pred_46.89...
  Visualization saved for img_002_auth_51.34_pred_46.89.

Processing image 4/5 (DataLoader index: 3) for Baseline_Model
  True Authenticity: 47.7000
  Baseline_Model Predicted Authenticity (Original): 44.5880
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.56s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.57s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.57s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.56s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.56s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.59s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_003_auth_47.70_pred_44.59...
  Visualization saved for img_003_auth_47.70_pred_44.59.

Processing image 5/5 (DataLoader index: 4) for Baseline_Model
  True Authenticity: 56.3482
  Baseline_Model Predicted Authenticity (Original): 69.4267
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.58s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.53s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.51s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.54s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.49s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.51s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_004_auth_56.35_pred_69.43...
  Visualization saved for img_004_auth_56.35_pred_69.43.

--- Saliency Analysis for Baseline_Model Finished ---

--- Starting Saliency Analysis for Pruned_Model ---
Output visualizations will be saved in subfolders within: 5_imgs_masking_experiment_outputs/Pruned_Model
Processing up to 5 images.
Using sigmas: [3, 5, 9, 17, 33, 65], Pixel batch size: 128


Pruned_Model Image Progress:   0%|          | 0/5 [00:00<?, ?it/s]


Processing image 1/5 (DataLoader index: 0) for Pruned_Model
  True Authenticity: 48.4945
  Pruned_Model Predicted Authenticity (Original): 46.9565
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.55s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.52s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.53s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.54s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.49s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.51s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_000_auth_48.49_pred_46.96...
  Visualization saved for img_000_auth_48.49_pred_46.96.

Processing image 2/5 (DataLoader index: 1) for Pruned_Model
  True Authenticity: 56.5779
  Pruned_Model Predicted Authenticity (Original): 33.6004
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.55s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.58s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.57s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.58s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.54s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.53s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_001_auth_56.58_pred_33.60...
  Visualization saved for img_001_auth_56.58_pred_33.60.

Processing image 3/5 (DataLoader index: 2) for Pruned_Model
  True Authenticity: 51.3419
  Pruned_Model Predicted Authenticity (Original): 47.1830
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.54s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.51s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.55s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.57s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.53s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.53s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_002_auth_51.34_pred_47.18...
  Visualization saved for img_002_auth_51.34_pred_47.18.

Processing image 4/5 (DataLoader index: 3) for Pruned_Model
  True Authenticity: 47.7000
  Pruned_Model Predicted Authenticity (Original): 46.3287
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.52s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.49s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.47s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.55s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.56s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.57s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_003_auth_47.70_pred_46.33...
  Visualization saved for img_003_auth_47.70_pred_46.33.

Processing image 5/5 (DataLoader index: 4) for Pruned_Model
  True Authenticity: 56.3482
  Pruned_Model Predicted Authenticity (Original): 68.3922
  Calculating saliency map...
Calculating saliency for image size torch.Size([224, 224]) using 6 sigmas: [3, 5, 9, 17, 33, 65] with pixel_batch_size: 128


Overall Sigmas :   0%|          | 0/6 [00:00<?, ?sigma/s]

  Sigma 1/6 (val=  3) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 3 finished processing in 78.53s.


  Sigma 2/6 (val=  5) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 5 finished processing in 78.49s.


  Sigma 3/6 (val=  9) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 9 finished processing in 78.49s.


  Sigma 4/6 (val= 17) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 17 finished processing in 78.53s.


  Sigma 5/6 (val= 33) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 33 finished processing in 78.50s.


  Sigma 6/6 (val= 65) Batches:   0%|          | 0/392 [00:00<?, ?batch/s]

  Sigma 65 finished processing in 78.52s.
  Saliency map calculated with shape: (224, 224)
  Visualizing and saving results with prefix: img_004_auth_56.35_pred_68.39...


  Visualization saved for img_004_auth_56.35_pred_68.39.

--- Saliency Analysis for Pruned_Model Finished ---

--- All Saliency Analyses Completed. Outputs in '5_imgs_masking_experiment_outputs' ---
