In [None]:
!unzip Annotations.zip

In [None]:
!unzip ImageSets.zip

In [None]:
!pip install torchmetrics

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import DetrForObjectDetection, DetrConfig
from PIL import Image
import xml.etree.ElementTree as ET

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Charger le modèle DOFA
dofa_model = torch.hub.load('zhu-xlab/DOFA', 'vit_base_dofa', pretrained=True).to(device)
# Désactiver la mise à jour des poids de DOFA pour l'utiliser comme extracteur de features figé
for param in dofa_model.parameters():
    param.requires_grad = False

class DOFAFeatureExtractor(nn.Module):
    def __init__(self, dofa_model):
        super().__init__()
        self.dofa = dofa_model

    def forward(self, pixel_values, wavelengths=None):
        patch_embed_output = self.dofa.patch_embed(pixel_values, wvs=wavelengths)
        # Si l'encodage retourne un tuple, on sélectionne le premier élément (les features)
        if isinstance(patch_embed_output, tuple):
            features = patch_embed_output[0]  
        else:
            features = patch_embed_output
        # Normalisation des features avec la couche fully-connected de normalisation du modèle DOFA
        features = self.dofa.fc_norm(features)

        # Passage des features à travers chaque bloc du modèle DOFA 
        for block in self.dofa.blocks:
            features = block(features)

        return features

# init DOFAFeatureExtractor
feature_extractor = DOFAFeatureExtractor(dofa_model).to(device)

class DIORDataset(Dataset):
    def __init__(self, images_dir, annotation_dir_hbb, annotation_dir_obb, txt_path, transform=None):
        """
        Initializes the DIORDataset.

        Args:
            images_dir (str): Path to the directory containing the images.
            annotation_dir_hbb (str): Path to the directory containing horizontal bounding box annotations.
            annotation_dir_obb (str): Path to the directory containing oriented bounding box annotations.
            txt_path (str): Path to the text file containing the list of image filenames for the split.
            transform (callable, optional): Optional transform to be applied to the images. Defaults to None.
        """
        self.images_dir = images_dir
        self.annotation_dir_hbb = annotation_dir_hbb
        self.annotation_dir_obb = annotation_dir_obb
        self.transform = transform
        self.img_names = self._load_image_names(txt_path)
        self.classes = ["airplane", "airport", "baseballfield", "basketballcourt", "bridge",
                        "chimney", "dam", "expresswayservicearea", "expresswaytollstation",
                        "golffield", "groundtrackfield", "harbor", "overpass", "ship",
                        "stadium", "storagetank", "tenniscourt", "trainstation", "vehicle", "windmill"]
        self.class_to_index = {label: i for i, label in enumerate(self.classes)}

    def _load_image_names(self, txt_path):
        """Loads the list of image filenames from the provided text file."""
        with open(txt_path, 'r') as f:
            return [line.strip() for line in f]

    def __len__(self):
        """Returns the total number of items in the dataset."""
        return len(self.img_names)

    def __getitem__(self, idx):
        """
        Retrieves an item from the dataset at the specified index.

        Args:
            idx (int): Index of the item to retrieve.

        Returns:
            tuple: A tuple containing the processed image and the target dictionary.
        """
        img_name = self.img_names[idx]
        img_path = os.path.join(self.images_dir, img_name + '.jpg')
        annotation_name_hbb = img_name + '.xml'
        annotation_path_hbb = os.path.join(self.annotation_dir_hbb, annotation_name_hbb)
        annotation_name_obb = img_name + '.xml'
        annotation_path_obb = os.path.join(self.annotation_dir_obb, annotation_name_obb)

        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image file not found: {img_path}")

        image = Image.open(img_path).convert("RGB")
        w, h = image.size  # get original size before transform
        if self.transform:
            image = self.transform(image)

        boxes = []
        labels = []

        # Try to load horizontal bounding box annotations first
        if os.path.exists(annotation_path_hbb):
            tree = ET.parse(annotation_path_hbb)
            root = tree.getroot()
            for obj in root.findall('.//object'):
                name = obj.find('name').text.lower().replace('-', '').replace(' ', '')
                if obj.find('bndbox') is not None:
                    bndbox = obj.find('bndbox')
                    xmin = int(bndbox.find('xmin').text)
                    ymin = int(bndbox.find('ymin').text)
                    xmax = int(bndbox.find('xmax').text)
                    ymax = int(bndbox.find('ymax').text)
                    boxes.append([xmin, ymin, xmax, ymax])
                    labels.append(name)
        # If no horizontal bounding box annotations, try to load oriented bounding box annotations
        elif os.path.exists(annotation_path_obb):
            tree = ET.parse(annotation_path_obb)
            root = tree.getroot()
            for obj in root.findall('.//object'):
                name = obj.find('name').text.lower().replace('-', '').replace(' ', '')
                if obj.find('robndbox') is not None:
                    robndbox = obj.find('robndbox')
                    x1 = int(robndbox.find('x_left_top').text)
                    y1 = int(robndbox.find('y_left_top').text)
                    x2 = int(robndbox.find('x_right_top').text)
                    y2 = int(robndbox.find('y_right_top').text)
                    x3 = int(robndbox.find('x_right_bottom').text)
                    y3 = int(robndbox.find('y_right_bottom').text)
                    x4 = int(robndbox.find('x_left_bottom').text)
                    y4 = int(robndbox.find('y_left_bottom').text)

                    min_x = min(x1, x2, x3, x4)
                    min_y = min(y1, y2, y3, y4)
                    max_x = max(x1, x2, x3, x4)
                    max_y = max(y1, y2, y3, y4)
                    boxes.append([min_x, min_y, max_x, max_y])
                    labels.append(name)

        boxes = torch.tensor(boxes, dtype=torch.float32)

        # Normalisation des boîtes en [0,1]
        boxes[:, [0, 2]] /= w  # Normaliser x
        boxes[:, [1, 3]] /= h  # Normaliser y

        # Clamp pour éviter dépassement ou division par zéro
        boxes = boxes.clamp(0, 1)
        labels = torch.tensor([self.class_to_index[label] for label in labels], dtype=torch.int64)

        target = {"boxes": boxes, "class_labels": labels}
        return image, target

