# DocSAF Demo Notebook

This notebook demonstrates the DocSAF Phase 0 vertical slice:
- Load a document image
- Extract text via OCR  
- Compute cross-modal saliency
- Apply attenuation field with two knobs: `alpha` and `radius`
- Generate adversarial image

**Two tunables only:** `alpha` (field strength) and `radius` (blur kernel size)


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from pathlib import Path

# DocSAF imports
from src.docsaf.utils import load_config, pil_to_tensor, tensor_to_pil
from src.docsaf.surrogates import load_embedder
from src.docsaf.ocr import ocr_read
from src.docsaf.saliency import compute_gradient_saliency
from src.docsaf.field import apply_field_safe
from src.docsaf.pdf_io import pdf_to_pil, is_pdf_file

print("DocSAF imports successful!")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


## Step 1: Load Configuration and Image


In [None]:
# Load config
config = load_config("configs/default.yaml")
print(f"Config alpha: {config['alpha']}, radius: {config['radius']}")

# Load demo image (replace with your image path)
image_path = "demo/sample_doc.png"  # Change this to your demo image

if is_pdf_file(image_path):
    print("Loading PDF page...")
    pil_image = pdf_to_pil(image_path, page=0, zoom=2.0)
else:
    print("Loading image...")
    pil_image = Image.open(image_path).convert("RGB")

# Convert to tensor
x_orig = pil_to_tensor(pil_image, device)

# Display original image
plt.figure(figsize=(10, 8))
plt.imshow(pil_image)
plt.title("Original Document Image")
plt.axis('off')
plt.show()

print(f"Image shape: {x_orig.shape}")


## Step 2: Extract Text via OCR


In [None]:
# Extract text using OCR
img_array = np.array(pil_image)
ocr_backend = config.get("ocr", "easyocr")
extracted_text = ocr_read(img_array, backend=ocr_backend)

print(f"OCR Backend: {ocr_backend}")
print(f"Extracted Text ({len(extracted_text)} chars):")
print(f"'{extracted_text[:200]}{'...' if len(extracted_text) > 200 else ''}'")

# Use fallback text if OCR fails
if not extracted_text.strip():
    extracted_text = "document text content"
    print("Using fallback text for demo")


## Step 3: Compute Cross-Modal Saliency → Apply Attenuation Field → Compare Results


In [None]:
# Load embedder model
surrogate_specs = config.get("surrogates", ["openclip:ViT-L-14@336"])
embedder_spec = surrogate_specs[0] if isinstance(surrogate_specs, list) else surrogate_specs
embedder = load_embedder(embedder_spec, device)
print(f"Loaded embedder: {embedder_spec}")

# Compute gradient saliency
x_input = x_orig.clone().requires_grad_(True)
original_alignment, saliency_map = compute_gradient_saliency(
    embedder, x_input, extracted_text, normalize=True
)

print(f"Original CLIP alignment score: {original_alignment:.4f}")

# Apply attenuation field with the two knobs
alpha = config['alpha']  # Field strength
radius = config['radius']  # Blur radius
print(f"Applying attenuation field with alpha={alpha}, radius={radius}")

x_adv = apply_field_safe(x_orig, saliency_map, alpha, radius)

# Compute adversarial alignment score
with torch.no_grad():
    adv_alignment, _ = compute_gradient_saliency(embedder, x_adv, extracted_text)

alignment_drop = original_alignment - adv_alignment
print(f"Adversarial CLIP alignment score: {adv_alignment:.4f}")
print(f"Alignment drop: {alignment_drop:.4f}")

# Convert to PIL for visualization
pil_adv = tensor_to_pil(x_adv)

# Side-by-side comparison
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Original image
axes[0,0].imshow(pil_image)
axes[0,0].set_title(f"Original\\nCLIP Score: {original_alignment:.4f}")
axes[0,0].axis('off')

# Saliency heatmap  
sal_cpu = saliency_map.squeeze().cpu().numpy()
im1 = axes[0,1].imshow(sal_cpu, cmap='hot', interpolation='bilinear')
axes[0,1].set_title(f"Saliency Map")
axes[0,1].axis('off')
plt.colorbar(im1, ax=axes[0,1], fraction=0.046)

# Saliency overlay
axes[1,0].imshow(pil_image)
axes[1,0].imshow(sal_cpu, cmap='hot', alpha=0.4, interpolation='bilinear')
axes[1,0].set_title("Saliency Overlay")
axes[1,0].axis('off')

# Adversarial result
axes[1,1].imshow(pil_adv)
axes[1,1].set_title(f"Adversarial (α={alpha}, r={radius})\\nCLIP Score: {adv_alignment:.4f}\\nDrop: {alignment_drop:.4f}")
axes[1,1].axis('off')

plt.tight_layout()
plt.show()

# Save adversarial image
output_path = Path(image_path).parent / f"{Path(image_path).stem}_adv.png"
pil_adv.save(output_path)
print(f"Saved adversarial image to: {output_path}")

print("\\n=== DocSAF Phase 0 Demo Complete ===")
print(f"Two tunables used: alpha={alpha}, radius={radius}")
print(f"Success: {alignment_drop > 0.01}")
