# Prétraitement et augmentation des données

### Division en ensembles Train/Validation/Test

In [3]:
import joblib
import numpy as np
from PIL import Image

images, labels = joblib.load("../data/processed_data.pkl")

### Encodage des étiquettes catégorielles avec LabelEncoder

In [None]:
import os
import random

train_split = 0.7
val_split = 0.15
test_split = 0.15

for subset in ["train", "val", "test"]:
    os.makedirs(os.path.join("../data/data_splite", subset), exist_ok=True)

classes = np.unique(labels)
for cls in classes:
    for subset in ["train", "val", "test"]:
        os.makedirs(os.path.join("../data/data_splite", subset, cls), exist_ok=True)

data = list(range(len(images)))
random.shuffle(data)

n_total = len(data)
n_train = int(0.7 * n_total)
n_val = int(0.15 * n_total)
n_test = n_total - n_train - n_val

splits = {
    "train": data[:n_train],
    "val": data[n_train:n_train + n_val],
    "test": data[n_train + n_val:]
}

image_count = {"train": 0, "val": 0, "test": 0}

for subset_name, subset_indices in splits.items():
    for idx_in_subset, idx in enumerate(subset_indices):
        img_array = images[idx]
        label = labels[idx]
        img = Image.fromarray(img_array)
        save_path = os.path.join("../data/data_splite", subset_name, label)
        img.save(os.path.join(save_path, f"{idx_in_subset}.png"))
        image_count[subset_name] += 1
        
for subset_name, count in image_count.items():
    print(f"{subset_name}: {count} images")

### Nombre d’images par classe dans le dossier d’entraînement

In [None]:
train_path = "../data/data_splite/train"
target_count = 0
for cls in os.listdir(train_path):
    path = os.path.join(train_path, cls)
    if os.path.isdir(path):
        count = len(os.listdir(path))
        print(f"{cls} : {count}")
    if count > target_count:
        target_count = count
print(target_count)

Benign : 358
early Pre-B : 702
Pre-B : 678
Pro-B : 531
702


### Augmentation et équilibrage des données d’entraînement par transformations d’images

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


base_transforms  = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])                          
])

augmentations = [
    transforms.RandomHorizontalFlip(p=1.0),
    transforms.ColorJitter(brightness=0.2),
    transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))
]

to_pil = transforms.ToPILImage()

for class_name in os.listdir(train_path):
    class_dir = os.path.join(train_path, class_name)
    images = os.listdir(class_dir)
    
    while len(os.listdir(class_dir)) < target_count:
        img_name = random.choice(images)
        img_path = os.path.join(class_dir, img_name)

        with Image.open(img_path).convert("RGB") as img:
            aug_img = random.choice(augmentations)(img)
            new_name = f"aug_{len(os.listdir(class_dir))}_{img_name}"
            aug_img.save(os.path.join(class_dir, new_name))


train_dataset = datasets.ImageFolder("../data/data_splite/train", transform=base_transforms)
val_dataset   = datasets.ImageFolder("../data/data_splite/val", transform=base_transforms)
test_dataset  = datasets.ImageFolder("../data/data_splite/test", transform=base_transforms)

print("Nombre d'images d'entraînement :", len(train_dataset))
print("Nombre d'images de validation :", len(val_dataset))
print("Nombre d'images de test :", len(test_dataset))




Nombre d'images d'entraînement : 2808
Nombre d'images de validation : 486
Nombre d'images de test : 487


In [None]:
from collections import Counter

labels = train_dataset.targets  

count_per_class = Counter(labels)

for idx, cls_name in enumerate(train_dataset.classes):
    print(f"{cls_name} : {count_per_class[idx]}")

Benign : 702
Pre-B : 702
Pro-B : 702
early Pre-B : 702


### Sauvegarde des datasets PyTorch pour réutilisation

In [None]:
import os
os.makedirs("../data/Data_Sets", exist_ok=True)


torch.save(train_dataset, "../data/Data_Sets/train_dataset.pt")
torch.save(val_dataset, "../data/Data_Sets/val_dataset.pt")
torch.save(test_dataset, "../data/Data_Sets/test_dataset.pt")