In [3]:
import cv2, os, PIL
import pandas as pd
from torchvision.transforms import ToPILImage

from typing import List
from tqdm import tqdm

In [None]:
def load_imagenet_val(folderdir='/Users/elior/Downloads/ILSVRC2012_img_val', 
                      transform=ToPILImage()
                      ) -> List[PIL.Image]:
    """
    Helper function to load ILSVRC2012_img_val dataset.
    Parameters
    ----------
    folderdir
        Path to the 'ILSVRC2012_img_val' folder.
    transform
        A torchvision transform that will be applied to every loaded image.
    Returns
    -------
    List[PIL.Image]
        A list of the loaded PIL images.
    """
    all_files = os.listdir(folderdir)
    all_files.sort()
    
    images = []
    for filename in tqdm(all_files):
    
        img = cv2.imread(os.path.join(folderdir, filename))
        if img is not None:
            images.append(transform(img) if transform else img)

    return images

In [None]:
def stratified_sampling(filedir='/Users/elior/Downloads/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt',
                        pct: float = 0.02,
                        ) -> pd.DataFrame:
    """
    Helper function to sample the ILSVRC2012_img_val dataset in a stratified method.
    Parameters
    ----------
    folderdir
        Path to the 'ILSVRC2012_validation_ground_truth.txt' folder.
    pct
        Percentage of (image, label) pairs that are kept in the sampling process compared to the initial entire dataset.
    Returns
    -------
    pd.DataFrame
        A DataFrame of the sampled PIL images and their associated label.
    """
    # Collect every class label for each image
    with open(filedir, 'r') as f:
        labels = f.read().splitlines()[0:20001]
    
    # Gather all images sorted like their associated label
    images = load_imagenet_val()
    
    # Create a Dataframe with the images and labels 
    merge_dict = {'images':images, 'labels':labels}
    df = pd.DataFrame(merge_dict)
    
    # Return a stratified sample of the dataset
    return df.groupby('labels', group_keys=False).apply(lambda x: x.sample(frac=pct, replace=False))


In [None]:
# Create the subset
subset = stratified_sampling(pct=0.1)

# Open first five images
for i in range(len(5)):
    subset.loc[i].at['images'].show()