In [None]:
from init_notebook import *
from multiprocessing import Pool
import pyarrow.parquet as pq


In [None]:
class Imagenet1kIterableDataset(BaseIterableDataset):
    _NUM_IMAGES = {
        "train": 1281167,
        "validation": 50000,
        "test": 100000,
    }
    def __init__(
            self,
            type: str = "train",  # "train", "validation", "test" 
            image_type: str = "pil",
            with_label: bool = False,
            repo_path: Union[str, Path] = config.BIG_DATASETS_PATH / "hug" / "imagenet-1k",
    ):
        if type not in self._NUM_IMAGES:
            raise ValueError(f"'type' needs to be one of {', '.join(self._NUM_IMAGES)}, got '{type}'")
        super().__init__()
        self._type = type
        self._image_type = image_type
        self._with_label = with_label
        self._repo_path = Path(repo_path)

    def __len__(self):
        return self._NUM_IMAGES[self._type]
        
    def __iter__(self):
        files = sorted((self._repo_path / "data").glob(f"{self._type}-*-of-*.parquet"))
        for file in files:
            for batch in pq.ParquetFile(file).iter_batches(batch_size=10):
                images = batch["image"]
                labels = batch["label"]
                for image, label in zip(images, labels):
                    buffer = io.BytesIO(image["bytes"].as_buffer())
                    image = PIL.Image.open(buffer)
                    if self._image_type == "tensor":
                        image = VF.to_tensor(image.convert("RGB"))
                    if self._with_label:
                        yield image, label
                    else:
                        yield image


image_ds = Imagenet1kIterableDataset()
size_map = {}
try:
    with tqdm(image_ds) as iterable:
        for i, image in enumerate(iterable):
            size_map[image.size] = size_map.get(image.size, 0) + 1
            iterable.set_postfix({"num_res": len(size_map)})
            #if i % 2000 == 0:
            #    print(size_map)
except KeyboardInterrupt:
    pass

In [None]:
#len(size_map)
df = pd.DataFrame(size_map.values(), index=list(size_map), columns=["count"]).sort_index().sort_values("count")
df.index = df.index.map(lambda x: str(tuple(x)))
df["%"] = (df["count"] / df["count"].sum() * 100).round(2)
df.tail(50)

In [None]:
df[df["count"] > 100].plot.bar()

In [None]:
grid = []
for image in tqdm(image_ds.limit(5000)):
    grid.append(image_resize_crop(image, (16, 16)))
display(VF.to_pil_image(make_grid(grid, nrow=50)))

In [None]:
with tqdm() as progress:
    for batch in tqdm(DataLoader(image_ds, batch_size=1, num_workers=4)):
        progress.update(batch.shape[0])

# extract fixed sizes

In [None]:
df.index[-20:].map(lambda x: (3, eval(x)[::-1]))

In [None]:
from src.util.tensor_storage import TensorStorage

In [None]:
def extract_shapes(shapes: List[Tuple[int, int, int]]):
    shape_storage = {
        shape: TensorStorage(
            filename_part=config.BIG_DATASETS_PATH / "imagenet1k-uint8-by-shape" / "x".join(str(s) for s in shape) / "batch",
            max_bytes=250 * 1024**2,
        )
        for shape in shapes
    }

    try:
        for image, label in tqdm(Imagenet1kIterableDataset(with_label=True)):
            shape = (3, image.height, image.width)
            if shape in shape_storage:
                t = (VF.to_tensor(image.convert("RGB")) * 255).to(torch.uint8)
                shape_storage[shape].add(t)
    except KeyboardInterrupt:
        pass
    
    for storage in shape_storage.values():
        storage.store_buffer()

extract_shapes([
    (3, 375, 500),
    (3, 333, 500),
    (3, 500, 375),
    (3, 334, 500),
    (3, 500, 333),
    (3, 500, 500),
    
])

In [None]:
t = torch.load(config.BIG_DATASETS_PATH / "imagenet1k-uint8-by-shape" / "3x375x500" / "batch-000001.pt")
t.shape