In [None]:
# Optional: preview some samples
import matplotlib.pyplot as plt
import numpy as np

if 'train_ds' in globals() and len(train_ds) > 0:
    n_show = 16
    idxs = np.linspace(0, min(len(train_ds)-1, n_show-1), n_show, dtype=int)
    imgs = [train_ds[i][0] for i in idxs]
    labels = [train_ds[i][1] for i in idxs]

    # images come as HxWxC or CxHxW depending on versions; try to handle both
    def to_hwc(x):
        arr = np.array(x)
        if arr.ndim == 3 and arr.shape[0] in (1,3):
            arr = np.transpose(arr, (1,2,0))
        return arr.squeeze()

    plt.figure(figsize=(8,8))
    for i, (img, lab) in enumerate(zip(imgs, labels)):
        ax = plt.subplot(4,4,i+1)
        ax.imshow(to_hwc(img), cmap='gray' if to_hwc(img).ndim==2 else None)
        # label can be array-like; show first element if needed
        try:
            lab_disp = int(lab[0])
        except Exception:
            lab_disp = lab
        ax.set_title(str(lab_disp), fontsize=8)
        ax.axis('off')
    plt.tight_layout()
else:
    print("No training dataset found to preview. Run the download cell above first.")

In [None]:
# Configure what to download
from pathlib import Path
import os

# Choose dataset flags (pick one or many)
# Examples: 'pathmnist', 'chestmnist', 'dermamnist', 'octmnist', 'pneumoniamnist', 'retinamnist',
#           'breastmnist', 'bloodmnist', 'tissuemnist', 'organamnist', 'organcmnist', 'organsmnist'
DATA_FLAGS = ["pathmnist"]  # change this list as you like

# Choose image size (28 for classic MedMNIST; 64/128/224 for MedMNIST+ where available)
IMG_SIZE = 28
AS_RGB = False  # set True to expand grayscale to 3 channels

# Download directory inside this repository
DOWNLOAD_DIR = Path("data/medmnist").resolve()
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
print(f"Download dir: {DOWNLOAD_DIR}")

# Download using the official API
import medmnist
from medmnist import INFO

summary = []
for flag in DATA_FLAGS:
    if flag not in INFO:
        print(f"[WARN] Unknown dataset flag: {flag}")
        continue
    info = INFO[flag]
    DataClass = getattr(medmnist, info['python_class'])

    print(f"\n==> Downloading {flag} (size={IMG_SIZE}) ...")
    # The dataset classes support: split, transform=None, download=False, size=28, as_rgb=False, root=None
    train_ds = DataClass(split='train', download=True, size=IMG_SIZE, as_rgb=AS_RGB, root=str(DOWNLOAD_DIR))
    val_ds   = DataClass(split='val',   download=True, size=IMG_SIZE, as_rgb=AS_RGB, root=str(DOWNLOAD_DIR))
    test_ds  = DataClass(split='test',  download=True, size=IMG_SIZE, as_rgb=AS_RGB, root=str(DOWNLOAD_DIR))

    summary.append((flag, len(train_ds), len(val_ds), len(test_ds)))

print("\nDownloaded datasets:")
for flag, ntr, nv, nt in summary:
    print(f"- {flag}: train={ntr}, val={nv}, test={nt}")

print("Done.")

In [None]:
# Install MedMNIST (and torch if missing)
import sys, subprocess

def pip_install(package):
    print(f"Installing {package}...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

try:
    import medmnist  # noqa: F401
except Exception:
    pip_install("medmnist>=3.0.0")

# Optional: ensure torch for some utilities
try:
    import torch  # noqa: F401
except Exception:
    pip_install("torch")

print("Setup complete.")

# MedMNIST downloader

This notebook installs the official MedMNIST package, downloads one or more datasets into this workspace, and previews a few samples.

- Change the dataset list and size below.
- Files are saved under `data/medmnist` inside this project.