# DIRE Detection 
This notebook implements the DIffusion REconstruction Error (DIRE) detector for distinguishing real vs. diffusion-generated images.

In [None]:
%pip install torch

import torch
import torch.nn.functional as F
from diffusers import DDPMPipeline, DDPMScheduler
from torchvision import transforms
from PIL import Image

# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "google/ddpm-cifar10-32"
num_inference_steps = 100
dire_threshold = 0.01

# Load pipeline and scheduler
pipe = DDPMPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
scheduler = DDPMScheduler.from_pretrained(model_id)
scheduler.set_timesteps(num_inference_steps)

# Preprocessing
preprocess = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

def compute_dire(img: Image.Image):
    x0 = preprocess(img).unsqueeze(0).to(device)
    t = scheduler.timesteps[-1]
    noise = torch.randn_like(x0)
    xt = scheduler.add_noise(x0, noise, t)
    x_recon = xt
    for t in scheduler.timesteps:
        with torch.no_grad():
            model_output = pipe.unet(x_recon, t)["sample"]
        x_recon = scheduler.step(model_output, t, x_recon)["prev_sample"]
    dire = F.mse_loss(x_recon, x0).item()
    return dire, x_recon.squeeze(0).cpu()

def detect_via_dire(img: Image.Image):
    dire, recon = compute_dire(img)
    is_diffusion = dire < dire_threshold
    return is_diffusion, dire, recon


In [None]:

# Example usage
if __name__ == "__main__":
    import sys
    if len(sys.argv) != 2:
        print("Usage: python detect_dire.py path/to/image.png")
        sys.exit(1)
    img = Image.open(sys.argv[1]).convert("RGB")
    is_diff, dire_score, recon_img = detect_via_dire(img)
    print(f"DIRE score = {dire_score:.6f}")
    print("Detected as", "diffusion-generated" if is_diff else "real")
    recon_pil = transforms.ToPILImage()(recon_img)
    recon_pil.save("reconstruction.png")
