# Road Defect Detection with Grounding DINO

This notebook demonstrates using **Grounding DINO** for zero-shot road defect detection.

## What is Grounding DINO?

Grounding DINO is an **open-set object detector** that combines vision and language understanding:

- **Zero-shot detection**: Detect objects without training on specific classes
- **Text-based queries**: Describe what you want to detect (e.g., "a pothole. a crack. a manhole cover.")
- **No training required**: Works out of the box with natural language prompts
- **State-of-the-art**: 52.5 AP on COCO zero-shot

## Grounding DINO vs YOLOv8

| Feature | Grounding DINO | YOLOv8 |
|---------|----------------|--------|
| Detection approach | Zero-shot with text prompts | Trained on specific classes |
| Training required | No | Yes (on labeled data) |
| Flexibility | Detect anything with text | Limited to trained classes |
| Speed | Slower (~10-20 FPS) | Faster (~100+ FPS) |
| Accuracy (trained domain) | Lower | Higher |
| Accuracy (new domains) | Better generalization | Requires fine-tuning |
| Use case | Exploration, prototyping, rare objects | Production, real-time systems |

**When to use Grounding DINO:**
- Quick prototyping without labeled data
- Detecting rare defect types
- Exploring new road damage categories
- Research and experimentation

**When to use YOLOv8:**
- Production deployment with labeled data
- Real-time processing requirements
- Higher accuracy on known defect types

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install transformers torch torchvision
!pip install pillow opencv-python matplotlib numpy
!pip install timm  # Required for Grounding DINO

In [None]:
# Import libraries
import torch
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from PIL import Image, ImageDraw, ImageFont
import requests
import matplotlib.pyplot as plt
import numpy as np
import cv2
from pathlib import Path

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 2. Load Grounding DINO Model

Available model sizes:
- `IDEA-Research/grounding-dino-tiny` - Fastest, lower accuracy
- `IDEA-Research/grounding-dino-base` - Balanced (recommended)

In [None]:
# Load model and processor
model_id = "IDEA-Research/grounding-dino-tiny"
# model_id = "IDEA-Research/grounding-dino-base"  # Use this for better accuracy

print(f"Loading {model_id}...")
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)

print("Model loaded successfully!")

## 3. Define Road Defect Text Prompts

Grounding DINO uses natural language to describe what to detect. Use **periods** to separate different object types.

In [None]:
# Define text prompts for road defects
# Format: Use periods (.) to separate different objects

# Option 1: Simple class names
text_prompt_simple = "pothole. crack. manhole cover. patch."

# Option 2: Descriptive phrases (often works better)
text_prompt_detailed = (
    "a pothole in the road. "
    "a crack in the pavement. "
    "an alligator crack. "
    "a longitudinal crack. "
    "a transverse crack. "
    "a road patch. "
    "a manhole cover."
)

# Option 3: All SVRDD classes
text_prompt_svrdd = (
    "alligator crack. "
    "longitudinal crack. "
    "transverse crack. "
    "pothole. "
    "longitudinal patch. "
    "transverse patch. "
    "manhole cover."
)

# Choose which prompt to use
text_prompt = text_prompt_svrdd

print(f"Text prompt: {text_prompt}")

## 4. Helper Functions for Visualization

In [None]:
def visualize_predictions(image, boxes, labels, scores, threshold=0.3):
    """
    Visualize predictions on image
    
    Args:
        image: PIL Image
        boxes: List of bounding boxes [x_min, y_min, x_max, y_max]
        labels: List of label strings
        scores: List of confidence scores
        threshold: Minimum confidence threshold
    """
    # Create copy of image
    draw = ImageDraw.Draw(image)
    
    # Try to load a font, fall back to default if not available
    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except:
        font = ImageFont.load_default()
    
    # Define colors for different defect types
    colors = {
        'pothole': 'red',
        'crack': 'yellow',
        'alligator': 'orange',
        'longitudinal': 'blue',
        'transverse': 'green',
        'patch': 'purple',
        'manhole': 'cyan',
        'default': 'white'
    }
    
    # Draw boxes and labels
    for box, label, score in zip(boxes, labels, scores):
        if score < threshold:
            continue
        
        # Get color based on label
        color = colors.get('default', 'white')
        for key in colors:
            if key in label.lower():
                color = colors[key]
                break
        
        # Draw bounding box
        draw.rectangle(box, outline=color, width=3)
        
        # Draw label with score
        text = f"{label}: {score:.2f}"
        text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
        draw.rectangle(text_bbox, fill=color)
        draw.text((box[0], box[1]), text, fill='black', font=font)
    
    return image


def postprocess_predictions(outputs, target_sizes, threshold=0.3):
    """
    Post-process model outputs to get boxes, labels, and scores
    """
    results = processor.post_process_grounded_object_detection(
        outputs,
        target_sizes=target_sizes,
        threshold=threshold
    )
    return results


