In [1]:
import sys

import pandas as pd
import numpy as np
from pathlib import Path

In [2]:
SPLITS = ["train", "val", "test"]
SIZE = 2
SOURCE = Path("/mnt/jbrockma/bachelor-thesis/data/oct/OCT")

In [3]:
split_info = pd.read_csv("/mnt/jbrockma/bachelor-thesis/medmnist_data_split/octmnist_split_info.csv", index_col="image_id")
split_info

Unnamed: 0_level_0,split,index
image_id,Unnamed: 1_level_1,Unnamed: 2_level_1
CNV-732516-77,train,0
NORMAL-3476775-5,train,1
NORMAL-1725461-13,train,2
NORMAL-5103899-4,train,3
CNV-7604641-38,train,4
...,...,...
DRUSEN-8350131-2,test,995
DME-859014-1,test,996
NORMAL-3063933-2,test,997
DME-943690-6,test,998


In [4]:
split_value_counts = split_info["split"].value_counts()

n_samples_of_split = {SPLIT: split_value_counts[SPLIT] for SPLIT in SPLITS}

# Prepare numpy arrays
images_of_split = {}
for SPLIT in SPLITS:
    n_samples = n_samples_of_split[SPLIT]
    images_of_split[SPLIT] = np.empty((n_samples, SIZE, SIZE), dtype=np.uint8)

In [5]:
info_of_image = {im_id: (split, idx) for im_id, split, idx in split_info.itertuples()}
info_of_image

{'CNV-732516-77': ('train', 0),
 'NORMAL-3476775-5': ('train', 1),
 'NORMAL-1725461-13': ('train', 2),
 'NORMAL-5103899-4': ('train', 3),
 'CNV-7604641-38': ('train', 4),
 'NORMAL-7081085-9': ('train', 5),
 'DME-1591159-85': ('train', 6),
 'CNV-3318854-209': ('train', 7),
 'NORMAL-4419154-8': ('train', 8),
 'NORMAL-907824-2': ('train', 9),
 'NORMAL-4556266-2': ('train', 10),
 'DRUSEN-4934663-36': ('train', 11),
 'CNV-172472-321': ('train', 12),
 'CNV-4219137-86': ('train', 13),
 'DRUSEN-7756213-4': ('train', 14),
 'CNV-6566667-25': ('train', 15),
 'CNV-9005459-75': ('train', 16),
 'NORMAL-9267111-53': ('train', 17),
 'CNV-2760476-3': ('train', 18),
 'CNV-6215140-26': ('train', 19),
 'CNV-6668596-145': ('train', 20),
 'NORMAL-5810801-1': ('train', 21),
 'NORMAL-6830356-3': ('train', 22),
 'CNV-1188386-762': ('train', 23),
 'NORMAL-3912695-7': ('train', 24),
 'CNV-4974377-25': ('train', 25),
 'NORMAL-4670940-2': ('train', 26),
 'NORMAL-7835584-9': ('train', 27),
 'NORMAL-5783190-8': ('tr

In [6]:
len(info_of_image)

109309

In [7]:
from tqdm import tqdm
import concurrent.futures
from PIL import Image
import random

In [8]:
with tqdm(desc="Processing", total=len(split_info), unit="pic") as pbar:
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = set()
    
        def preprocess(_fp, _info):
            try:
                _split, _index = _info
        
                with Image.open(_fp) as im:
                    if im.mode != "L":
                        im = im.convert("L")
                    im = im.resize((SIZE, SIZE), Image.BICUBIC)
        
                    images_of_split[_split][_index] = np.asarray(im)
            except Exception as e:
                print(e)
                raise e
    
        for child in SOURCE.rglob("*"):
            if child.is_file():
                try:
                    info = info_of_image.pop(child.stem)
                except KeyError:
                    pbar.write(f"Skipping {child.name}")
                    continue
                future = executor.submit(preprocess, child, info)
                future.add_done_callback(lambda _: pbar.update())
                futures.add(future)
        
        missing_ims = list(info_of_image.keys())
        
        if not len(missing_ims) == 0:
            pbar.write(f"The dataset at {SOURCE} is incomplete.")
            if len(missing_ims) == 1:
                pbar.write(f"Missing image with id {missing_ims[0]}")
            elif 1 < len(info_of_image) < 6:
                im_ids_str = ", ".join(missing_ims[:-1]) + f", and {missing_ims[-1]}"
                pbar.write(f"Missing files with ids {im_ids_str}")
            else:
                random_im_ids = random.sample(missing_ims, 5)
                im_ids_str = ", ".join(random_im_ids[:-1]) + f", and {random_im_ids[-1]}"
                pbar.write(f"Missing {len(missing_ims)} images such as {im_ids_str}")
            sys.exit(1)
    
        concurrent.futures.wait(futures)

Processing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 109309/109309 [03:09<00:00, 577.09pic/s]


In [9]:

exceptions = set()

for future in futures:
    exception = future.exception()
    if exception:
        exceptions.add(exception)

for exception in exceptions:
    print(exception)


In [12]:
mnist = np.load("/root/.medmnist/octmnist.npz")
labels_of_split = {SPLIT: mnist[f"{SPLIT}_labels"] for SPLIT in SPLITS}

name_to_array = {}
for data_name, data_of_split in [("images", images_of_split), ("labels", labels_of_split)]:
    for SPLIT in SPLITS:
        name_to_array[f"{SPLIT}_{data_name}"] = data_of_split[SPLIT]

np.savez_compressed("/root/test.npz", **name_to_array)