# Alignment Score

In [None]:
import numpy as np
data = np.load("../Shapley Save/shap_data_real.npz", allow_pickle=True)
real_dmaps = data["dmaps"]
real_defocus_shaps = data["defocus_shaps"]
real_rgb_shaps = data["rgb_shaps"]


In [None]:
# Need to replace each fake dataset one by one and compute the alignment score
data = np.load("../Shapley Save/shap_data_fake_Deepfakes.npz", allow_pickle=True)
fake_dmaps = data["dmaps"]
fake_defocus_shaps = data["defocus_shaps"]
fake_rgb_shaps = data["rgb_shaps"]

In [3]:
real_dmaps.shape

(1000, 299, 299)

In [None]:
fake_dmaps.shape

# Pixel Difference with SHAP

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.stats
from brokenaxes import brokenaxes

def compare_shap_diff(fake_dmaps, real_dmaps, shap_list,
                               num_bins=20, save_dir="./align"):
    os.makedirs(save_dir, exist_ok=True)
    kl_scores = []
    overlap_scores = []
    
    for i in range(50):
        # asseart shape consistency
        assert fake_dmaps[i].shape == real_dmaps[i].shape, "Real/Fake map shape mismatch"
        h, w = fake_dmaps[i].shape  # Extract height and width
        
        fake_dmap = fake_dmaps[i].astype(np.float32).flatten()
        real_dmap = real_dmaps[i].astype(np.float32).flatten()
        
        # Convert SHAP values to absolute values and flatten to 1D
        shap = np.abs(shap_list[i].astype(np.float32).flatten())
        # Ensure all values are non-negative (clip any potential numerical noise below 0)
        shap = np.clip(shap, 0, None) 
    
        
        # Compute pixel-wise differences
        pixel_diff = np.abs(fake_dmap - real_dmap)  
        
                
        # Compute histograms
        bins = np.linspace(0, 1, num_bins+1)
        bin_centers = (bins[:-1] + bins[1:])/2
        
        hist_diff, _ = np.histogram(fake_dmap, bins=bins, weights=pixel_diff)
        hist_shap, _ = np.histogram(fake_dmap, bins=bins, weights=shap)
        print(hist_diff)

        # L1 Regularization
        eps = 1e-8
        hist_diff_norm = hist_diff / (hist_diff.sum() + eps)
        hist_shap_norm = hist_shap / (hist_shap.sum() + eps)
        
        # Compute KL Divergence and Overlap
        kl_div = scipy.stats.entropy(hist_shap_norm + eps, hist_diff_norm + eps)
        overlap = np.minimum(hist_shap_norm, hist_diff_norm).sum()
        
        kl_scores.append(kl_div)
        overlap_scores.append(overlap)
        
        # Generate a figure with a broken y-axis (y-axis break applied)
        fig = plt.figure(figsize=(5, 2.5))
        bax = brokenaxes(xlims=[(0, 0.6)], ylims=((0, 0.1), (0.4, 0.5)), hspace=0.25)

        bax.bar(bin_centers, hist_diff_norm, width=0.04, alpha=0.6, color='navy', label='Pixel Diff')
        bax.bar(bin_centers, hist_shap_norm, width=0.02, alpha=0.8, color='gold', label='SHAP')

        bax.set_xlabel("Defocus Value", fontsize=10)
        bax.set_ylabel("Normalized Weight", fontsize=10)
        bax.legend(loc='upper right', fontsize=10)
        fig.suptitle(f"Sample {i:02d} | KL: {kl_div:.3f} | Alignment Score: {overlap:.3f}", fontsize=10)

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"align_{i:02d}.png"), bbox_inches='tight', dpi=300)
        plt.close()
    # Print the average KL and Alignment scores
    print(f"✅ Average KL Divergence: {np.mean(kl_scores):.4f} ± {np.std(kl_scores):.4f}")
    print(f"✅ Average Alignment Score: {np.mean(overlap_scores):.4f} ± {np.std(overlap_scores):.4f}")
    return kl_scores, overlap_scores


In [None]:
# Results from the Defocus model
kl_scores, overlap_scores = compare_shap_diff(
    fake_dmaps=fake_dmaps,
    real_dmaps=real_dmaps,
    shap_list=fake_defocus_shaps,
    num_bins=20,
    save_dir="./compare_shap_diff_defocus_Deepfakes"
)


In [None]:
# Results from the RGB model
kl_scores, overlap_scores = compare_shap_diff(
    fake_dmaps=fake_dmaps,
    real_dmaps=real_dmaps,
    shap_list=fake_rgb_shaps,
    num_bins=20,
    save_dir="./compare_shap_diff_rgb_Deepfakes"
)
