# DETR Evaluation on Roboflow Wildfire Smoke Test Set (Local)

This notebook loads a DETR model checkpoint (trained on Roboflow YOLO-format wildfire smoke dataset) and evaluates it on the test split, reporting mAP and visualizing predictions.

## 1. Install and Import Dependencies

In [ ]:
# Uncomment if running for the first time
# !pip install transformers torchvision torchmetrics pycocotools ruamel.yaml opencv-python tqdm

In [ ]:
import os
import torch
from torchvision.datasets import CocoDetection
from torchvision import transforms
from torch.utils.data import DataLoader
from transformers import DetrForObjectDetection, DetrConfig, DetrImageProcessor
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from ruamel.yaml import YAML
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

## 2. Set Paths and Load Class Names

In [ ]:
data_root = './roboflow'  # Update if needed
yaml_path = os.path.join(data_root, 'data.yaml')
checkpoint_path = './detr_last.ckpt'  # Path to your downloaded checkpoint

yaml = YAML()
with open(yaml_path, 'r') as f:
    data_yaml = yaml.load(f)
class_list = data_yaml['names'] if 'names' in data_yaml else data_yaml['nc']
print('Class names:', class_list)

## 3. Convert YOLO to COCO for Test Set (if needed)

In [ ]:
def yolo_to_coco(img_dir, label_dir, class_list, output_json):
    images = []
    annotations = []
    annotation_id = 1
    img_files = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    for img_id, filename in enumerate(img_files, 1):
        img_path = os.path.join(img_dir, filename)
        img = Image.open(img_path)
        width, height = img.size
        images.append({
            'id': img_id,
            'file_name': filename,
            'width': width,
            'height': height
        })
        label_path = os.path.join(label_dir, filename.rsplit('.', 1)[0] + '.txt')
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) == 5:
                        class_id, x_center, y_center, w, h = map(float, parts)
                        x = (x_center - w/2) * width
                        y = (y_center - h/2) * height
                        w_box = w * width
                        h_box = h * height
                        annotations.append({
                            'id': annotation_id,
                            'image_id': img_id,
                            'category_id': int(class_id) + 1,
                            'bbox': [x, y, w_box, h_box],
                            'area': w_box * h_box,
                            'iscrowd': 0
                        })
                        annotation_id += 1
    categories = [{"id": i+1, "name": name} for i, name in enumerate(class_list)]
    coco_dict = {'images': images, 'annotations': annotations, 'categories': categories}
    with open(output_json, 'w') as f:
        json.dump(coco_dict, f)
    print(f'COCO annotation saved to {output_json}')
    return coco_dict

test_img_dir = os.path.join(data_root, 'test', 'images')
test_label_dir = os.path.join(data_root, 'test', 'labels')
test_coco_json = os.path.join(data_root, 'test_coco.json')
if not os.path.exists(test_coco_json):
    yolo_to_coco(test_img_dir, test_label_dir, class_list, test_coco_json)
else:
    print(f'COCO annotation already exists: {test_coco_json}')

## 4. Prepare DataLoader for Test Set

In [ ]:
def detr_collate_fn(batch):
    images, targets = list(zip(*batch))
    new_targets = []
    for t in targets:
        boxes = []
        labels = []
        for obj in t:
            boxes.append(obj['bbox'])
            labels.append(obj['category_id'] - 1)
        if boxes:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)
        else:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        new_targets.append({'boxes': boxes, 'labels': labels})
    return images, new_targets

transform = transforms.Compose([transforms.ToTensor()])
test_dataset = CocoDetection(test_img_dir, test_coco_json, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2, collate_fn=detr_collate_fn)
print('Test set size:', len(test_dataset))

## 5. Load DETR Model from Checkpoint

In [ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = DetrConfig.from_pretrained('facebook/detr-resnet-50')
config.num_labels = len(class_list)
model = DetrForObjectDetection(config)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'state_dict' in checkpoint:  # Lightning checkpoint
    model.load_state_dict({k.replace('model.', ''): v for k, v in checkpoint['state_dict'].items() if k.startswith('model.')})
else:
    model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
print('Loaded DETR model from checkpoint')

## 6. Evaluate on Test Set (mAP and Metrics)

In [ ]:
map_metric = MeanAveragePrecision(class_metrics=True)
all_losses = []
for images, targets in tqdm(test_loader):
    with torch.no_grad():
        encoding = processor(images, return_tensors="pt", do_rescale=False).to(device)
        outputs = model(**encoding)
    # Post-process outputs
    if isinstance(images[0], torch.Tensor):
        target_sizes = torch.stack([torch.tensor(img.shape[-2:]) for img in images]).to(device)
    else:
        target_sizes = torch.tensor([img.size[::-1] for img in images]).to(device)
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)
    # Format results and targets for metric
    for r in results:
        r["boxes"] = r["boxes"].to(device)
        r["labels"] = r["labels"].to(device)
        r["scores"] = r["scores"].to(device)
    formatted_targets = []
    for t in targets:
        formatted_targets.append({
            "boxes": t["boxes"].to(device),
            "labels": t["labels"].to(device)
        })
    map_metric.update(results, formatted_targets)
metrics = map_metric.compute()
for k in list(metrics.keys()):
    if isinstance(metrics[k], torch.Tensor) and metrics[k].numel() > 1:
        valid = metrics[k][metrics[k] != -1]
        metrics[k + '_mean'] = valid.mean().item() if valid.numel() > 0 else float('nan')
        del metrics[k]
scalar_metrics = {k: (v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else v) for k, v in metrics.items() if isinstance(v, (float, int)) or (isinstance(v, torch.Tensor) and v.numel() == 1)}
print('Test metrics:', scalar_metrics)

## 7. Visualize Predictions

In [ ]:
def plot_predictions(model, processor, dataloader, device, class_names, num_images=4, score_threshold=0.5):
    model.eval()
    images_shown = 0
    for images, targets in dataloader:
        pixel_values = processor(images, return_tensors="pt", do_rescale=False).pixel_values.to(device)
        with torch.no_grad():
            outputs = model(pixel_values=pixel_values)
        results = processor.post_process_object_detection(outputs, target_sizes=[img.shape[1:] for img in images], threshold=score_threshold)
        for idx, (image, result) in enumerate(zip(images, results)):
            plt.figure(figsize=(8, 6))
            img_np = image.permute(1, 2, 0).cpu().numpy()
            plt.imshow(img_np)
            ax = plt.gca()
            boxes = result["boxes"].cpu().numpy()
            scores = result["scores"].cpu().numpy()
            labels = result["labels"].cpu().numpy()
            for box, score, label in zip(boxes, scores, labels):
                xmin, ymin, xmax, ymax = box
                ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color='red', linewidth=2))
                ax.text(xmin, ymin, f'{class_names[label]}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10, color='black')
            plt.axis('off')
            plt.show()
            images_shown += 1
            if images_shown >= num_images:
                return

# Usage example:
plot_predictions(model, processor, test_loader, device, class_list, num_images=6, score_threshold=0.5)