def detect_defects(image, text_prompt, threshold=0.3):
    """
    Detect road defects in an image using text prompts
    
    Args:
        image: PIL Image or path to image
        text_prompt: Text description of objects to detect
        threshold: Confidence threshold (0-1)
    
    Returns:
        Dictionary with boxes, labels, and scores
    """
    # Load image if path is provided
    if isinstance(image, (str, Path)):
        image = Image.open(image)
    
    # Prepare inputs
    inputs = processor(images=image, text=text_prompt, return_tensors="pt").to(device)
    
    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Post-process
    target_sizes = torch.tensor([image.size[::-1]]).to(device)
    results = postprocess_predictions(outputs, target_sizes, threshold=threshold)
    
    # Extract boxes, labels, and scores
    result = results[0]
    boxes = result['boxes'].cpu().numpy()
    labels = result['labels']
    scores = result['scores'].cpu().numpy()
    
    return {
        'boxes': boxes,
        'labels': labels,
        'scores': scores,
        'image': image
    }

print("Helper functions loaded!")

## 5. Test on Sample Image

Let's test the model on a sample image (you can replace with your own road images).

In [None]:
# Load a sample image (replace with your own road image path)
# Option 1: Load from URL
sample_url = "http://images.cocodataset.org/val2017/000000039769.jpg"  # Replace with road image
image = Image.open(requests.get(sample_url, stream=True).raw)

# Option 2: Load from local file (uncomment to use)
# image = Image.open("path/to/your/road/image.jpg")

# Display original image
plt.figure(figsize=(12, 8))
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')
plt.show()

print(f"Image size: {image.size}")

In [None]:
# Run detection
print(f"Detecting with prompt: {text_prompt}\n")

results = detect_defects(
    image=image,
    text_prompt=text_prompt,
    threshold=0.25  # Adjust threshold (0.2-0.4 typically works well)
)

# Print results
print(f"Found {len(results['boxes'])} detections:")
for label, score in zip(results['labels'], results['scores']):
    print(f"  - {label}: {score:.3f}")

In [None]:
# Visualize results
annotated_image = image.copy()
annotated_image = visualize_predictions(
    annotated_image,
    results['boxes'],
    results['labels'],
    results['scores'],
    threshold=0.25
)

# Display
plt.figure(figsize=(15, 10))
plt.imshow(annotated_image)
plt.title("Detections")
plt.axis('off')
plt.show()

## 6. Batch Processing on SVRDD Dataset

In [None]:
# Process multiple images from SVRDD dataset
data_dir = Path("./data/svrdd/images/val")  # Adjust path as needed

if data_dir.exists():
    # Get sample images
    image_files = list(data_dir.glob("*.jpg"))[:5]  # Process first 5 images
    
    print(f"Processing {len(image_files)} images...\n")
    
    for img_path in image_files:
        print(f"Processing {img_path.name}...")
        
        # Run detection
        results = detect_defects(
            image=img_path,
            text_prompt=text_prompt,
            threshold=0.3
        )
        
        # Visualize
        annotated = results['image'].copy()
        annotated = visualize_predictions(
            annotated,
            results['boxes'],
            results['labels'],
            results['scores'],
            threshold=0.3
        )
        
        # Display
        plt.figure(figsize=(15, 10))
        plt.imshow(annotated)
        plt.title(f"{img_path.name} - {len(results['boxes'])} detections")
        plt.axis('off')
        plt.show()
        
        print(f"Found {len(results['boxes'])} defects\n")
else:
    print(f"Data directory not found: {data_dir}")
    print("Please download and extract the SVRDD dataset first.")

## 7. Process Video Frames

Extract and process frames from road inspection videos.

