In [None]:
import os
import shutil
import random
from collections import defaultdict
from pathlib import Path

# ================= CONFIGURATION =================

BASE_DIR = Path('..').resolve()
# Dossier source (votre dataset g√©ant actuel)
SOURCE_DIR = BASE_DIR / "data" / "dataset" / "garbage_dataset"

# Dossier de destination (sera cr√©√© s'il n'existe pas)
DEST_DIR = BASE_DIR / "data" / "dataset" / "garbage_dataset_100_reduction" 

# Nombre total d'images voulu (ex: 1000)
TARGET_TOTAL = 100

# R√©partition (Train / Valid / Test)
# Notez l'usage de 'valid' pour respecter votre tree
SPLIT_RATIOS = {'train': 0.8, 'valid': 0.1, 'test': 0.1}

# Extensions d'images accept√©es
IMG_EXT = {'.jpg', '.jpeg', '.png', '.bmp'}
# =================================================

def get_class_from_label(label_path):
    """Lit le premier ID de classe dans le fichier .txt YOLO."""
    if not os.path.exists(label_path):
        return None
    try:
        with open(label_path, 'r') as f:
            line = f.readline()
            if not line.strip(): return None
            # Format YOLO : class_id x y w h
            return int(line.split()[0]) 
    except:
        return None

def main():
    print(f"üîç Analyse de {SOURCE_DIR}...")
    
    files_by_class = defaultdict(list)
    image_count = 0
    
    # 1. SCAN ET INDEXATION (M√™me logique que pr√©c√©demment)
    for root, _, files in os.walk(SOURCE_DIR):
        for file in files:
            ext = os.path.splitext(file)[1].lower()
            if ext in IMG_EXT:
                img_path = os.path.join(root, file)
                base_name = os.path.splitext(file)[0]
                
                # Recherche du label (m√©thode robuste)
                label_path = os.path.join(root, base_name + ".txt")
                if not os.path.exists(label_path):
                    # Essai structure parall√®le images/labels
                    label_path = img_path.replace(f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}")
                    label_path = os.path.splitext(label_path)[0] + ".txt"

                if os.path.exists(label_path):
                    class_id = get_class_from_label(label_path)
                    if class_id is not None:
                        files_by_class[class_id].append((img_path, label_path))
                        image_count += 1

    classes = list(files_by_class.keys())
    if not classes:
        print("‚ùå Erreur : Aucune donn√©e trouv√©e.")
        return

    print(f"‚úÖ Trouv√© {image_count} images sur {len(classes)} classes.")

    # 2. S√âLECTION √âQUILIBR√âE AM√âLIOR√âE
    quota = TARGET_TOTAL // len(classes)
    final_selection = []
    
    print(f"üìä Objectif : {quota} images/classe (total vis√©: {quota * len(classes)}).")
    
    # Statistiques par classe
    class_stats = {}
    
    for class_id in sorted(classes):
        pairs = files_by_class[class_id]
        available = len(pairs)
        
        if available < quota:
            print(f"‚ö†Ô∏è  Classe {class_id}: seulement {available} images disponibles (quota: {quota})")
        
        random.shuffle(pairs)
        selected = pairs[:min(quota, available)]
        final_selection.extend(selected)
        class_stats[class_id] = len(selected)

    # Affichage du bilan par classe
    print(f"\nüìä R√©partition par classe:")
    for class_id in sorted(class_stats.keys()):
        print(f"   Classe {class_id}: {class_stats[class_id]} images")
    
    print(f"\nüì¶ Total s√©lectionn√© : {len(final_selection)} images.")
    
    # V√©rification de l'√©quilibre
    min_imgs = min(class_stats.values())
    max_imgs = max(class_stats.values())
    if max_imgs - min_imgs > quota * 0.1:  # Plus de 10% de diff√©rence
        print(f"‚ö†Ô∏è  D√©s√©quilibre d√©tect√©: min={min_imgs}, max={max_imgs}")

    random.shuffle(final_selection)

    # 3. DISTRIBUTION ET COPIE (Nouvelle structure)
    n_train = int(len(final_selection) * SPLIT_RATIOS['train'])
    n_valid = int(len(final_selection) * SPLIT_RATIOS['valid'])
    
    # D√©coupage de la liste
    datasets = {
        'train': final_selection[:n_train],
        'valid': final_selection[n_train:n_train+n_valid],
        'test': final_selection[n_train+n_valid:]
    }

    print(f"\nüöÄ Cr√©ation de la structure dans '{DEST_DIR}'...")

    for split_name, pairs in datasets.items():
        # Construction des chemins selon votre Tree
        # ex: garbage_dataset/train/images
        split_img_dir = os.path.join(DEST_DIR, split_name, 'images')
        # ex: garbage_dataset/train/labels
        split_lbl_dir = os.path.join(DEST_DIR, split_name, 'labels')
        
        os.makedirs(split_img_dir, exist_ok=True)
        os.makedirs(split_lbl_dir, exist_ok=True)

        for img_src, lbl_src in pairs:
            shutil.copy2(img_src, os.path.join(split_img_dir, os.path.basename(img_src)))
            shutil.copy2(lbl_src, os.path.join(split_lbl_dir, os.path.basename(lbl_src)))
    
    # Copie du fichier data.yaml
    source_yaml = os.path.join(SOURCE_DIR, 'data.yaml')
    dest_yaml = os.path.join(DEST_DIR, 'data.yaml')
    
    if os.path.exists(source_yaml):
        shutil.copy2(source_yaml, dest_yaml)
        print(f"‚úÖ Fichier data.yaml copi√©.")
    else:
        print(f"‚ö†Ô∏è  Fichier data.yaml introuvable dans {SOURCE_DIR}")
            
    print(f"‚úÖ Termin√© ! L'arborescence respecte le format demand√©.")

