In [1]:
import matplotlib.pyplot as plt
import torchvision
from datasets import load_dataset, load_from_disk
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
datasets = ["organamnist", "pathmnist", "retinamnist", "dermamnist", "bloodmnist"]
image_size = "64" # 28, 64, 128

In [None]:
dataset_dir = "medmnist_datasets"
for dataset_name in datasets:
    test_dataset = load_dataset(f"jafermarq/{dataset_name}", image_size, split="val")
    test_dataset.save_to_disk(f"{dataset_dir}/{dataset_name}")

In [None]:
pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

def apply_transforms(batch):
    """Apply transforms to the batch."""
    batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
    return batch

In [None]:
# Use "pathmnist" as an example to test the dataloader
test_dataset = load_from_disk(f"{dataset_dir}/pathmnist")
dataset = test_dataset.with_format("torch").with_transform(apply_transforms)
dataloader = DataLoader(dataset, batch_size=8)

In [None]:
# Print out the first batch
batch = next(iter(dataloader))
images, labels = batch["image"], batch["label"]

# Make a grid of the batch
grid = torchvision.utils.make_grid(images, nrow=4, normalize=True, scale_each=True)

# Plot
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0))  # (C,H,W) -> (H,W,C)
plt.axis("off")
plt.show()

print("Labels:", labels.tolist())