In [None]:

import os
from collections import defaultdict
import matplotlib.pyplot as plt

# === Config ===
IMG_DIR = "train_images"
IMG_EXT = (".jpg", ".jpeg", ".png")

# === Count images per class ===
class_counts = defaultdict(int)

for label in sorted(os.listdir(IMG_DIR)):
    label_path = os.path.join(IMG_DIR, label)
    if not os.path.isdir(label_path):
        continue
    count = sum(fname.lower().endswith(IMG_EXT) for fname in os.listdir(label_path))
    class_counts[label] = count

# === Print distribution ===
print("📊 Class distribution:")
for label, count in class_counts.items():
    print(f"{label:30s}: {count} images")

min_count = min(class_counts.values())
if min_count < 5:
    print(f"\n WARNING: At least one class has fewer than 5 images (min = {min_count}) — stratified split may fail!")
else:
    print("\n All classes have sufficient images for stratified splitting.")

# === Plot bar chart ===
labels = list(class_counts.keys())
counts = [class_counts[label] for label in labels]

plt.figure(figsize=(10, 6))
bars = plt.bar(labels, counts)
plt.xticks(rotation=45, ha="right")
plt.title("Class Distribution in train_images/")
plt.xlabel("Class Label")
plt.ylabel("Number of Images")
plt.tight_layout()

# Optional: annotate bars with values
for bar, count in zip(bars, counts):
    plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), str(count), ha='center', va='bottom', fontsize=9)

plt.show()