In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torchvision.models import ResNet18_Weights
import shap
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import json
import os
from tqdm import tqdm
import gc

# Configuration
class Config:
    DATA_ROOT = "/kaggle/input/real-vs-fake-faces/"
    REAL_IMAGES = DATA_ROOT + "real/"
    FAKE_IMAGES = DATA_ROOT + "fake/"
    OUTPUT_DIR = "outputs/"
    
    BATCH_SIZE = 16
    IMAGE_SIZE = 224
    NUM_EPOCHS = 10
    LEARNING_RATE = 3e-4
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    NUM_BACKGROUND = 50
    SHAP_SAMPLES = 25



# 1. Data Loading and Preparation
class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            with Image.open(img_path) as img:
                image = img.convert('RGB')
                if self.transform:
                    image = self.transform(image)
                label = self.labels[idx]
                return image, label, img_path
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            dummy_img = Image.new('RGB', (Config.IMAGE_SIZE, Config.IMAGE_SIZE), color='black')
            if self.transform:
                dummy_img = self.transform(dummy_img)
            return dummy_img, self.labels[idx], img_path

def prepare_data():
    transform = transforms.Compose([
        transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    real_images = [os.path.join(Config.REAL_IMAGES, f) for f in os.listdir(Config.REAL_IMAGES) if f.endswith(('.jpg', '.png'))]
    fake_images = [os.path.join(Config.FAKE_IMAGES, f) for f in os.listdir(Config.FAKE_IMAGES) if f.endswith(('.jpg', '.png'))]
    
    real_labels = [0] * len(real_images)
    fake_labels = [1] * len(fake_images)
    
    all_images = real_images + fake_images
    all_labels = real_labels + fake_labels
    
    indices = np.random.permutation(len(all_images))
    all_images = [all_images[i] for i in indices]
    all_labels = [all_labels[i] for i in indices]
    
    n_total = len(all_images)
    n_train = int(0.7 * n_total)
    n_val = int(0.15 * n_total)
    
    train_images = all_images[:n_train]
    train_labels = all_labels[:n_train]
    
    val_images = all_images[n_train:n_train+n_val]
    val_labels = all_labels[n_train:n_train+n_val]
    
    test_images = all_images[n_train+n_val:]
    test_labels = all_labels[n_train+n_val:]
    
    train_dataset = ImageDataset(train_images, train_labels, transform)
    val_dataset = ImageDataset(val_images, val_labels, transform)
    test_dataset = ImageDataset(test_images, test_labels, transform)
    
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, 
                             num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, 
                           num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, 
                            num_workers=2, pin_memory=True)
    
    return train_loader, val_loader, test_loader, test_dataset

# 2. Model Definition
class FakeImageDetector(nn.Module):
    def __init__(self):
        super(FakeImageDetector, self).__init__()
        self.resnet = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(inplace=False),
            nn.Dropout(0.3),
            nn.Linear(256, 2)
        )
        
        for module in self.resnet.modules():
            if isinstance(module, nn.ReLU):
                module.inplace = False
        
    def forward(self, x):
        return self.resnet(x.clone())
    
    def get_activation_map(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x.clone())
        x = self.resnet.maxpool(x)
        
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        
        return x

# 3. Training Functions
def train_model(train_loader, val_loader):
    model = FakeImageDetector().to(Config.DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)
    
    best_val_acc = 0.0
    
    for epoch in range(Config.NUM_EPOCHS):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels, _ in tqdm(train_loader):
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            del images, labels, outputs, loss
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        train_acc = 100.0 * correct / total
        train_loss = train_loss / len(train_loader)
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels, _ in val_loader:
                images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                del images, labels, outputs, loss
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        val_acc = 100.0 * correct / total
        val_loss = val_loss / len(val_loader)
        
        print(f'Epoch [{epoch+1}/{Config.NUM_EPOCHS}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), os.path.join(Config.OUTPUT_DIR, 'best_model.pth'))
        
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return model

# 4. SHAP Analysis
def explain_model_with_shap(model, test_dataset):
    model.eval()
    
    for module in model.modules():
        if isinstance(module, nn.ReLU):
            module.inplace = False
    
    background = []
    background_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
    
    for i, (img, _, _) in enumerate(background_loader):
        if i >= Config.NUM_BACKGROUND:
            break
        background.append(img.squeeze(0).cpu().numpy())
    
    background = np.array(background)
    
    class ModelWrapper(nn.Module):
        def __init__(self, model):
            super(ModelWrapper, self).__init__()
            self.model = model
            
        def forward(self, x):
            x = x.clone()
            return self.model(x)
    
    wrapped_model = ModelWrapper(model).to(Config.DEVICE)
    
    e = shap.GradientExplainer(wrapped_model, torch.tensor(background).to(Config.DEVICE))
    
    test_images = []
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    image_paths = []
    true_labels = []
    
    for i, (img, label, path) in enumerate(test_loader):
        if i >= Config.SHAP_SAMPLES:
            break
        test_images.append(img.squeeze(0).cpu().numpy())
        image_paths.append(path)
        true_labels.append(label.item())
    
    test_images = np.array(test_images)
    
    batch_size = 5
    num_batches = (len(test_images) + batch_size - 1) // batch_size
    all_shap_values = [[] for _ in range(2)]
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(test_images))
        
        batch_images = test_images[start_idx:end_idx]
        batch_tensor = torch.tensor(batch_images, requires_grad=True).to(Config.DEVICE)
        
        try:
            batch_shap_values = e.shap_values(batch_tensor, nsamples=50)
            for i in range(2):
                all_shap_values[i].extend(batch_shap_values[i])  # Already NumPy arrays
        except Exception as ex:
            print(f"Error in batch {batch_idx}: {ex}")
            for i in range(end_idx - start_idx):
                for j in range(2):
                    zeros = np.zeros_like(batch_images[0])
                    all_shap_values[j].append(zeros)
        
        del batch_tensor
        if 'batch_shap_values' in locals():
            del batch_shap_values
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        gc.collect()
    
    all_shap_values = [np.array(vals) for vals in all_shap_values]
    
    return all_shap_values, test_images, image_paths, true_labels

