In [1]:
import os
from pathlib import Path, PosixPath

In [18]:
from dotenv import load_dotenv
import json
import numpy as np
from datetime import datetime

In [16]:
from PIL import Image

In [14]:
load_dotenv("../envs/mednist.env")
root_dir = Path(os.environ.get("DATASET_DIR"))
data_dir = Path(os.environ.get("DATA_DIR"))

In [None]:
def strip_root(id_path: PosixPath, root_dir: PosixPath = root_dir) -> str:
    """
    Strip local directory related information.
    """
    return str(id_path).replace(str(root_dir), "<DATASET_DIR>")

In [12]:
import yaml

hparams_dict = {
    "finetune_frac": 0.1,
    "test_frac": 0.1,
    "train_batchsize": 1024,
    "ftune_batchsize": 1024,
    "num_workers": 2,
    "device": "cpu",
    "spatial_dims": 2,
    "in_channels": 1,
    "out_channels": 5,
    "loss": "CrossEntropyLoss",
    "optimizer": "AdamW",
    "epochs": 4,
    "val_interval": 1,
    "lr": 1e-5,
}

with open(data_dir / 'hyperparam.yml', 'w') as outfile:
    yaml.dump(hparams_dict, outfile, default_flow_style=False)

In [17]:
class_names = [fname for fname in os.listdir(root_dir) if (root_dir / fname).is_dir()]
num_classes = len(class_names)
image_dict = {
    class_name: [
        root_dir / class_name / x for x in os.listdir(root_dir / class_name)]
    for class_name in class_names

}

image_fdirs, image_labels = [],[]
for idx, class_name in enumerate(class_names):
  image_fdirs.extend(image_dict[class_name])
  image_labels.extend([idx] * len(image_dict[class_name]))

num_total = len(image_labels)
image_width, image_height = Image.open(image_fdirs[0]).size

In [19]:
## Generate Split and save for future use!
finetune_frac, test_frac = hparams_dict['finetune_frac'], hparams_dict['test_frac']
ixs = [idx for idx in range(num_total)]
np.random.shuffle(ixs)
train_ixs = ixs[:int(num_total*(1-(finetune_frac+test_frac)))]
ftune_ixs = ixs[int(num_total*(1-(finetune_frac+test_frac))):int(num_total*(1-(test_frac)))]
test_ixs = ixs[int(num_total*(1-(test_frac))):]

trainX = [image_fdirs[idx] for idx in train_ixs]
trainY = [image_labels[idx] for idx in train_ixs]

ftuneX = [image_fdirs[idx] for idx in ftune_ixs]
ftuneY = [image_labels[idx] for idx in ftune_ixs]

testX = [image_fdirs[idx] for idx in test_ixs]
testY = [image_labels[idx] for idx in test_ixs]

print(
    "Training count =",len(trainX),"Validation count =", len(ftuneX), "Test count =",len(testX))

Training count = 47163 Validation count = 5895 Test count = 5896


In [25]:
datetime.now().strftime("%b%d,%Y")

'Oct01,2024'

In [52]:
trainX = [strip_root(id_path=x) for x in trainX]
trainY = [strip_root(id_path=x) for x in trainY]
ftuneX = [strip_root(id_path=x) for x in ftuneX]
ftuneY = [strip_root(id_path=x) for x in ftuneY]
testX = [strip_root(id_path=x) for x in testX]
testY = [strip_root(id_path=x) for x in testY]

In [54]:
data_split = {}
now = datetime.now().strftime("%b%d,%Y")
data_split['info'] = f"Date: {now}. MedNIST Dataset. Training: {len(trainX)} - Finetune: {len(ftuneX)} - Test: {len(testX)}"
data_split['train'] = {"image": trainX, "label": trainY}
data_split['ftune'] = {"image": ftuneX, "label": ftuneY}
data_split['test'] = {"image": testX, "label": testY}
with open(str(data_dir / "random_split.json"), "w") as fp:
    json.dump(data_split, fp)