In [5]:
import os
import torch
from torchvision import datasets
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np

## Étape 1 : Chargement des datasets

In [6]:
dataset_root = "ipeo_hurricane_for_students"

# Transform simple : juste resize et ToTensor (pas de normalisation)
transform = Compose([
    Resize((224, 224)),
    ToTensor()
])

# Charger tous les datasets
print("Chargement des datasets...")
train_ds = datasets.ImageFolder(os.path.join(dataset_root, "train"), transform=transform)
val_ds = datasets.ImageFolder(os.path.join(dataset_root, "validation"), transform=transform)
test_ds = datasets.ImageFolder(os.path.join(dataset_root, "test"), transform=transform)

print(f"Train : {len(train_ds)} images")
print(f"Validation : {len(val_ds)} images")
print(f"Test : {len(test_ds)} images")

# Fusionner tous les datasets
combined_dataset = ConcatDataset([train_ds, val_ds, test_ds])
print(f"\nTotal : {len(combined_dataset)} images")

Chargement des datasets...
Train : 19000 images
Validation : 2000 images
Test : 2000 images

Total : 23000 images


## Étape 2 : Calcul des statistiques

In [7]:
# Créer un DataLoader
batch_size = 32
loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print("Calcul des statistiques en cours...\n")

# Accumulator pour mean et std
mean_accumulator = torch.zeros(3)
std_accumulator = torch.zeros(3)
total_pixels = 0

# Première passe : calculer la moyenne
print("Passe 1 : calcul de la moyenne...")
for batch_idx, (images, _) in enumerate(loader):
    # images shape: (batch_size, 3, 224, 224)
    batch_size_actual = images.shape[0]
    height, width = images.shape[2], images.shape[3]
    
    # Remodeler : (batch_size * height * width, 3)
    images_reshaped = images.permute(0, 2, 3, 1).reshape(-1, 3)
    
    # Accumuler
    mean_accumulator += images_reshaped.sum(dim=0)
    total_pixels += images_reshaped.shape[0]
    
    if (batch_idx + 1) % 5 == 0:
        print(f"  Batch {batch_idx + 1}/{len(loader)}")

# Moyenne globale
mean = mean_accumulator / total_pixels
print(f"\nMoyenne calculée : {mean}")

# Deuxième passe : calculer l'écart-type
print("\nPasse 2 : calcul de l'écart-type...")
total_pixels = 0
for batch_idx, (images, _) in enumerate(loader):
    # Remodeler pour le calcul
    images_reshaped = images.permute(0, 2, 3, 1).reshape(-1, 3)
    
    # Calculer la variance
    diff_squared = (images_reshaped - mean.unsqueeze(0)) ** 2
    std_accumulator += diff_squared.sum(dim=0)
    total_pixels += images_reshaped.shape[0]
    
    if (batch_idx + 1) % 5 == 0:
        print(f"  Batch {batch_idx + 1}/{len(loader)}")

# Écart-type global
std = torch.sqrt(std_accumulator / total_pixels)

print(f"\nÉcart-type calculé : {std}")


Calcul des statistiques en cours...

Passe 1 : calcul de la moyenne...
  Batch 5/719
  Batch 10/719
  Batch 15/719
  Batch 20/719
  Batch 25/719
  Batch 30/719
  Batch 35/719
  Batch 40/719
  Batch 45/719
  Batch 50/719
  Batch 55/719
  Batch 60/719
  Batch 65/719
  Batch 70/719
  Batch 75/719
  Batch 80/719
  Batch 85/719
  Batch 90/719
  Batch 95/719
  Batch 100/719
  Batch 105/719
  Batch 110/719
  Batch 115/719
  Batch 120/719
  Batch 125/719
  Batch 130/719
  Batch 135/719
  Batch 140/719
  Batch 145/719
  Batch 150/719
  Batch 155/719
  Batch 160/719
  Batch 165/719
  Batch 170/719
  Batch 175/719
  Batch 180/719
  Batch 185/719
  Batch 190/719
  Batch 195/719
  Batch 200/719
  Batch 205/719
  Batch 210/719
  Batch 215/719
  Batch 220/719
  Batch 225/719
  Batch 230/719
  Batch 235/719
  Batch 240/719
  Batch 245/719
  Batch 250/719
  Batch 255/719
  Batch 260/719
  Batch 265/719
  Batch 270/719
  Batch 275/719
  Batch 280/719
  Batch 285/719
  Batch 290/719
  Batch 295/719
  Bat

## Résultats

In [4]:
print("="*60)
print("STATISTIQUES DE VOTRE DATASET")
print("="*60)
print(f"\nMean par canal (RGB) :")
print(f"  {mean.tolist()}")
print(f"\nStandard Deviation par canal (RGB) :")
print(f"  {std.tolist()}")

print(f"\n\n" + "="*60)
print("CODE À COPIER DANS VOTRE NOTEBOOK")
print("="*60)
print(f"""
import torch
import torchvision.transforms as T

mean = torch.tensor({mean.tolist()})
std = torch.tensor({std.tolist()})

normalize = T.Normalize(mean, std)
default_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    normalize
])
""")

STATISTIQUES DE VOTRE DATASET

Mean par canal (RGB) :
  [1.022279143333435, 1.022279143333435, 1.022279143333435]

Standard Deviation par canal (RGB) :
  [0.6852813363075256, 0.6502869129180908, 0.7472431659698486]


CODE À COPIER DANS VOTRE NOTEBOOK

import torch
import torchvision.transforms as T

mean = torch.tensor([1.022279143333435, 1.022279143333435, 1.022279143333435])
std = torch.tensor([0.6852813363075256, 0.6502869129180908, 0.7472431659698486])

normalize = T.Normalize(mean, std)
default_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    normalize
])

