# Creating Dataloaders using `torch-em`

The tutorial is focused on creating dataloaders using `torch-em` for various segmentation tasks. Let's get started.
The first thing to do would be to make sure that we have `torch-em` installed and accessible in the kernel. 

In [3]:
URL = "https://github.com/constantinpape/torch-em#installation"
try:
    import torch_em
    print("Yayy, we found 'torch-em'. Start creating your dataloaders already. Skip to Step 1.")
except ModuleNotFoundError:
    print(f"'torch-em' was not found. Please install it from {URL}")

Yayy, we found 'torch-em'. Start creating your dataloaders already


If the script above suggests to install `torch-em`, please go ahead and install it first.

NOTE: In case you are using Google Colab / Kaggle, the installation is mentioned below, we recommend installing the repositories from [source](https://github.com/constantinpape/torch-em?tab=readme-ov-file#from-source) for best results.

In [5]:
# Now let's try and check again if `torch-em` is installed
import torch_em
# TODO: need to how this works on kaggle / google colab and what's missing here

Ideally, this should not throw any errors. If there are some modules missing, please go ahead and install them.

## Step 0: Let's explore our datasets

We will create dataloaders from three different open-source datasets, and see how to create dataloaders for training a UNet architecture. The choice of datasets are following:

1. DSB (Nuclei Segmentation in Light Microscopy: Caicedo et al. - https://doi.org/10.1038/s41592-019-0612-7)
2. Covid IF (Nuclei and Cell Segmentation in Immunofluorescence: Pape et al. - https://doi.org/10.1002/bies.202000257)
3. PlantSeg (Cell Segmentation in Confocal and Light-Sheet Microscopy: Wolny et al. - https://doi.org/10.7554/eLife.57613)

In [None]:
def _fetch_datasets(dataset_name, path):
    if dataset_name == "dsb":
        from torch_em.data.datasets.dsb import _download_dsb
        _download_dsb(path, "reduced", download=True)
        data_path = path

    elif dataset_name == "covid_if:
        from torch_em.data.datasets.covid_if import _download_covid_if
        _download_covid_if(path, download=True)
        data_path = path

    elif dataset_name == "plantseg":
        # let's test for root
        from torch_em.data.datasets.plantseg import _require_plantseg_data
        data_path = _require_plantseg_data(path, download=True, name="root", split="train")

    else:
        raise ValueError(f"Oops, download is not enabled for {dataset_name}.")

    return data_path


def plot_samples(image, label):
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(image)
    ax[1].imshow(label)
    plt.show()

Our dataset is downloaded. Let's explore them quickly before proceeding to create the dataset object.

In [None]:
# For DSB
image_paths = ...
label_paths = ...

print("The image extension seems to be:", os.path.splitext(image_paths[0]))
print("The label extension seems to be:", os.path.splitext(label_paths[0]))

# It appears that the images are in tif format. It's a supported data format. Now let's check the data structure.

image_shapes = [imageio.imread(path).shape for path in image_paths]
label_shapes = [imageio.imread(path).shape for path in label_paths]

# It appears that the images has one channel. It's a supported data structure as well. Now let's visualize one image to understand our data better.

for image_path, label_path in zip(image_paths, label_paths):
    image = imageio.imread(image_path)
    label = imageio.imread(label_path)

    plot_samples(image, label)

    break  # it's enough to check a few samples, feel free to explore the entire dataset

In [None]:
# For Covid IF
volume_paths = ...

print("The volume extension seems to be:", os.path.splitext(image_paths[0]))

# It appears that the images are in hdf5 format. It's a supported data format. Now let's check the data structure.

# Let's try to open one image first
import h5py
with h5py.File(image_path[0]) as f:
    image = f["..."][:]
    label = f["..."][:]

    print(image.shape, label.shape)

    plot_samples(image, label)

In [None]:
# For PlantSeg
volume_paths = ...

print("The volume extension seems to be:", os.path.splitext(image_paths[0]))

# It appears that the images...

# Let's try to open one image first
import h5py
with h5py.File(image_path[0]) as f:
    image = f["..."][:]
    label = f["..."][:]

    print(image.shape, label.shape)

    plot_samples(image, label)

## Step 1: Let's create the dataset

In [None]:
# DSB dataset

In [None]:
# Covid IF dataset

In [None]:
# PlantSeg dataset

## Step 2: Let's create the dataloader

In [None]:
# DSB dataloader

In [None]:
# Covid IF dataloader

In [None]:
# PlantSeg dataloader

## Step 3: Let's check our dataloaders

In [None]:
def look_at_chosen_loader(chosen_loader, do_run=False):
    if do_run:
        save_path = f"./loader.png"
        print("Let's check how the samples look first. We store the images here:", save_path)
        from torch_em.util.debug import check_loader
        check_loader(chosen_loader, 8, plt=True, save_path=save_path)
    else:
        print(f"There are {len(chosen_loader} samples generated from the loader. Please pass 'do_run=True' to 'look_at_chosen_loader' function.")
        return

    for x, y in chosen_loader:
        print(x.shape, y.shape)