# DETR Smoke Detection Training on Pyronear Dataset (Colab Ready)

This notebook trains a [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50) model on the Pyronear smoke dataset, downloading data directly from Hugging Face. It uses PyTorch Lightning and torchmetrics to provide detailed YOLO-style metrics (mAP@0.5, mAP@0.5:0.95, precision, recall, etc.) during training.


## 📦 Setup and Install Dependencies

In [None]:
!pip install -q torch torchvision pytorch-lightning torchmetrics transformers datasets huggingface_hub pycocotools opencv-python

## 📚 Imports

In [None]:
import os
import json
from PIL import Image
import torch
from torchvision.datasets import CocoDetection
from torchvision import transforms
from torch.utils.data import DataLoader
from transformers import DetrImageProcessor, DetrForObjectDetection
import matplotlib.pyplot as plt
from datasets import load_dataset
import pytorch_lightning as pl
from torchmetrics.detection.mean_ap import MeanAveragePrecision


## 📥 Download Pyronear Dataset from Hugging Face
We use the [datasets](https://huggingface.co/docs/datasets) library to download the Pyronear smoke dataset.

In [None]:
# Change to the correct dataset repo if needed
dataset = load_dataset('pyronear/smoke-detection-yolo', split='train')
# If the dataset is not in COCO format, conversion logic will be added below.

## 🔁 Convert YOLO to COCO format (if needed)
If the dataset is in YOLO format, convert it to COCO format for DETR compatibility.

In [None]:
# Add your YOLO-to-COCO conversion code here, or skip if dataset is already in COCO format.
# Example placeholder:
# coco_dict = convert_yolo_to_coco(dataset)


## 🗂️ Prepare DataLoaders
Wrap the COCO dataset for use with PyTorch Lightning.

In [None]:
# Example DataLoader setup (replace with your actual paths and logic)
transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = CocoDetection(img_folder, ann_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=lambda x: tuple(zip(*x)))


## ⚡ PyTorch Lightning Module with YOLO-style Metrics
We use torchmetrics' MeanAveragePrecision for mAP@0.5, mAP@0.5:0.95, precision, recall, etc.

In [None]:
class DETRLightningModule(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        self.model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50', num_labels=num_classes)
        self.processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
        self.map_metric = MeanAveragePrecision(class_metrics=True)

    def forward(self, pixel_values):
        return self.model(pixel_values)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        encoding = self.processor(images, return_tensors="pt", padding=True).to(self.device)
        labels = [{"class_labels": t['labels'], "boxes": t['boxes']} for t in targets]
        outputs = self.model(**encoding, labels=labels)
        loss = outputs.loss
        # Calculate metrics
        preds = outputs.logits
        # You may need to post-process predictions for torchmetrics
        self.map_metric.update(preds, labels)
        self.log_dict(self.map_metric.compute(), prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        encoding = self.processor(images, return_tensors="pt", padding=True).to(self.device)
        labels = [{"class_labels": t['labels'], "boxes": t['boxes']} for t in targets]
        outputs = self.model(**encoding, labels=labels)
        preds = outputs.logits
        self.map_metric.update(preds, labels)
        self.log_dict(self.map_metric.compute(), prog_bar=True, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4)


## 🚂 Train the Model
Set up the PyTorch Lightning Trainer and start training.

In [None]:
# Example usage (replace with your actual dataloader and num_classes)
model = DETRLightningModule(num_classes=2)
trainer = pl.Trainer(max_epochs=10, accelerator='gpu' if torch.cuda.is_available() else 'cpu')
trainer.fit(model, train_dataloader, val_dataloader)


## 📊 Visualize Metrics and Predictions
Plot or print the metrics after each epoch, and visualize sample predictions.

In [None]:
# 📊 Plot mAP, Precision, and Recall per epoch from Lightning logs

import matplotlib.pyplot as plt

def plot_metrics(trainer):
    # Extract metrics from the trainer's logger
    metrics = trainer.callback_metrics
    epochs = range(1, trainer.current_epoch + 2)

    # These keys may differ depending on your metric names/logs
    mAP_50 = [metrics.get(f"map_50_epoch_{e}", None) for e in range(trainer.current_epoch + 1)]
    mAP_95 = [metrics.get(f"map_epoch_{e}", None) for e in range(trainer.current_epoch + 1)]
    precision = [metrics.get(f"precision_epoch_{e}", None) for e in range(trainer.current_epoch + 1)]
    recall = [metrics.get(f"recall_epoch_{e}", None) for e in range(trainer.current_epoch + 1)]

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, mAP_50, label="mAP@0.5")
    plt.plot(epochs, mAP_95, label="mAP@0.5:0.95")
    plt.plot(epochs, precision, label="Precision")
    plt.plot(epochs, recall, label="Recall")
    plt.xlabel("Epoch")
    plt.ylabel("Metric Value")
    plt.title("Detection Metrics per Epoch")
    plt.legend()
    plt.grid()
    plt.show()

# Usage example (run after training):
# plot_metrics(trainer)


In [None]:
# 🖼️ Visualize sample predictions from the trained model

import torch
import matplotlib.pyplot as plt
import numpy as np

def plot_predictions(model, processor, dataloader, device, class_names, num_images=4, score_threshold=0.5):
    model.eval()
    images_shown = 0

    for images, targets in dataloader:
        # Move images to device
        pixel_values = processor(images, return_tensors="pt", padding=True).pixel_values.to(device)
        with torch.no_grad():
            outputs = model(pixel_values=pixel_values)

        # Post-process outputs
        results = processor.post_process_object_detection(outputs, target_sizes=[img.size[::-1] for img in images], threshold=score_threshold)

        for idx, (image, result) in enumerate(zip(images, results)):
            plt.figure(figsize=(8, 6))
            plt.imshow(image)
            ax = plt.gca()
            boxes = result["boxes"].cpu().numpy()
            scores = result["scores"].cpu().numpy()
            labels = result["labels"].cpu().numpy()

            for box, score, label in zip(boxes, scores, labels):
                xmin, ymin, xmax, ymax = box
                ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color='red', linewidth=2))
                ax.text(xmin, ymin, f'{class_names[label]}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10, color='black')
            plt.axis('off')
            plt.show()
            images_shown += 1
            if images_shown >= num_images:
                return

# Usage example (after training):
# class_names = ["background", "smoke"]  # adjust as needed
# plot_predictions(model.model, model.processor, val_dataloader, device="cuda" if torch.cuda.is_available() else "cpu", class_names=class_names)