if __name__ == "__main__":
    main()

üîç Analyse de C:\Users\jansc\OneDrive\Bureau\ECAM_local\ai_project_ma2\ia-llm-project\data\dataset\garbage_dataset...
‚úÖ Trouv√© 3846 images sur 7 classes.
üìä Objectif : ~14 images/classe.
üì¶ Total s√©lectionn√© : 98 images.

üöÄ Cr√©ation de la structure dans 'C:\Users\jansc\OneDrive\Bureau\ECAM_local\ai_project_ma2\ia-llm-project\data\dataset\garbage_dataset_100_reduction'...
‚úÖ Termin√© ! L'arborescence respecte le format demand√©.


In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict
import os

# Configuration
BASE_DIR = Path('..').resolve()
DEST_DIR = BASE_DIR / "data" / "dataset" / "garbage_dataset_100_reduction"

def get_class_from_label(label_path):
    """Lit le premier ID de classe dans le fichier .txt YOLO."""
    if not os.path.exists(label_path):
        return None
    try:
        with open(label_path, 'r') as f:
            line = f.readline()
            if not line.strip(): return None
            return int(line.split()[0])
    except:
        return None

# Comptage des classes dans chaque split
splits = ['train', 'valid', 'test']
class_distribution = {split: defaultdict(int) for split in splits}

for split in splits:
    labels_dir = DEST_DIR / split / 'labels'
    if labels_dir.exists():
        for label_file in labels_dir.glob('*.txt'):
            class_id = get_class_from_label(str(label_file))
            if class_id is not None:
                class_distribution[split][class_id] += 1

# Pr√©paration des donn√©es pour le graphique
all_classes = sorted(set().union(*[d.keys() for d in class_distribution.values()]))
train_counts = [class_distribution['train'][c] for c in all_classes]
valid_counts = [class_distribution['valid'][c] for c in all_classes]
test_counts = [class_distribution['test'][c] for c in all_classes]

# Cr√©ation du graphique
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Graphique 1: Distribution par split
x = range(len(all_classes))
width = 0.25

ax1.bar([i - width for i in x], train_counts, width, label='Train', alpha=0.8)
ax1.bar(x, valid_counts, width, label='Valid', alpha=0.8)
ax1.bar([i + width for i in x], test_counts, width, label='Test', alpha=0.8)

ax1.set_xlabel('Classe')
ax1.set_ylabel('Nombre d\'images')
ax1.set_title('Distribution des classes par split')
ax1.set_xticks(x)
ax1.set_xticklabels(all_classes)
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# Graphique 2: Distribution totale
total_counts = [train_counts[i] + valid_counts[i] + test_counts[i] for i in range(len(all_classes))]
ax2.bar(all_classes, total_counts, color='steelblue', alpha=0.8)
ax2.set_xlabel('Classe')
ax2.set_ylabel('Nombre total d\'images')
ax2.set_title('Distribution totale des classes')
ax2.grid(axis='y', alpha=0.3)

# Ajout des valeurs sur les barres
for i, v in enumerate(total_counts):
    ax2.text(all_classes[i], v + 0.5, str(v), ha='center', va='bottom')

plt.tight_layout()
plt.show()

print(f"\nüìä R√©sum√© de la distribution:")
print(f"{'Classe':<10} {'Train':<10} {'Valid':<10} {'Test':<10} {'Total':<10}")
print("-" * 50)
for c in all_classes:
    total = train_counts[all_classes.index(c)] + valid_counts[all_classes.index(c)] + test_counts[all_classes.index(c)]
    print(f"{c:<10} {train_counts[all_classes.index(c)]:<10} {valid_counts[all_classes.index(c)]:<10} {test_counts[all_classes.index(c)]:<10} {total:<10}")