In [None]:
import torch
from torch.utils.data import Dataset
from sklearn.datasets import fetch_lfw_people
import torchvision.transforms as transforms

from gorillatracker.transform_utils import SquarePad


class LFWPytorchDataset(Dataset):
    def __init__(self, resize=None, transform=None):
        # Load the LFW dataset
        lfw_people = fetch_lfw_people(resize=resize, color=True)
        self.images = lfw_people.images
        self.targets = lfw_people.target
        self.target_names = lfw_people.target_names
        self.transform = transform
        self.topil = transforms.ToPILImage()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.targets[idx]
        image = self.topil(image)

        if self.transform:
            image = self.transform(image)

        # Convert the label to a tensor
        label = torch.tensor(label, dtype=torch.long)

        return image, label


transform = transforms.Compose(
    # [transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])]  # Since images are grayscale, only one channel
    # [SquarePad(), transforms.ToTensor()]
    [transforms.ToTensor()]
)

lfw_dataset = LFWPytorchDataset(resize=1.0, transform=transform)

In [None]:
import matplotlib.pyplot as plt
import random

# two images of michael jackson and two images of gerhard schroeder

topil = transforms.ToPILImage()
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

michael_jackson_idx = lfw_dataset.target_names.tolist().index("Michael Jackson")
gerhard_schroeder_idx = lfw_dataset.target_names.tolist().index("Gerhard Schroeder")

michael_jackson_indices = [i for i, x in enumerate(lfw_dataset.targets) if x == michael_jackson_idx]
gerhard_schroeder_indices = [i for i, x in enumerate(lfw_dataset.targets) if x == gerhard_schroeder_idx]

idx = random.choice(michael_jackson_indices)
image1, label1 = lfw_dataset[idx]
idx = random.choice(michael_jackson_indices)
image2, label2 = lfw_dataset[idx]
idx = random.choice(gerhard_schroeder_indices)
image3, label3 = lfw_dataset[idx]
idx = random.choice(gerhard_schroeder_indices)
image4, label4 = lfw_dataset[idx]


print(f"Image shape: {image1.shape}, Label: {label1}")
image1 = topil(image1)
axs[0][0].imshow(image1)
axs[0][0].set_title(lfw_dataset.target_names[label1])
axs[0][0].axis("off")


print(f"Image shape: {image2.shape}, Label: {label2}")
image2 = topil(image2)
axs[0][1].imshow(image2)
axs[0][1].set_title(lfw_dataset.target_names[label2])
axs[0][1].axis("off")


print(f"Image shape: {image3.shape}, Label: {label3}")
image3 = topil(image3)
axs[1][0].imshow(image3)
axs[1][0].set_title(lfw_dataset.target_names[label3])
axs[1][0].axis("off")  # Hide axes


print(f"Image shape: {image4.shape}, Label: {label4}")
image4 = topil(image4)
axs[1][1].imshow(image4)
axs[1][1].set_title(lfw_dataset.target_names[label4])
axs[1][1].axis("off")  # Hide axes

# Show the plot
plt.tight_layout()
plt.savefig("lfw_dataset.pdf")
plt.show()

In [None]:
import numpy as np

labels = [lfw_dataset[i][1] for i in range(len(lfw_dataset))]
unique, counts = np.unique(labels, return_counts=True)

# group into 1-5, 5-10, 10-100, 100-1000
bins = [1, 2, 5, 10, 100, 1000]
hist = np.histogram(counts, bins=bins)

plt.bar(range(len(hist[0])), hist[0], tick_label=["1", "2-4", "5-9", "10-99", "100-1000"])
plt.xlabel("Number of images per class")
plt.ylabel("Number of classes")
plt.title("Distribution of images per class")
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.rc("font", family="serif", size=12)

# filter that at least 3 images per class
min_images_per_class = 3
labels = [lfw_dataset[i][1] for i in range(len(lfw_dataset))]
unique, counts = np.unique(labels, return_counts=True)
filtered_labels_ = [label for label, count in zip(unique, counts) if count >= min_images_per_class]
filtered_labels = [label for label in labels if label in filtered_labels_]
print(f"Number of classes with at least {min_images_per_class} images: {len(np.unique(filtered_labels))}")

unique, counts = np.unique(filtered_labels, return_counts=True)

# group into 3-4, 5-9, 10-99, 100-1000
bins = [3, 5, 10, 100, 1000]
hist, bin_edges = np.histogram(counts, bins=bins)

bin_count = [0] * (len(bins) - 1)
for count in counts:
    if count >= 1000:
        print(f"Class with more than 1000 images: {count}")
    elif count >= 100:
        bin_count[3] += count
    elif count >= 10:
        bin_count[2] += count
    elif count >= 5:
        bin_count[1] += count
    else:
        bin_count[0] += count

# Create the combined plot
fig, ax1 = plt.subplots(figsize=(10, 6))

# Plot the number of classes on the primary y-axis
bars = ax1.bar(range(len(hist)), hist, tick_label=["3-4", "5-9", "10-99", "100-1000"], color="firebrick", alpha=1.0)
ax1.set_xlabel("Number of Images per Class", fontsize=11)
ax1.set_ylabel("Number of Classes", color="firebrick", fontsize=11)
ax1.tick_params(axis="y", labelcolor="firebrick")
ax1.spines["left"].set_color("firebrick")
ax1.spines["left"].set_linewidth(2)

# Annotate bar values
for bar in bars:
    yval = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width() / 2, yval + 2, int(yval), ha="center", fontsize=10)

# Create a secondary y-axis to plot the number of images
ax2 = ax1.twinx()
points = ax2.plot(
    range(len(bin_count)), bin_count, color="orange", marker="o", linestyle="--", linewidth=2, markersize=8
)
ax2.set_ylabel("Number of Images", color="orange", fontsize=11)
ax2.tick_params(axis="y", labelcolor="orange")
ax2.spines["right"].set_color("orange")
ax2.spines["right"].set_linewidth(2)

# remove the top spines for ax1 and ax2
ax1.spines["top"].set_visible(False)
ax2.spines["top"].set_visible(False)

for point in points:
    yval = point.get_ydata()
    for i, val in enumerate(yval):
        ax2.text(i, val + 100, int(val), ha="center", fontsize=10)

plt.grid(False)
plt.tight_layout()
plt.savefig("lfw_distribution.pdf", dpi=500)
plt.show()

print(f"Number of images left: {len(filtered_labels)}")