train_txt_path = "Main/train.txt"
val_txt_path = "Main/val.txt"
test_txt_path = "Main/test.txt"

train_images_dir = "drive/MyDrive/JPEGImages-trainval"
test_images_dir = "drive/MyDrive/JPEGImages-test"

annotation_dir_hbb = "Annotations/Horizontal Bounding Boxes"
annotation_dir_obb = "Annotations/Oriented Bounding Boxes"


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create instances of the dataset for different splits
train_dataset = DIORDataset(train_images_dir, annotation_dir_hbb, annotation_dir_obb, train_txt_path, transform=transform)
val_dataset = DIORDataset(train_images_dir, annotation_dir_hbb, annotation_dir_obb, val_txt_path, transform=transform)
test_dataset = DIORDataset(test_images_dir, annotation_dir_hbb, annotation_dir_obb, test_txt_path, transform=transform)



train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))




#  FakeBackbone : remplace le backbone ResNet par DOFA
class FakeBackbone(nn.Module):
    def __init__(self, feature_extractor, feature_dim):
        super().__init__()
        self.feature_extractor = feature_extractor 
        self.proj = nn.Linear(feature_dim, 2048)  #  match les attentes de DETR

    def forward(self, pixel_values, *args, **kwargs):
        batch_size = pixel_values.shape[0]
        device = pixel_values.device
        wavelengths = torch.tensor([0.665, 0.56, 0.49], dtype=torch.float32).to(device)

        # DOFA features [B, N, D]
        features = self.feature_extractor(pixel_values, wavelengths)

        # Projection → 2048 canaux
        features = self.proj(features)  # [B, N, 2048]

        # Reshape pour format image-like [B, 2048, H, W]
        h = w = int(features.shape[1] ** 0.5)
        features = features.permute(0, 2, 1).contiguous()     # [B, 2048, N]
        features = features.view(batch_size, 2048, h, w)      # [B, 2048, H, W]

        mask = torch.zeros((batch_size, h, w), dtype=torch.bool, device=device)

        # Créer un dummy object_queries_list pour satisfaire DETR
        object_queries_list = [torch.zeros((features.shape[0], 256, 1), device=device)]

        # Retour format attendu : (features_list, object_queries_list)
        return [(features, mask)], object_queries_list



#  DetrWithDOFA : instancie DETR avec le FakeBackbone
class DetrWithDOFA(nn.Module):
    def __init__(self, feature_extractor):
        super().__init__()
        self.feature_extractor = feature_extractor

       # Charger la config sans backbone
        detr_config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
        detr_config.use_pretrained_backbone = False
        detr_config.num_labels = 20
        self.detr_model = DetrForObjectDetection(detr_config)

        # Charger uniquement les poids utiles (transformer)
        pretrained = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
        filtered_weights = {
            k: v for k, v in pretrained.state_dict().items()
            if "transformer" in k
        }
        self.detr_model.load_state_dict(filtered_weights, strict=False)
        #  Extraire la dimension des features de DOFA
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224).to(device)
            dummy_wvs = torch.tensor([0.665, 0.56, 0.49], dtype=torch.float32).to(device)
            feats = self.feature_extractor(dummy, dummy_wvs)  # [1, N, D]
            self.feature_dim = feats.shape[2]  # D

        #  Remplacer le backbone ResNet par FakeBackbone basé sur DOFA
        self.detr_model.model.backbone = FakeBackbone(self.feature_extractor, self.feature_dim)

    def forward(self, images, targets=None):
        # Appel direct à DetrForObjectDetection, les features sont gérées par FakeBackbone
        return self.detr_model(pixel_values=images, labels=targets)



