# DETR Training on Roboflow YOLO-format Wildfire Smoke Dataset (Colab Ready)

This notebook is designed for Google Colab and will train DETR on a YOLO-format wildfire smoke dataset structured as Roboflow exports, with splits for train, valid, and test. It automatically parses the dataset structure and class names from `data.yaml`.


## 1. Mount Google Drive

In [ ]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Install Dependencies

In [ ]:
!pip install -q transformers torchvision pytorch-lightning torchmetrics pycocotools ruamel.yaml opencv-python

## 3. Parse Dataset Structure and Class Names

In [ ]:
import os
from ruamel.yaml import YAML

data_root = '/content/drive/MyDrive/smoke_dataset/roboflow'
yaml_path = os.path.join(data_root, 'data.yaml')
yaml = YAML()
with open(yaml_path, 'r') as f:
    data_yaml = yaml.load(f)

# Get class names
class_list = data_yaml['names'] if 'names' in data_yaml else data_yaml['nc']
print('Class names:', class_list)

splits = ['train', 'valid', 'test']
for split in splits:
    print(f'{split} images:', os.listdir(os.path.join(data_root, split, 'images'))[:3])
    print(f'{split} labels:', os.listdir(os.path.join(data_root, split, 'labels'))[:3])


## 4. Convert YOLO to COCO for Each Split

In [ ]:
import glob
import json
from PIL import Image
from tqdm import tqdm

def yolo_to_coco(img_dir, label_dir, class_list, output_json):
    images = []
    annotations = []
    annotation_id = 1
    img_files = sorted(glob.glob(os.path.join(img_dir, '*')))
    for img_id, img_path in enumerate(tqdm(img_files), 1):
        img = Image.open(img_path)
        width, height = img.size
        filename = os.path.basename(img_path)
        images.append({
            'id': img_id,
            'file_name': filename,
            'width': width,
            'height': height
        })
        label_path = os.path.join(label_dir, filename.rsplit('.', 1)[0] + '.txt')
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) == 5:
                        class_id, x_center, y_center, w, h = map(float, parts)
                        x = (x_center - w/2) * width
                        y = (y_center - h/2) * height
                        w_box = w * width
                        h_box = h * height
                        annotations.append({
                            'id': annotation_id,
                            'image_id': img_id,
                            'category_id': int(class_id) + 1,
                            'bbox': [x, y, w_box, h_box],
                            'area': w_box * h_box,
                            'iscrowd': 0
                        })
                        annotation_id += 1
    categories = [{"id": i+1, "name": name} for i, name in enumerate(class_list)]
    coco_dict = {'images': images, 'annotations': annotations, 'categories': categories}
    with open(output_json, 'w') as f:
        json.dump(coco_dict, f)
    print(f'COCO annotation saved to {output_json}')
    return coco_dict

coco_jsons = {}
for split in splits:
    img_dir = os.path.join(data_root, split, 'images')
    label_dir = os.path.join(data_root, split, 'labels')
    output_json = os.path.join(data_root, f'{split}_coco.json')
    coco_jsons[split] = output_json
    yolo_to_coco(img_dir, label_dir, class_list, output_json)


## 5. Prepare DataLoaders for Each Split

In [ ]:
from torchvision.datasets import CocoDetection
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([transforms.ToTensor()])

