In [None]:
!pip install torchgeo torchvision torch

In [None]:
!pip install rasterio

In [None]:
from torchgeo.datasets import EuroSAT


# Charger le dataset EuroSAT
dataset = EuroSAT(root="data", download=True)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import rasterio
from torch.utils.data import Dataset, DataLoader
import glob
import os
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, mean_absolute_error, mean_squared_error, confusion_matrix
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Chargement du modèle DOFA pré-entraîné
model = torch.hub.load('zhu-xlab/DOFA', 'vit_base_dofa', pretrained=True).to(device)


# Adapter la tête du modèle pour 10 classes (EuroSAT)
num_classes = 10
model.head = nn.Linear(model.head.in_features, num_classes).to(device)



# Longueurs d'onde Sentinel-2
wavelengths = torch.tensor(
    [0.443, 0.490, 0.560, 0.665, 0.705, 0.740, 0.783, 0.842, 0.865, 0.945, 1.375, 1.610, 2.190],
    dtype=torch.float32
).to(device)

# Classe Dataset pour charger et prétraiter les images Sentinel-2
class Sentinel2Dataset(Dataset):
    def __init__(self, file_list, root_dir, transform=None):
        self.file_list = [line.strip() for line in open(file_list, 'r')]
        self.root_dir = root_dir
        self.transform = transform
        self.label_map = {cls: idx for idx, cls in enumerate([
            "AnnualCrop", "Forest", "HerbaceousVegetation", "Highway",
            "Industrial", "Pasture", "PermanentCrop", "Residential",
            "River", "SeaLake"
        ])}

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_name = self.file_list[idx]
        class_name = img_name.split('_')[0]
        img_path = os.path.join(self.root_dir, class_name, img_name.replace(".jpg", ".tif"))

        with rasterio.open(img_path) as src:
            img = src.read().astype(np.float32)

        img = torch.from_numpy(img)
        label = self.label_map[class_name]

        if self.transform:
            img = self.transform(img)

        return img, label

# Calcul de la moyenne et de l'écart-type sur les 13 bandes spectrales
def compute_mean_std(train_loader):
    total_sum = torch.zeros(13)
    total_sum_sq = torch.zeros(13)
    total_pixels = 0

    for images, _ in tqdm(train_loader, desc="Calcul des stats"):
        batch_size, channels, H, W = images.shape
        images = images.view(batch_size, channels, -1)

        total_sum += images.sum(dim=(0, 2))
        total_sum_sq += (images ** 2).sum(dim=(0, 2))
        total_pixels += batch_size * H * W

    mean = total_sum / total_pixels
    variance = total_sum_sq / total_pixels - mean**2
    std = torch.sqrt(variance)

    return mean, std

# Définition des transformations
class DataAugmentation(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.transform = transforms.Compose([
            #Découpe aléatoire et redimensionnement de l'image à 224x224
            transforms.RandomResizedCrop(size=(224, 224), scale=(0.2, 1.0)),
            #Normalisation
            transforms.Normalize(mean=mean, std=std)
        ])

    def forward(self, x):
       # Applique les transformations à l'image
        return self.transform(x)


Using cache found in /root/.cache/torch/hub/zhu-xlab_DOFA_master


In [None]:
# Chargement des datasets
data_dir = "data/ds/images/remote_sensing/otherDatasets/sentinel_2/tif/"
dir_fic_txt = "data/"

trainset = Sentinel2Dataset(os.path.join(dir_fic_txt, "eurosat-train.txt"), data_dir)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)

# Calcul automatique des statistiques
S2_MEAN_train, S2_STD_train = compute_mean_std(train_loader)

# Transformation avec valeurs calculées
transform = DataAugmentation(mean=S2_MEAN_train, std=S2_STD_train)

# Chargement des datasets avec transformations
trainset = Sentinel2Dataset(os.path.join(dir_fic_txt, "eurosat-train.txt"), data_dir, transform=transform)
valset = Sentinel2Dataset(os.path.join(dir_fic_txt, "eurosat-val.txt"), data_dir, transform=transform)
testset = Sentinel2Dataset(os.path.join(dir_fic_txt, "eurosat-test.txt"), data_dir, transform=transform)

