In [1]:
import sys
sys.path.append("/Users/msrobin/GitHub Repositorys/Interpretable-Deep-Fake-Detection-2")
sys.argv = ["train.py"]

In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

from pytorch_grad_cam import GradCAM, XGradCAM, LayerCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

from training.detectors.resnet34_detector import ResnetDetector

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  Referenced from: <FB2FD416-6C4D-3621-B677-61F07C02A3C5> /opt/anaconda3/envs/lime/lib/python3.9/site-packages/torchvision/image.so
  warn(


In [4]:
class CustomImageDataset(Dataset):
    def __init__(self, folder_paths, transform=None):
        self.image_files = []
        for folder, label in folder_paths.items():
            for f in os.listdir(folder):
                if f.endswith((".png", ".jpg")):
                    self.image_files.append((os.path.join(folder, f), label))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path, label = self.image_files[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label, img_path

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# You can adjust the dataset path here:
file_path_deepfakebench = {
    "/Users/msrobin/GitHub Repositorys/Interpretable-Deep-Fake-Detection-2/datasets/2x2_images": 1
}
dataset = CustomImageDataset(file_path_deepfakebench, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [6]:
# Load pretrained ResNet34-based detector
config = {
    "pretrained": "./weights/ckpt_best.pth",
    "model_name": "resnet34",
    "backbone_name": "resnet34",
    "backbone_config": {"num_classes": 2, "inc": 3}
}
model = ResnetDetector(config)
model.eval().to(device)

# Wrap model for GradCAM compatibility
class WrappedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        data_dict = {'image': x}
        return self.model(data_dict)["cls"]

wrapped_model = WrappedModel(model)


KeyError: 'mode'

In [None]:
# Find last conv layer (excluding adjust_channel)
target_layers = []
for name, module in model.backbone.named_modules():
    if isinstance(module, torch.nn.Conv2d) and "adjust" not in name:
        target_layers.append((name, module))

# Choose the last eligible one
last_layer_name, last_layer = target_layers[-1]
print(f"Using target layer: {last_layer_name}")


In [None]:
from ipywidgets import widgets, interact

cam_methods = {
    "GradCAM": GradCAM,
    "XGradCAM": XGradCAM,
    "LayerCAM": LayerCAM
}

def run_cam(cam_type="GradCAM"):
    cam_class = cam_methods[cam_type]
    cam = cam_class(model=wrapped_model, target_layers=[last_layer])

    for img_batch, label_batch, path_batch in dataloader:
        for i in range(min(3, len(img_batch))):
            img = img_batch[i].unsqueeze(0).to(device)
            label = label_batch[i]
            targets = [ClassifierOutputTarget(label.item())]

            grayscale_cam = cam(input_tensor=img, targets=targets)[0]

            img_np = img.squeeze().cpu().numpy().transpose(1, 2, 0)
            img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
            heatmap = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)

            plt.figure(figsize=(10, 5))
            plt.imshow(heatmap)
            plt.axis("off")
            plt.title(f"{cam_type} | Label: {label.item()}")
            plt.show()
        break

interact(run_cam, cam_type=list(cam_methods.keys()))