datasets = {}
dataloaders = {}
for split in splits:
    img_dir = os.path.join(data_root, split, 'images')
    coco_json = coco_jsons[split]
    datasets[split] = CocoDetection(img_dir, coco_json, transform=transform)
    dataloaders[split] = DataLoader(datasets[split], batch_size=4, shuffle=(split=='train'), num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

print('Train set size:', len(datasets['train']))
print('Valid set size:', len(datasets['valid']))
print('Test set size:', len(datasets['test']))


## 6. Define DETR Lightning Module (with metrics logging)

In [ ]:
from transformers import DetrForObjectDetection, DetrConfig, DetrImageProcessor
import torch
import pytorch_lightning as pl
from torchmetrics.detection.mean_ap import MeanAveragePrecision

class DETRLightningModule(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        config = DetrConfig.from_pretrained('facebook/detr-resnet-50')
        config.num_labels = num_classes
        self.model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50', config=config, ignore_mismatched_sizes=True)
        self.processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
        self.map_metric = MeanAveragePrecision(class_metrics=True)

    def forward(self, pixel_values):
        return self.model(pixel_values)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        encoding = self.processor(images, return_tensors="pt").to(self.device)
        labels = [{"class_labels": t["labels"], "boxes": t["boxes"]} for t in targets]
        outputs = self.model(**encoding, labels=labels)
        loss = outputs.loss

        if isinstance(images[0], torch.Tensor):
            target_sizes = torch.stack([torch.tensor(img.shape[-2:]) for img in images]).to(self.device)
        else:
            target_sizes = torch.tensor([img.size[::-1] for img in images]).to(self.device)

        results = self.processor.post_process_object_detection(
            outputs, target_sizes=target_sizes, threshold=0.5
        )
        for r in results:
            r["boxes"] = r["boxes"].to(self.device)
            r["labels"] = r["labels"].to(self.device)
            r["scores"] = r["scores"].to(self.device)
        formatted_targets = []
        for t in targets:
            formatted_targets.append({
                "boxes": t["boxes"].to(self.device),
                "labels": t["labels"].to(self.device)
            })
        self.map_metric.update(results, formatted_targets)
        metrics = self.map_metric.compute()
        for k in list(metrics.keys()):
            if isinstance(metrics[k], torch.Tensor) and metrics[k].numel() > 1:
                valid = metrics[k][metrics[k] != -1]
                metrics[k + '_mean'] = valid.mean().item() if valid.numel() > 0 else float('nan')
                del metrics[k]
        scalar_metrics = {k: (v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else v) for k, v in metrics.items() if isinstance(v, (float, int)) or (isinstance(v, torch.Tensor) and v.numel() == 1)}
        self.log_dict(scalar_metrics, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        encoding = self.processor(images, return_tensors="pt").to(self.device)
        labels = [{"class_labels": t["labels"], "boxes": t["boxes"]} for t in targets]
        outputs = self.model(**encoding, labels=labels)
        loss = outputs.loss

        if isinstance(images[0], torch.Tensor):
            target_sizes = torch.stack([torch.tensor(img.shape[-2:]) for img in images]).to(self.device)
        else:
            target_sizes = torch.tensor([img.size[::-1] for img in images]).to(self.device)
        results = self.processor.post_process_object_detection(
            outputs, target_sizes=target_sizes, threshold=0.5
        )
        for r in results:
            r["boxes"] = r["boxes"].to(self.device)
            r["labels"] = r["labels"].to(self.device)
            r["scores"] = r["scores"].to(self.device)
        formatted_targets = []
        for t in targets:
            formatted_targets.append({
                "boxes": t["boxes"].to(self.device),
                "labels": t["labels"].to(self.device)
            })
        self.map_metric.update(results, formatted_targets)
        metrics = self.map_metric.compute()
        for k in list(metrics.keys()):
            if isinstance(metrics[k], torch.Tensor) and metrics[k].numel() > 1:
                valid = metrics[k][metrics[k] != -1]
                metrics[k + '_mean'] = valid.mean().item() if valid.numel() > 0 else float('nan')
                del metrics[k]
        scalar_metrics = {k: (v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else v) for k, v in metrics.items() if isinstance(v, (float, int)) or (isinstance(v, torch.Tensor) and v.numel() == 1)}
        self.log_dict(scalar_metrics, prog_bar=True, on_step=False, on_epoch=True)
        self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4)


## 7. Train the Model

In [ ]:
model = DETRLightningModule(num_classes=len(class_list))
trainer = pl.Trainer(max_epochs=10, accelerator='gpu' if torch.cuda.is_available() else 'cpu')
trainer.fit(model, dataloaders['train'], dataloaders['valid'])


## 8. Visualize Metrics and Predictions

In [ ]:
import matplotlib.pyplot as plt
import torch

def plot_predictions(model, processor, dataloader, device, class_names, num_images=4, score_threshold=0.5):
    model.eval()
    images_shown = 0
    for images, targets in dataloader:
        pixel_values = processor(images, return_tensors="pt").pixel_values.to(device)
        with torch.no_grad():
            outputs = model(pixel_values=pixel_values)
        results = processor.post_process_object_detection(outputs, target_sizes=[img.shape[1:] for img in images], threshold=score_threshold)
        for idx, (image, result) in enumerate(zip(images, results)):
            plt.figure(figsize=(8, 6))
            img_np = image.permute(1, 2, 0).cpu().numpy()
            plt.imshow(img_np)
            ax = plt.gca()
            boxes = result["boxes"].cpu().numpy()
            scores = result["scores"].cpu().numpy()
            labels = result["labels"].cpu().numpy()
            for box, score, label in zip(boxes, scores, labels):
                xmin, ymin, xmax, ymax = box
                ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color='red', linewidth=2))
                ax.text(xmin, ymin, f'{class_names[label]}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10, color='black')
            plt.axis('off')
            plt.show()
            images_shown += 1
            if images_shown >= num_images:
                return
# Usage example:
# plot_predictions(model.model, model.processor, dataloaders['test'], device="cuda" if torch.cuda.is_available() else "cpu", class_names=class_list)
