In [None]:
# === Reda - Dataset Exploration ===

import os
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.datasets import ImageFolder
from torchvision.io import read_image
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import random

# --- Path to cleaned dataset ---
data_dir = "data/cleaned/"

# --- Load dataset using ImageFolder ---
dataset = ImageFolder(root=data_dir)

# --- Count samples per class ---
class_counts = {}
for cls_name, idx in dataset.class_to_idx.items():
    class_counts[cls_name] = sum([1 for _, label in dataset.samples if label == idx])

# Display counts
print("Number of samples per class:")
for cls, count in class_counts.items():
    print(f"{cls}: {count}")

# Plot countplot
plt.figure(figsize=(8,5))
sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()))
plt.title("Number of samples per class")
plt.ylabel("Count")
plt.xlabel("Classes")
plt.show()

# --- Show a few sample images per class ---
def show_samples_per_class(dataset, samples_per_class=3):
    class_to_indices = {cls_idx: [] for cls_idx in range(len(dataset.classes))}
    
    for i, (_, label) in enumerate(dataset.samples):
        if len(class_to_indices[label]) < samples_per_class:
            class_to_indices[label].append(i)
    
    plt.figure(figsize=(samples_per_class*3, len(dataset.classes)*3))
    
    for cls_idx, indices in class_to_indices.items():
        for i, idx in enumerate(indices):
            img, label = dataset[idx]
            img = img.permute(1, 2, 0)  # C,H,W -> H,W,C for plt
            plt.subplot(len(dataset.classes), samples_per_class, cls_idx*samples_per_class + i + 1)
            plt.imshow(img)
            plt.axis('off')
            if i == 1:
                plt.title(dataset.classes[label])
    plt.tight_layout()
    plt.show()

show_samples_per_class(dataset, samples_per_class=3)
