## 📦 Setup and Install

In [None]:
!pip install -q transformers torchvision pycocotools opencv-python


## 📚 Imports

In [None]:
import os
import json
from PIL import Image
import torch
from torchvision.datasets import CocoDetection
from torchvision import transforms
from torch.utils.data import DataLoader
from transformers import DetrImageProcessor, DetrForObjectDetection
import matplotlib.pyplot as plt


## 🔁 Convert YOLO to COCO format

In [None]:
def yolo_to_coco(images_dir, labels_dir, output_json, category_name="smoke"):
    images = []
    annotations = []
    categories = [{"id": 1, "name": category_name}]
    ann_id = 0
    image_id = 0

    for filename in sorted(os.listdir(images_dir)):
        if not filename.endswith(".jpg"):
            continue

        image_path = os.path.join(images_dir, filename)
        label_path = os.path.join(labels_dir, filename.replace(".jpg", ".txt"))
        img = Image.open(image_path)
        width, height = img.size

        images.append({
            "id": image_id,
            "file_name": filename,
            "width": width,
            "height": height
        })

        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                for line in f.readlines():
                    cls_id, x_center, y_center, w, h = map(float, line.strip().split())
                    x = (x_center - w / 2) * width
                    y = (y_center - h / 2) * height
                    w *= width
                    h *= height

                    annotations.append({
                        "id": ann_id,
                        "image_id": image_id,
                        "category_id": 1,
                        "bbox": [x, y, w, h],
                        "area": w * h,
                        "iscrowd": 0
                    })
                    ann_id += 1

        image_id += 1

    coco_dict = {
        "images": images,
        "annotations": annotations,
        "categories": categories
    }

    with open(output_json, "w") as f:
        json.dump(coco_dict, f, indent=2)

    print(f"COCO annotations saved to {output_json}")


## ⚙️ Convert Your Dataset

In [None]:
# Update these paths to your Kaggle or local environment
images_dir = "/kaggle/input/pyro_sdis_yolo/images"
labels_dir = "/kaggle/input/pyro_sdis_yolo/labels"
output_json = "/kaggle/working/train_annotations.json"

yolo_to_coco(images_dir, labels_dir, output_json)


## 📂 Load COCO Dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((800, 800)),
    transforms.ToTensor()
])

dataset = CocoDetection(
    root=images_dir,
    annFile=output_json,
    transform=transform
)

data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: list(zip(*x)))


## 🧠 Load Pretrained DETR

In [None]:
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


## 🏋️ Train DETR

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

model.train()
for epoch in range(3):  # Train for 3 epochs (adjust as needed)
    for images, targets in data_loader:
        encoding = processor(images=images, annotations=targets, return_tensors="pt", padding=True)
        encoding = {k: v.to(device) for k, v in encoding.items()}

        outputs = model(**encoding)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1}: Loss = {loss.item()}")

torch.save(model.state_dict(), "/kaggle/working/detr_smoke.pth")


## 🔍 Inference and Visualization

In [None]:
model.eval()
image_path = os.path.join(images_dir, os.listdir(images_dir)[0])
image = Image.open(image_path).convert("RGB")

inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]]).to(device)
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.8)[0]

# Plot results
plt.imshow(image)
ax = plt.gca()
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    xmin, ymin, xmax, ymax = box
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                               fill=False, color='red', linewidth=2))
    ax.text(xmin, ymin, f"{score:.2f}", fontsize=12, color='white', bbox=dict(facecolor='red', alpha=0.5))
plt.axis('off')
plt.show()