# 5. Description Generator
def generate_descriptions(model, shap_values, test_images, image_paths, true_labels):
    descriptions = []
    
    model.eval()
    
    batch_size = 8
    num_batches = (len(test_images) + batch_size - 1) // batch_size
    all_predictions = []
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(test_images))
        
        batch_images = test_images[start_idx:end_idx]
        batch_tensor = torch.tensor(batch_images).to(Config.DEVICE)
        
        with torch.no_grad():
            outputs = model(batch_tensor)
            _, predicted = outputs.max(1)
            all_predictions.extend(predicted.cpu().numpy())
        
        del batch_tensor, outputs
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    predicted = np.array(all_predictions)
    
    shap_values_class0 = shap_values[0]
    shap_values_class1 = shap_values[1]
    
    for i in range(len(test_images)):
        img_path = image_paths[i][0]
        true_label = true_labels[i]
        pred_label = predicted[i]
        
        result = {
            "image_path": img_path,
            "true_label": "Real" if true_label == 0 else "Fake",
            "predicted_label": "Real" if pred_label == 0 else "Fake", 
            "correct_prediction": bool(true_label == pred_label),
        }
        
        if pred_label == 1:
            shap_vals = shap_values_class1[i]
            
            attribution = np.sum(np.abs(shap_vals), axis=0)
            
            height, width = attribution.shape
            h_quadrants = [
                (0, height//2),
                (height//2, height)
            ]
            w_quadrants = [
                (0, width//2),
                (width//2, width)
            ]
            
            region_scores = {}
            for h_idx, (h_start, h_end) in enumerate(h_quadrants):
                for w_idx, (w_start, w_end) in enumerate(w_quadrants):
                    region_name = f"quadrant_{h_idx}_{w_idx}"
                    region_scores[region_name] = np.sum(attribution[h_start:h_end, w_start:w_end])
            
            sorted_regions = sorted(region_scores.items(), key=lambda x: x[1], reverse=True)
            top_regions = [r[0] for r in sorted_regions[:2]]
            
            artifacts = []
            
            if "quadrant_0_0" in top_regions:
                artifacts.append("unusual texture patterns in the top-left region")
            if "quadrant_0_1" in top_regions:
                artifacts.append("inconsistent lighting in the top-right region")
            if "quadrant_1_0" in top_regions:
                artifacts.append("edge artifacts in the bottom-left region")
            if "quadrant_1_1" in top_regions:
                artifacts.append("unnatural color distribution in the bottom-right region")
            
            if not artifacts:
                artifacts = ["subtle visual inconsistencies throughout the image"]
            
            description = f"This image appears to be AI-generated due to {', '.join(artifacts[:-1])}"
            if len(artifacts) > 1:
                description += f", and {artifacts[-1]}"
            else:
                description = f"This image appears to be AI-generated due to {artifacts[0]}"
            
            result["description"] = description
            result["key_artifacts"] = artifacts
            
        else:
            description = "This appears to be an authentic photograph with natural lighting, consistent textures, and realistic details."
            result["description"] = description
        
        descriptions.append(result)
        
        if i % 10 == 0:
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return descriptions

# 6. Visualization Functions
def visualize_shap_results(image_idx, shap_values, test_images, descriptions):
    test_image = test_images[image_idx]
    
    description = descriptions[image_idx]
    
    if description["predicted_label"] == "Fake":
        shap_vals = shap_values[1][image_idx]
    else:
        shap_vals = shap_values[0][image_idx]
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 7))
    
    img_display = np.transpose(test_image, (1, 2, 0))
    img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min() + 1e-8)
    axes[0].imshow(img_display)
    axes[0].set_title(f"Original Image\nTrue: {description['true_label']}, Pred: {description['predicted_label']}")
    axes[0].axis('off')
    
    shap_visualization = np.sum(np.abs(shap_vals), axis=0)
    shap_visualization = shap_visualization / (shap_visualization.max() + 1e-8)
    
    axes[1].imshow(shap_visualization, cmap='hot')
    axes[1].set_title("SHAP Attribution Map")
    axes[1].axis('off')
    
    plt.tight_layout()
    
    plt.figtext(0.5, 0.01, description["description"], ha="center", fontsize=12, 
                bbox={"facecolor":"white", "alpha":0.5, "pad":5})
    
    plt.savefig(os.path.join(Config.OUTPUT_DIR, f"shap_vis_{image_idx}.png"), dpi=100)
    plt.close()
    
    gc.collect()

# 7. Main Function
def main():
    os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
    
    train_loader, val_loader, test_loader, test_dataset = prepare_data()
    
    model = FakeImageDetector().to(Config.DEVICE)
    
    if os.path.exists(os.path.join(Config.OUTPUT_DIR, 'best_model.pth')):
        state_dict = torch.load(os.path.join(Config.OUTPUT_DIR, 'best_model.pth'), 
                               map_location=Config.DEVICE, weights_only=True)
        model.load_state_dict(state_dict)
        print("Loaded pre-trained model.")
    else:
        model = train_model(train_loader, val_loader)
    
    for module in model.modules():
        if isinstance(module, nn.ReLU):
            module.inplace = False
    
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _ in test_loader:
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            del images, labels, outputs
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    test_acc = 100.0 * correct / total
    print(f'Test Accuracy: {test_acc:.2f}%')
    
    cm = confusion_matrix(all_labels, all_preds)
    print("Confusion Matrix:")
    print(cm)
    
    print("Generating SHAP explanations...")
    shap_values, test_images, image_paths, true_labels = explain_model_with_shap(model, test_dataset)
    
    print("Generating descriptions based on SHAP values...")
    descriptions = generate_descriptions(model, shap_values, test_images, image_paths, true_labels)
    
    with open(os.path.join(Config.OUTPUT_DIR, 'descriptions.json'), 'w') as f:
        json.dump(descriptions, f, indent=4)
    
    num_vis = min(5, len(test_images))
    for i in range(num_vis):
        print(f"Visualizing results for image {i+1}/{num_vis}...")
        visualize_shap_results(i, shap_values, test_images, descriptions)
    
    print("Done! Results saved to:", Config.OUTPUT_DIR)

if __name__ == "__main__":
    main()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 199MB/s]