# Création des DataLoaders
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
val_loader = DataLoader(valset, batch_size=128, shuffle=False)
test_loader = DataLoader(testset, batch_size=128, shuffle=False)

In [11]:
# Définition de la fonction de perte et de l’optimiseur
criterion = nn.CrossEntropyLoss()

# Geler tous les paramètres sauf la tête du modèle
for param in model.parameters():
    param.requires_grad = False
for param in model.head.parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.head.parameters(), lr=1e-3, weight_decay=1e-4)

In [None]:
# Entraînement du modèle
num_epochs = 26
train_losses, val_losses = [], []

for epoch in range(num_epochs):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad() # Réinitialiser les gradients avant la rétropropagation

        outputs = model(images, wave_list=wavelengths) # DOFA
        loss = criterion(outputs, labels) # Calculer la perte (difference entre prédictions et vraies étiquettes)
        loss.backward() #rétropropagation pour calculer les gradients
        optimizer.step() # Mettre à jour les poids du modèle avec les gradients calculés


        running_loss += loss.item() * images.size(0) # Prendre la classe avec la probabilité maximale
        correct += (torch.max(outputs, 1)[1] == labels).sum().item() # Compter les prédictions correctes
        total += labels.size(0)

    train_losses.append(running_loss / len(train_loader.dataset))
    train_acc = correct / total

    #  Évaluation sur validation
    model.eval()
    running_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images, wave_list=wavelengths)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            correct += (torch.max(outputs, 1)[1] == labels).sum().item()
            total += labels.size(0)

    val_losses.append(running_loss / len(val_loader.dataset))
    val_acc = correct / total

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_losses[-1]:.6f}, Train Acc: {train_acc:.4f}, Val Loss: {val_losses[-1]:.6f}, Val Acc: {val_acc:.4f}")

# Affichage des métriques
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.show()

Epoch 1/26 [Training]: 100%|██████████| 254/254 [05:00<00:00,  1.18s/it]


Epoch 1/26, Train Loss: 0.814686, Train Acc: 0.7792, Val Loss: 0.539640, Val Acc: 0.8487


Epoch 2/26 [Training]: 100%|██████████| 254/254 [04:59<00:00,  1.18s/it]


Epoch 2/26, Train Loss: 0.446896, Train Acc: 0.8747, Val Loss: 0.403905, Val Acc: 0.8819


Epoch 3/26 [Training]: 100%|██████████| 254/254 [04:59<00:00,  1.18s/it]


Epoch 3/26, Train Loss: 0.372563, Train Acc: 0.8960, Val Loss: 0.365590, Val Acc: 0.8965


Epoch 4/26 [Training]: 100%|██████████| 254/254 [05:00<00:00,  1.18s/it]


Epoch 4/26, Train Loss: 0.328978, Train Acc: 0.9069, Val Loss: 0.331916, Val Acc: 0.9035


Epoch 5/26 [Training]: 100%|██████████| 254/254 [05:01<00:00,  1.19s/it]


Epoch 5/26, Train Loss: 0.301908, Train Acc: 0.9129, Val Loss: 0.302846, Val Acc: 0.9131


Epoch 6/26 [Training]: 100%|██████████| 254/254 [05:04<00:00,  1.20s/it]


In [None]:
# Évaluation sur le test set
all_preds, all_labels = [], []
with torch.no_grad():
    model.eval()
    for images, labels in tqdm(test_loader, desc="Test"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images, wave_list=wavelengths)
        # Enregistrer les prédictions et les labels pour les métriques
        all_preds.extend(torch.max(outputs, 1)[1].cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Matrice de confusion
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, cmap="Blues", xticklabels=trainset.label_map.keys(), yticklabels=trainset.label_map.keys())
plt.show()

# Affichage des scores
print(f"Accuracy: {np.sum(np.diag(cm)) / np.sum(cm):.4f}")
print(f"Precision: {precision_score(all_labels, all_preds, average='weighted'):.4f}")
print(f"Recall: {recall_score(all_labels, all_preds, average='weighted'):.4f}")
print(f"F1 Score: {f1_score(all_labels, all_preds, average='weighted'):.4f}")