In [None]:
def process_video(video_path, output_dir, text_prompt, frame_interval=30, threshold=0.3):
    """
    Process video and detect road defects
    
    Args:
        video_path: Path to video file
        output_dir: Directory to save annotated frames
        text_prompt: Detection text prompt
        frame_interval: Process every Nth frame
        threshold: Detection confidence threshold
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    cap = cv2.VideoCapture(str(video_path))
    frame_count = 0
    processed_count = 0
    total_detections = 0
    
    print(f"Processing video: {video_path}")
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_count % frame_interval == 0:
            # Convert BGR to RGB
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(frame_rgb)
            
            # Run detection
            results = detect_defects(image, text_prompt, threshold)
            
            # Visualize
            annotated = image.copy()
            annotated = visualize_predictions(
                annotated,
                results['boxes'],
                results['labels'],
                results['scores'],
                threshold=threshold
            )
            
            # Save
            output_path = output_dir / f"frame_{processed_count:06d}_detections_{len(results['boxes'])}.jpg"
            annotated.save(output_path)
            
            total_detections += len(results['boxes'])
            processed_count += 1
            
            if processed_count % 10 == 0:
                print(f"Processed {processed_count} frames, found {total_detections} total detections")
        
        frame_count += 1
    
    cap.release()
    
    print(f"\nProcessing complete!")
    print(f"Total frames: {frame_count}")
    print(f"Processed frames: {processed_count}")
    print(f"Total detections: {total_detections}")
    print(f"Average detections per frame: {total_detections/processed_count:.2f}")
    print(f"Results saved to: {output_dir}")

# Example usage (uncomment when you have a video)
# process_video(
#     video_path="path/to/road/video.mp4",
#     output_dir="outputs/video_detections",
#     text_prompt=text_prompt_svrdd,
#     frame_interval=30,
#     threshold=0.3
# )

## 8. Experiment with Different Prompts

Try different text descriptions to see what works best for your use case.

In [None]:
# Test different prompts on the same image
test_prompts = [
    "damage on road. defect on pavement.",
    "pothole. crack. hole in road.",
    "road surface damage. pavement defect.",
    "broken asphalt. damaged concrete. road crack.",
]

# Load test image (replace with your road image)
# test_image = Image.open("path/to/road/image.jpg")

# for prompt in test_prompts:
#     print(f"\nTesting prompt: {prompt}")
#     results = detect_defects(test_image, prompt, threshold=0.25)
#     print(f"Detections: {len(results['boxes'])}")
#     
#     # Visualize
#     annotated = test_image.copy()
#     annotated = visualize_predictions(annotated, results['boxes'], results['labels'], results['scores'])
#     
#     plt.figure(figsize=(15, 10))
#     plt.imshow(annotated)
#     plt.title(f"Prompt: {prompt}")
#     plt.axis('off')
#     plt.show()

print("Prompt experimentation code ready. Uncomment to test with your images.")

## 9. Calculate Detection Statistics

In [None]:
import pandas as pd
from collections import Counter

def calculate_detection_stats(results_list):
    """
    Calculate statistics from multiple detection results
    
    Args:
        results_list: List of detection result dictionaries
    """
    all_labels = []
    all_scores = []
    
    for results in results_list:
        all_labels.extend(results['labels'])
        all_scores.extend(results['scores'])
    
    # Count detections by type
    label_counts = Counter(all_labels)
    
    # Create DataFrame
    stats_df = pd.DataFrame([
        {'Defect Type': label, 'Count': count, 'Percentage': f"{count/len(all_labels)*100:.1f}%"}
        for label, count in label_counts.most_common()
    ])
    
    print(f"\n=== Detection Statistics ===")
    print(f"Total detections: {len(all_labels)}")
    print(f"Unique defect types: {len(label_counts)}")
    print(f"Average confidence: {np.mean(all_scores):.3f}")
    print(f"\nDetection breakdown:")
    print(stats_df.to_string(index=False))
    
    return stats_df

# Example usage after processing multiple images
# stats = calculate_detection_stats(results_list)

## 10. Tips for Better Results

### Prompt Engineering Tips:
1. **Be specific**: "pothole in asphalt" vs "pothole"
2. **Use periods**: Separate objects with periods (required)
3. **Try synonyms**: "crack", "fissure", "split in road"
4. **Add context**: "damage on road surface" can work better than just "damage"

### Threshold Tuning:
- **Higher threshold (0.4-0.6)**: Fewer false positives, may miss defects
- **Lower threshold (0.2-0.3)**: More detections, more false positives
- **Start at 0.3** and adjust based on results

### Model Selection:
- **grounding-dino-tiny**: Faster, less accurate (good for prototyping)
- **grounding-dino-base**: Slower, more accurate (recommended for final results)

### Performance:
- Grounding DINO is slower than YOLO (10-20 FPS vs 100+ FPS)
- Use GPU for acceptable speed
- Process every Nth frame for videos (frame_interval=30)
- Consider using Grounding DINO for labeling data, then train YOLOv8 for deployment

## 11. Integration with Segment Anything (SAM)

Grounding DINO can be combined with SAM for pixel-accurate segmentation masks.

In [None]:
# Install SAM
# !pip install segment-anything

# Example workflow (requires SAM setup):
# 1. Use Grounding DINO to detect defects (get bounding boxes)
# 2. Use SAM to generate masks from those boxes
# 3. Get pixel-accurate segmentation of cracks, potholes, etc.

print("SAM integration can provide pixel-accurate masks for detected defects.")
print("This is useful for measuring defect size, calculating area, etc.")

## Next Steps

### Option 1: Use Grounding DINO for Prototyping
1. Test on your road footage without training
2. Experiment with different text prompts
3. Use it to label data for YOLOv8 training

### Option 2: Combine with YOLOv8
1. Use Grounding DINO to generate pseudo-labels
2. Manually verify and correct
3. Train YOLOv8 for faster inference

### Option 3: Hybrid Approach
1. Use YOLOv8 for common defects (fast)
2. Use Grounding DINO for rare/new defect types
3. Best of both worlds!

## Resources

- Grounding DINO Paper: https://arxiv.org/abs/2303.05499
- Hugging Face Docs: https://huggingface.co/docs/transformers/en/model_doc/grounding-dino
- Model Hub: https://huggingface.co/IDEA-Research
- Segment Anything: https://segment-anything.com/