100%|██████████| 90/90 [00:18<00:00,  4.76it/s]


Epoch [1/10], Train Loss: 0.6483, Train Acc: 64.29%, Val Loss: 0.6108, Val Acc: 66.01%


100%|██████████| 90/90 [00:09<00:00,  9.40it/s]


Epoch [2/10], Train Loss: 0.5009, Train Acc: 76.40%, Val Loss: 0.6260, Val Acc: 68.63%


100%|██████████| 90/90 [00:09<00:00,  9.04it/s]


Epoch [3/10], Train Loss: 0.3770, Train Acc: 83.26%, Val Loss: 0.7362, Val Acc: 64.38%


100%|██████████| 90/90 [00:09<00:00,  9.20it/s]


Epoch [4/10], Train Loss: 0.2967, Train Acc: 87.61%, Val Loss: 1.2123, Val Acc: 60.46%


100%|██████████| 90/90 [00:09<00:00,  9.25it/s]


Epoch [5/10], Train Loss: 0.2120, Train Acc: 92.72%, Val Loss: 0.8776, Val Acc: 67.65%


100%|██████████| 90/90 [00:09<00:00,  9.48it/s]


Epoch [6/10], Train Loss: 0.1559, Train Acc: 94.40%, Val Loss: 1.2689, Val Acc: 64.05%


100%|██████████| 90/90 [00:09<00:00,  9.49it/s]


Epoch [7/10], Train Loss: 0.1925, Train Acc: 92.93%, Val Loss: 1.0836, Val Acc: 61.76%


100%|██████████| 90/90 [00:09<00:00,  9.06it/s]


Epoch [8/10], Train Loss: 0.1334, Train Acc: 95.24%, Val Loss: 1.1398, Val Acc: 65.69%


100%|██████████| 90/90 [00:09<00:00,  9.45it/s]


Epoch [9/10], Train Loss: 0.0972, Train Acc: 97.90%, Val Loss: 1.3539, Val Acc: 60.46%


100%|██████████| 90/90 [00:09<00:00,  9.62it/s]


Epoch [10/10], Train Loss: 0.1372, Train Acc: 95.24%, Val Loss: 0.9830, Val Acc: 66.01%
Test Accuracy: 65.80%
Confusion Matrix:
[[123  38]
 [ 67  79]]
Generating SHAP explanations...
Generating descriptions based on SHAP values...
Visualizing results for image 1/5...
Visualizing results for image 2/5...
Visualizing results for image 3/5...
Visualizing results for image 4/5...
Visualizing results for image 5/5...
Done! Results saved to: outputs/


In [2]:
import os
import json
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import shap
import gc
import base64
from io import BytesIO
from tqdm import tqdm
from torchvision import models, transforms
from torchvision.models import ResNet18_Weights
from transformers import AutoProcessor, BlipForConditionalGeneration, CLIPModel, CLIPProcessor

