In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict, Image
import PIL

In [None]:
def prep_dataset(sample):
    img = sample['image']
    sample['width'] = img.width
    sample['height'] = img.height

    if (img.mode == 'L') | (img.mode == 'CMYK') | (img.mode == 'RGBA'):
        rgbimg = PIL.Image.new("RGB", img.size)
        rgbimg.paste(img)
        img = rgbimg
    
    sample['image'] = Image().encode_example(img)

    return sample

In [None]:
batch_size = 64

# Training split

Cutoff for training split is 2048 samples

In [None]:
%%time
dataset = load_dataset("imagenet-1k",split='train)
dataset = dataset.map(prep_dataset)

In [None]:
train_dataset = Dataset.from_dict({
    "img_batch" : [],
    "label_batch" : [],
    "width" : [],
    "height": [],
})

In [None]:
width = torch.tensor(dataset['width']);
height = torch.tensor(dataset['height']);
unique_pairs = torch.unique(torch.stack([width, height], dim=1), dim=0)
pair_counts = {(w.item(), h.item()): ((width == w) & (height == h)).sum().item() 
               for w, h in unique_pairs}
sizes = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)
N = 0;
while (sizes[N][1]>=2048):
    N +=1
sizes = sizes[:N]

In [None]:
%%time
for size, count in sizes:
    w = size[0]; h = size[1]
    filtered = dataset.filter(lambda x: x['width']==w and x['height']==h)
    for i_batch in range(len(filtered)//batch_size):
        ind = range(i_batch * batch_size, (i_batch + 1) * batch_size)
        img_batch = filtered[ind]['image']
        img_batch = [Image().encode_example(pil_img) for pil_img in img_batch] 
        label_batch = filtered[ind]['label']

        train_dataset = train_dataset.add_item({
            "img_batch" : img_batch,
            "label_batch" : label_batch,
            "width" : w,
            "height": h,
        })

# Test split

Cutoff for the test split is 256 samples

In [None]:
%%time
dataset = load_dataset("imagenet-1k",split='test')
dataset = dataset.map(prep_dataset)

In [None]:
test_dataset = Dataset.from_dict({
    "img_batch" : [],
    "label_batch" : [],
    "width" : [],
    "height": [],
})

In [None]:
width = torch.tensor(dataset['width']);
height = torch.tensor(dataset['height']);
unique_pairs = torch.unique(torch.stack([width, height], dim=1), dim=0)
pair_counts = {(w.item(), h.item()): ((width == w) & (height == h)).sum().item() 
               for w, h in unique_pairs}
sizes = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)
N = 0;
while (sizes[N][1]>=256):
    N +=1
sizes = sizes[:N]

In [None]:
%%time
for size, count in sizes:
    w = size[0]; h = size[1]
    filtered = dataset.filter(lambda x: x['width']==w and x['height']==h)
    for i_batch in range(len(filtered)//batch_size):
        ind = range(i_batch * batch_size, (i_batch + 1) * batch_size)
        img_batch = filtered[ind]['image']
        img_batch = [Image().encode_example(pil_img) for pil_img in img_batch] 
        label_batch = filtered[ind]['label']

        test_dataset = test_dataset.add_item({
            "img_batch" : img_batch,
            "label_batch" : label_batch,
            "width" : w,
            "height": h,
        })

# Validation split

Cutoff for training split is 128 samples

In [None]:
%%time
dataset = load_dataset("imagenet-1k",split='validation')
dataset = dataset.map(prep_dataset)

In [None]:
val_dataset = Dataset.from_dict({
    "img_batch" : [],
    "label_batch" : [],
    "width" : [],
    "height": [],
})

In [None]:
width = torch.tensor(dataset['width']);
height = torch.tensor(dataset['height']);
unique_pairs = torch.unique(torch.stack([width, height], dim=1), dim=0)
pair_counts = {(w.item(), h.item()): ((width == w) & (height == h)).sum().item() 
               for w, h in unique_pairs}
sizes = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)
N = 0;
while (sizes[N][1]>=128):
    N +=1
sizes = sizes[:N]

In [None]:
%%time
for size, count in sizes:
    w = size[0]; h = size[1]
    filtered = dataset.filter(lambda x: x['width']==w and x['height']==h)
    for i_batch in range(len(filtered)//batch_size):
        ind = range(i_batch * batch_size, (i_batch + 1) * batch_size)
        img_batch = filtered[ind]['image']
        img_batch = [Image().encode_example(pil_img) for pil_img in img_batch] 
        label_batch = filtered[ind]['label']

        val_dataset = val_dataset.add_item({
            "img_batch" : img_batch,
            "label_batch" : label_batch,
            "width" : w,
            "height": h,
        })

In [None]:
dataset = DatasetDict({
    "train": train_dataset,
    "test": test_dataset,
    "validation": val_dataset,
})

In [None]:
dataset.push_to_hub("danjacobellis/imagenet_batched_64")