# Créer un nouveau modèle DETR avec DOFA comme backbone
detr_with_dofa = DetrWithDOFA(feature_extractor).to(device)



def train_model(detr_with_dofa, train_loader, val_loader, epochs=1, lr=1e-4):
    optimizer = optim.Adam(detr_with_dofa.detr_model.parameters(), lr=lr)
    detr_with_dofa.train()

    train_losses, val_losses = [], []

    for epoch in range(epochs):
        detr_with_dofa.train()
        running_train_loss = 0.0

        for i, (images, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Training]")):
            images = torch.stack(images).to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            optimizer.zero_grad()

            #print("Target boxes avant l'appel au modèle:", targets)
            outputs = detr_with_dofa(images, targets)
            #print("Shape de outputs.pred_boxes:", outputs.pred_boxes.shape)
            #print("Exemple de outputs.pred_boxes:", outputs.pred_boxes[0])
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item() * images.size(0)

            # Debugging: Inspect the first batch of the first epoch
            if epoch == 0 and i == 0:
                print("First batch - Predicted boxes:", outputs.pred_boxes)
                print("First batch - Target boxes:", targets[0]['boxes'])

        avg_train_loss = running_train_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}")

        detr_with_dofa.eval()
        running_val_loss = 0.0

        with torch.no_grad():
            for images, targets in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Validation]"):
                images = torch.stack(images).to(device)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                outputs = detr_with_dofa(images, targets)
                loss = outputs.loss
                running_val_loss += loss.item() * images.size(0)

        avg_val_loss = running_val_loss / len(val_loader.dataset)
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss:.4f}")

    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.show()

# Lancer l'entraînement avec le modèle DETR et DOFA comme backbone
train_model(detr_with_dofa, train_loader, val_loader, epochs=100, lr=1e-4)

In [None]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision

def compute_metrics(predicted_boxes, true_boxes, predicted_labels, scores, true_labels, class_names):
    """
    Compute mAP@0.5, mAP@0.5:0.95, AP per class using torchmetrics.
    """
    metric = MeanAveragePrecision(
        iou_type="bbox",
        class_metrics=True,
    ).to(device)

    preds = []
    targets = []

    for pred_box, pred_label, pred_score, gt_box, gt_label in zip(predicted_boxes, predicted_labels, scores, true_boxes, true_labels):
        preds.append({
            "boxes": torch.tensor(pred_box, device=device),
            "scores": torch.tensor(pred_score, device=device),
            "labels": torch.tensor(pred_label, device=device)
        })

        targets.append({
            "boxes": torch.tensor(gt_box, device=device),
            "labels": torch.tensor(gt_label, device=device)
        })

    metric.update(preds, targets)
    result = metric.compute()

    AP_per_class = {}
    if "map_per_class" in result and result["map_per_class"] is not None:
        for label_idx, ap in zip(result["classes"], result["map_per_class"]):
            label_idx = int(label_idx)
            if label_idx < len(class_names):
                class_name = class_names[label_idx]
                AP_per_class[class_name] = ap.item() * 100

    metrics = {
        "mAP@0.5": result["map_50"].item() * 100,
        "mAP@0.5:0.95": result["map"].item() * 100,
        "AP_per_class": AP_per_class
    }
    return metrics



In [None]:
def test_model(detr_with_dofa, test_loader, class_names):
    detr_with_dofa.eval()
    predicted_boxes = []
    true_boxes = []
    predicted_labels = []
    scores = []
    true_labels = []

    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc="Testing"):
            images = torch.stack(images).to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            outputs = detr_with_dofa(images)

            # Pour chaque image du batch
            for i in range(len(images)):
                pred_boxes = outputs.pred_boxes[i].cpu().numpy()
                pred_logits = outputs.logits[i].softmax(-1).cpu().numpy()

                # Retirer la classe "no object" (la dernière)
                scores_per_image = pred_logits[:, :-1].max(axis=1)  # meilleures scores par classe
                labels_per_image = pred_logits[:, :-1].argmax(axis=1)  # labels correspondants

                predicted_boxes.append(pred_boxes)
                scores.append(scores_per_image)
                predicted_labels.append(labels_per_image)

                true_box = targets[i]['boxes'].cpu().numpy()
                label_true = targets[i]['class_labels'].cpu().numpy()
                true_boxes.append(true_box)
                true_labels.append(label_true)

    metrics = compute_metrics(predicted_boxes, true_boxes, predicted_labels, scores, true_labels, class_names)

    return metrics


In [None]:
metrics = test_model(detr_with_dofa, test_loader, train_dataset.classes)  # ou test_dataset.classes si c'est mieux

print(f"mAP@0.5: {metrics['mAP@0.5']:.8f}%")
print(f"mAP@0.5:0.95: {metrics['mAP@0.5:0.95']:.8f}%")

for class_name, ap in metrics['AP_per_class'].items():
    print(f"{class_name}: {ap:.8f}%")