# Configuration
class Config:
    DATA_ROOT = "/kaggle/input/real-vs-fake-faces/"
    REAL_IMAGES = os.path.join(DATA_ROOT, "real/")
    FAKE_IMAGES = os.path.join(DATA_ROOT, "fake/")
    OUTPUT_DIR = "outputs/"
    IMAGE_SIZE = 224
    BATCH_SIZE = 16
    NUM_EPOCHS = 5
    LEARNING_RATE = 3e-4
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_BACKGROUND = 50
    SHAP_SAMPLES = 25

# FakeImageDetector
class FakeImageDetector(nn.Module):
    def __init__(self):
        super(FakeImageDetector, self).__init__()
        self.resnet = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(inplace=False),
            nn.Dropout(0.3),
            nn.Linear(256, 2)
        )
        
        # Ensure all ReLU operations are not in-place to work with SHAP
        for module in self.resnet.modules():
            if isinstance(module, nn.ReLU):
                module.inplace = False
        
    def forward(self, x):
        return self.resnet(x.clone())
    
    def get_activation_map(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x.clone())
        x = self.resnet.maxpool(x)
        
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        
        return x

# VLM Description Generator
class VLMDescriptionGenerator:
    def __init__(self, fake_detector_path=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Create output directory if it doesn't exist
        if not os.path.exists(Config.OUTPUT_DIR):
            os.makedirs(Config.OUTPUT_DIR)
        
        # Create image and heatmap directories
        self.images_dir = os.path.join(Config.OUTPUT_DIR, "processed_images")
        self.heatmaps_dir = os.path.join(Config.OUTPUT_DIR, "heatmaps")
        self.overlays_dir = os.path.join(Config.OUTPUT_DIR, "overlays")
        
        os.makedirs(self.images_dir, exist_ok=True)
        os.makedirs(self.heatmaps_dir, exist_ok=True)
        os.makedirs(self.overlays_dir, exist_ok=True)
        
        # Load the fake detector model
        if fake_detector_path is None:
            fake_detector_path = os.path.join(Config.OUTPUT_DIR, 'best_model.pth')
        
        # Check if the model file exists before loading
        if not os.path.exists(fake_detector_path):
            raise FileNotFoundError(f"Fake detector model not found at {fake_detector_path}")
        
        self.fake_detector = FakeImageDetector().to(self.device)
        self.fake_detector.load_state_dict(torch.load(fake_detector_path, map_location=self.device))
        self.fake_detector.eval()
        
        # Load CLIP for feature extraction
        print("Loading CLIP model...")
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_model.eval()
        
        # Initialize DataFrame for results
        self.results_df = pd.DataFrame(columns=[
            'image_filename',
            'image_path',
            'processed_image_path',
            'heatmap_path',
            'overlay_path',
            'is_fake',
            'confidence',
            'description',
            'manual_description',  # This will be empty initially
            'source_label'         # Whether it came from real or fake folder
        ])
        
        # Placeholders for VLM models
        self.vlm_processor = None
        self.vlm_model = None
        
        print("Models loaded successfully")
        
    def _load_vlm_model(self):
        """Load VLM model on demand to save memory"""
        if self.vlm_model is None:
            print("Loading VLM model...")
            model_name = "Salesforce/blip-image-captioning-base"
            self.vlm_processor = AutoProcessor.from_pretrained(model_name)
            self.vlm_model = BlipForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
                low_cpu_mem_usage=True,
                device_map="auto" if self.device.type == 'cuda' else None
            )
            if self.device.type != 'cuda':
                self.vlm_model = self.vlm_model.to(self.device)
            self.vlm_model.eval()
            print("VLM model loaded")
        
    def _unload_vlm_model(self):
        """Unload VLM model to free memory"""
        if self.vlm_model is not None:
            del self.vlm_processor
            del self.vlm_model
            self.vlm_processor = None
            self.vlm_model = None
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print("VLM model unloaded")
    
    def process_image(self, image_path):
        """Process a single image and generate descriptions."""
        try:
            img = Image.open(image_path).convert('RGB')
            img_array = np.array(img)
            
            img_tensor = self.preprocess_for_fake_detector(img)
            
            with torch.no_grad():
                outputs = self.fake_detector(img_tensor.unsqueeze(0).to(self.device))
                _, predicted = outputs.max(1)
                is_fake = bool(predicted.item())
                confidence = torch.softmax(outputs, dim=1)[0][predicted.item()].item()
            
            del outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Using a simpler approach for feature importance instead of SHAP
            heatmap = self.get_feature_importance(img_tensor, img_array.shape[:2])
            description = self.generate_vlm_description(img, heatmap, is_fake, confidence)
            
            # Determine if this is from the real or fake directory
            if "real" in image_path.lower():
                source_label = "real"
            elif "fake" in image_path.lower():
                source_label = "fake"
            else:
                source_label = "unknown"
            
            # Save processed images
            image_filename = os.path.basename(image_path)
            base_name = os.path.splitext(image_filename)[0]
            
            # Save processed image
            processed_img_path = os.path.join(self.images_dir, f"{base_name}_processed.jpg")
            img.save(processed_img_path)
            
            # Save heatmap
            heatmap_path = os.path.join(self.heatmaps_dir, f"{base_name}_heatmap.jpg")
            plt.figure(figsize=(8, 8))
            plt.imshow(heatmap, cmap='jet')
            plt.axis('off')
            plt.tight_layout()
            plt.savefig(heatmap_path)
            plt.close()
            
            # Create and save overlay
            heatmap_uint8 = (heatmap * 255).astype(np.uint8)
            heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
            heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
            heatmap_resized = cv2.resize(heatmap_colored, (img_array.shape[1], img_array.shape[0]))
            overlay = cv2.addWeighted(img_array, 0.7, heatmap_resized, 0.3, 0)
            
            overlay_path = os.path.join(self.overlays_dir, f"{base_name}_overlay.jpg")
            cv2.imwrite(overlay_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
            
            return {
                "image_filename": image_filename,
                "image_path": image_path,
                "processed_image_path": processed_img_path,
                "heatmap_path": heatmap_path,
                "overlay_path": overlay_path,
                "is_fake": is_fake,
                "confidence": confidence,
                "description": description,
                "manual_description": "",  # Empty initially
                "source_label": source_label,
                "heatmap": heatmap  # Keep heatmap for visualization if needed
            }
            
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            return {
                "image_filename": os.path.basename(image_path),
                "image_path": image_path,
                "processed_image_path": "",
                "heatmap_path": "",
                "overlay_path": "",
                "is_fake": None,
                "confidence": 0.0,
                "description": f"Error processing image: {str(e)}",
                "manual_description": "",
                "source_label": "error",
                "heatmap": np.zeros((224, 224))
            }
    
    def preprocess_for_fake_detector(self, img):
        """Preprocess image for the fake detector model."""
        transform = transforms.Compose([
            transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        return transform(img)
    
    def get_feature_importance(self, img_tensor, original_shape):
        """Get feature importance using a gradient-based approach instead of SHAP."""
        try:
            img_tensor = img_tensor.to(self.device).unsqueeze(0)
            img_tensor.requires_grad_(True)
            
            # Forward pass with gradient
            self.fake_detector.eval()
            outputs = self.fake_detector(img_tensor)
            
            # Get the predicted class
            _, predicted = outputs.max(1)
            
            # Zero gradients
            self.fake_detector.zero_grad()
            
            # Backward pass for the predicted class
            one_hot = torch.zeros_like(outputs)
            one_hot[0, predicted] = 1
            outputs.backward(gradient=one_hot)
            
            # Get gradients
            gradients = img_tensor.grad.cpu().detach().numpy()[0]
            
            # Take absolute value and average over channels
            importance = np.mean(np.abs(gradients), axis=0)
            
            # Normalize importance scores
            importance_min, importance_max = importance.min(), importance.max()
            if importance_max > importance_min:
                importance = (importance - importance_min) / (importance_max - importance_min)
            
            # Resize to original shape
            importance_resized = cv2.resize(importance, (original_shape[1], original_shape[0]))
            
            return importance_resized
            
        except Exception as e:
            print(f"Error calculating feature importance: {e}")
            return np.zeros(original_shape)
    
    def generate_vlm_description(self, img, heatmap, is_fake, confidence):
        """Generate a description using a VLM, incorporating importance map insights."""
        try:
            self._load_vlm_model()
            
            img_array = np.array(img)
            # Create colored heatmap overlay
            heatmap_uint8 = (heatmap * 255).astype(np.uint8)
            heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
            heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
            
            # Resize heatmap to match image dimensions
            heatmap_colored = cv2.resize(heatmap_colored, (img_array.shape[1], img_array.shape[0]))
            
            # Create overlay image
            overlay = cv2.addWeighted(img_array, 0.7, heatmap_colored, 0.3, 0)
            overlay_img = Image.fromarray(overlay.astype('uint8'))
            
            # Prepare inputs for VLM
            inputs = self.vlm_processor(
                images=overlay_img,
                text=f"This image appears to be {'AI-generated' if is_fake else 'authentic'} with {confidence:.1%} confidence. Describe what makes it look {'artificial' if is_fake else 'real'}:",
                return_tensors="pt"
            ).to(self.device)
            
            # Generate description
            with torch.no_grad():
                outputs = self.vlm_model.generate(
                    **inputs,
                    max_new_tokens=100,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                )
            
            # Process and format description
            description = self.vlm_processor.decode(outputs[0], skip_special_tokens=True)
            prompt = f"This image appears to be {'AI-generated' if is_fake else 'authentic'} with {confidence:.1%} confidence. Describe what makes it look {'artificial' if is_fake else 'real'}:"
            if prompt in description:
                description = description.replace(prompt, "").strip()
            
            # Clean up to save memory
            del inputs, outputs, overlay, overlay_img
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            self._unload_vlm_model()
            
            return description
            
        except Exception as e:
            print(f"Error generating VLM description: {e}")
            self._unload_vlm_model()
            return f"Unable to generate description due to error: {str(e)}"
    
    def batch_process(self, image_paths, output_dir=None, batch_size=10):
        """Process a batch of images and save the results."""
        results = []
        
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        # Calculate total number of batches for progress tracking
        total_batches = (len(image_paths) - 1) // batch_size + 1
        
        for i in range(0, len(image_paths), batch_size):
            batch = image_paths[i:i+batch_size]
            batch_results = []
            
            for img_path in tqdm(batch, desc=f"Processing batch {i//batch_size + 1}/{total_batches}"):
                try:
                    result = self.process_image(img_path)
                    results.append(result)
                    batch_results.append(result)
                except Exception as e:
                    print(f"Error processing {img_path}: {e}")
                    # Continue with next image on error
                    continue
            
            # Add batch results to DataFrame
            batch_df = pd.DataFrame(batch_results)
            self.results_df = pd.concat([self.results_df, batch_df], ignore_index=True)
            
            # Save updated CSV after each batch
            if output_dir:
                csv_path = os.path.join(output_dir, "fake_detection_results.csv")
                # Create a copy of the dataframe without the heatmap column for CSV saving
                save_df = self.results_df.drop(columns=['heatmap'], errors='ignore')
                save_df.to_csv(csv_path, index=False)
                print(f"Updated CSV saved to {csv_path}")
            
            # Clear memory between batches
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        # Save final results
        if output_dir:
            # Create HTML file that shows the images, heatmaps, and descriptions
            self.create_html_report(output_dir)
        
        return results

    def create_html_report(self, output_dir):
        """Create an HTML report with all the images, heatmaps, and descriptions."""
        html_path = os.path.join(output_dir, "detection_results.html")
        
        html_content = """
        <!DOCTYPE html>
        <html>
        <head>
            <title>Fake Image Detection Results</title>
            <style>
                body { font-family: Arial, sans-serif; margin: 20px; }
                .entry { border: 1px solid #ddd; margin: 20px 0; padding: 15px; border-radius: 5px; }
                .images { display: flex; flex-wrap: wrap; gap: 10px; }
                .image-container { margin: 10px; }
                img { max-width: 300px; max-height: 300px; }
                .metadata { margin-top: 15px; }
                .fake { background-color: #ffdddd; }
                .real { background-color: #ddffdd; }
                .error { background-color: #dddddd; }
                h2 { margin-top: 0; }
                table { border-collapse: collapse; width: 100%; }
                table, th, td { border: 1px solid #ddd; }
                th, td { padding: 8px; text-align: left; }
                tr:nth-child(even) { background-color: #f2f2f2; }
            </style>
        </head>
        <body>
            <h1>Fake Image Detection Results</h1>
        """
        
        # Add summary statistics
        real_count = len(self.results_df[self.results_df['is_fake'] == False])
        fake_count = len(self.results_df[self.results_df['is_fake'] == True])
        error_count = len(self.results_df[self.results_df['is_fake'].isna()])
        
        html_content += f"""
        <div class="summary">
            <h2>Summary</h2>
            <p>Total images: {len(self.results_df)}</p>
            <p>Detected as real: {real_count}</p>
            <p>Detected as fake: {fake_count}</p>
            <p>Processing errors: {error_count}</p>
        </div>
        """
        
        # Add individual entries
        for _, row in self.results_df.iterrows():
            # Determine the CSS class based on the fake detection result
            if row['is_fake'] is None:
                entry_class = "error"
                status = "Error"
            elif row['is_fake']:
                entry_class = "fake"
                status = "Fake"
            else:
                entry_class = "real"
                status = "Real"
            
            confidence = row['confidence'] * 100 if row['confidence'] is not None else 0
            
            html_content += f"""
            <div class="entry {entry_class}">
                <h2>{row['image_filename']} - {status} ({confidence:.1f}%)</h2>
                <div class="images">
                    <div class="image-container">
                        <h3>Original</h3>
                        <img src="{row['processed_image_path']}" alt="Original Image">
                    </div>
                    <div class="image-container">
                        <h3>Heatmap</h3>
                        <img src="{row['heatmap_path']}" alt="Heatmap">
                    </div>
                    <div class="image-container">
                        <h3>Overlay</h3>
                        <img src="{row['overlay_path']}" alt="Overlay">
                    </div>
                </div>
                <div class="metadata">
                    <h3>Detection Details</h3>
                    <p><strong>Source Label:</strong> {row['source_label']}</p>
                    <p><strong>Description:</strong> {row['description']}</p>
                    <p><strong>Manual Description:</strong> {row['manual_description'] if row['manual_description'] else "N/A"}</p>
                </div>
            </div>
            """
        
        html_content += """
        </body>
        </html>
        """
        
        with open(html_path, 'w', encoding='utf-8') as f:
            f.write(html_content)
        
        print(f"HTML report saved to {html_path}")

    def export_to_excel(self, output_path):
        """Export results to an Excel file with embedded images."""
        from openpyxl import Workbook
        from openpyxl.drawing.image import Image as XLImage
        from openpyxl.utils import get_column_letter
        from PIL import Image as PILImage
        import io
        
        print("Exporting to Excel...")
        wb = Workbook()
        ws = wb.active
        ws.title = "Detection Results"
        
        # Write headers
        headers = [
            'Image ID', 'Is Fake', 'Confidence', 'Source Label', 
            'Description', 'Manual Description'
        ]
        for col_num, header in enumerate(headers, 1):
            ws.cell(row=1, column=col_num).value = header
        
        # Format header row
        for col_num in range(1, len(headers) + 1):
            cell = ws.cell(row=1, column=col_num)
            cell.font = cell.font.copy(bold=True)
        
        # Add data and images
        for idx, row in enumerate(self.results_df.iterrows(), 2):
            row_data = row[1]
            
            # Add text data
            ws.cell(row=idx, column=1).value = row_data['image_filename']
            ws.cell(row=idx, column=2).value = str(row_data['is_fake'])
            ws.cell(row=idx, column=3).value = f"{row_data['confidence']:.2%}" if row_data['confidence'] is not None else "N/A"
            ws.cell(row=idx, column=4).value = row_data['source_label']
            ws.cell(row=idx, column=5).value = row_data['description']
            ws.cell(row=idx, column=6).value = row_data['manual_description']
            
        # Adjust column widths
        for col_num in range(1, len(headers) + 1):
            column = get_column_letter(col_num)
            ws.column_dimensions[column].width = 30 if col_num >= 5 else 15
        
        # Save workbook
        wb.save(output_path)
        print(f"Excel report saved to {output_path}")
    
    def add_manual_description(self, image_filename, manual_description):
        """Add a manual description for an image."""
        idx = self.results_df[self.results_df['image_filename'] == image_filename].index
        if len(idx) > 0:
            self.results_df.loc[idx[0], 'manual_description'] = manual_description
            return True
        return False

    def visualize_result(self, result, save_path=None):
        """Visualize the result with the original image and heatmap overlay."""
        try:
            # Skip visualization if no valid result
            if result["is_fake"] is None:
                return None
                
            img = Image.open(result["image_path"]).convert('RGB')
            img_array = np.array(img)
            
            fig, axs = plt.subplots(1, 3, figsize=(18, 6))
            
            # Original image
            axs[0].imshow(img_array)
            axs[0].set_title("Original Image")
            axs[0].axis('off')
            
            # Heatmap
            heatmap_img = axs[1].imshow(result["heatmap"], cmap='jet')
            axs[1].set_title("Detection Heatmap")
            axs[1].axis('off')
            fig.colorbar(heatmap_img, ax=axs[1], fraction=0.046, pad=0.04)
            
            # Overlay
            heatmap_uint8 = (result["heatmap"] * 255).astype(np.uint8)
            heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
            heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
            heatmap_resized = cv2.resize(heatmap_colored, (img_array.shape[1], img_array.shape[0]))
            overlay = cv2.addWeighted(img_array, 0.7, heatmap_resized, 0.3, 0)
            
            axs[2].imshow(overlay)
            axs[2].set_title(f"Overlay - {'Fake' if result['is_fake'] else 'Real'} ({result['confidence']:.2%})")
            axs[2].axis('off')
            
            # Add description text - wrap text to avoid overflow
            plt.figtext(0.5, 0.01, f"Description: {result['description']}", 
                       wrap=True, horizontalalignment='center', fontsize=10)
            
            plt.tight_layout()
            
            if save_path:
                plt.savefig(save_path, bbox_inches='tight', dpi=150)
                
            return fig
        
        except Exception as e:
            print(f"Error visualizing result: {e}")
            return None

def main():
    """Main function adapted for notebook environment."""
    # Define input and output paths
    input_path = Config.DATA_ROOT  # Base dataset directory
    output_dir = Config.OUTPUT_DIR
    batch_size = 5  # Smaller batch size to avoid memory issues
    detector_path = os.path.join(Config.OUTPUT_DIR, 'best_model.pth')
    
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Check if model exists before proceeding
    if not os.path.exists(detector_path):
        print(f"Model not found at {detector_path}. Please train the model first.")
        return
    
    try:
        # Initialize the generator
        generator = VLMDescriptionGenerator(fake_detector_path=detector_path)
        
        # Get list of images recursively
        image_exts = ['.jpg', '.jpeg', '.png', '.bmp']
        image_paths = []
        if os.path.isdir(input_path):
            for root, _, files in os.walk(input_path):
                for f in files:
                    if os.path.splitext(f.lower())[1] in image_exts:
                        image_paths.append(os.path.join(root, f))
        else:
            if os.path.splitext(input_path.lower())[1] in image_exts:
                image_paths = [input_path]
        
        print(f"Found {len(image_paths)} images")
        
        if not image_paths:
            print("No images found. Please check the input path.")
            return
        
        # Process images with error handling
        try:
            results = generator.batch_process(image_paths, output_dir, batch_size)
            
            # Save results as Excel file
            excel_path = os.path.join(output_dir, "fake_detection_results.xlsx")
            generator.export_to_excel(excel_path)
            
            print(f"Processed {len(results)} images. Results saved to:")
            print(f"- CSV: {os.path.join(output_dir, 'fake_detection_results.csv')}")
            print(f"- Excel: {excel_path}")
            print(f"- HTML Report: {os.path.join(output_dir, 'detection_results.html')}")
            
        except Exception as e:
            print(f"Error in batch processing: {e}")
            
    except Exception as e:
        print(f"Error in main function: {e}")

if __name__ == "__main__":
    main()

2025-04-15 03:33:44.781008: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744688024.950881      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744688025.005340      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the funct

Loading CLIP model...


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

Models loaded successfully
Found 2041 images


Processing batch 1/409:   0%|          | 0/5 [00:00<?, ?it/s]

Loading VLM model...


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

VLM model loaded
VLM model unloaded


Processing batch 1/409:  20%|██        | 1/5 [00:08<00:34,  8.71s/it]

Loading VLM model...
VLM model loaded


Processing batch 1/409:  40%|████      | 2/5 [00:11<00:16,  5.41s/it]

VLM model unloaded
Loading VLM model...
VLM model loaded


Processing batch 1/409:  60%|██████    | 3/5 [00:14<00:08,  4.29s/it]

VLM model unloaded
Loading VLM model...
VLM model loaded


Processing batch 1/409:  80%|████████  | 4/5 [00:17<00:03,  3.69s/it]

VLM model unloaded
Loading VLM model...
VLM model loaded
VLM model unloaded


Processing batch 1/409: 100%|██████████| 5/5 [00:20<00:00,  4.05s/it]
The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.


Updated CSV saved to outputs/fake_detection_results.csv


Processing batch 2/409:   0%|          | 0/5 [00:00<?, ?it/s]

Loading VLM model...


Processing batch 2/409:   0%|          | 0/5 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [3]:
!zip -r /kaggle/working/output.zip /kaggle/working/

  adding: kaggle/working/ (stored 0%)
  adding: kaggle/working/outputs/ (stored 0%)
  adding: kaggle/working/outputs/fake_detection_results.csv (deflated 76%)
  adding: kaggle/working/outputs/shap_vis_1.png (deflated 0%)
  adding: kaggle/working/outputs/shap_vis_3.png (deflated 0%)
  adding: kaggle/working/outputs/shap_vis_0.png (deflated 0%)
  adding: kaggle/working/outputs/overlays/ (stored 0%)
  adding: kaggle/working/outputs/overlays/mid_233_1111_overlay.jpg (deflated 1%)
  adding: kaggle/working/outputs/overlays/hard_32_1111_overlay.jpg (deflated 0%)
  adding: kaggle/working/outputs/overlays/mid_345_1111_overlay.jpg (deflated 1%)
  adding: kaggle/working/outputs/overlays/mid_161_0110_overlay.jpg (deflated 0%)
  adding: kaggle/working/outputs/overlays/mid_200_1111_overlay.jpg (deflated 0%)
  adding: kaggle/working/outputs/descriptions.json (deflated 94%)
  adding: kaggle/working/outputs/processed_images/ (stored 0%)
  adding: kaggle/working/outputs/processed_images/mid_233_1111_pro

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


 (deflated 0%)
  adding: kaggle/working/outputs/best_model.pth (deflated 7%)
  adding: kaggle/working/outputs/heatmaps/ (stored 0%)
  adding: kaggle/working/outputs/heatmaps/mid_233_1111_heatmap.jpg (deflated 13%)
  adding: kaggle/working/outputs/heatmaps/hard_32_1111_heatmap.jpg (deflated 13%)
  adding: kaggle/working/outputs/heatmaps/mid_161_0110_heatmap.jpg (deflated 10%)
  adding: kaggle/working/outputs/heatmaps/mid_200_1111_heatmap.jpg (deflated 12%)
  adding: kaggle/working/outputs/heatmaps/mid_345_1111_heatmap.jpg (deflated 11%)
  adding: kaggle/working/.virtual_documents/ (stored 0%)


## 🔍 SHAP Analysis on Aggregated Video Frames

In [None]:
import shap
import matplotlib.pyplot as plt
import torch

# Load model and processor
model.eval()
model.cpu()  # SHAP works better on CPU

# Pick a few samples from validation set
sample_frames = []
sample_labels = []
for i in range(5):  # Pick 5 random samples
    img_tensor, label = val_dataset[i]
    sample_frames.append(img_tensor.numpy())
    sample_labels.append(label)

sample_frames = torch.tensor(sample_frames)

# Define prediction function
def predict(images):
    with torch.no_grad():
        images = torch.tensor(images).to(torch.float32)
        outputs = model(pixel_values=images)
        return outputs.logits.numpy()

# Create SHAP explainer
masker = shap.maskers.Image("inpaint_telea", sample_frames[0].shape)
explainer = shap.Explainer(predict, masker, output_names=["original", "manipulated"])

# Compute SHAP values
shap_values = explainer(sample_frames)

# Visualize
shap.image_plot(shap_values, sample_frames)
