Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PanNuke Dataloader #153

Merged
merged 15 commits into from
Oct 11, 2023
Merged

Add PanNuke Dataloader #153

merged 15 commits into from
Oct 11, 2023

Conversation

anwai98
Copy link
Contributor

@anwai98 anwai98 commented Oct 5, 2023

  • Adding the dataloader for the PanNuke dataset (histopathology domain)

@anwai98
Copy link
Contributor Author

anwai98 commented Oct 5, 2023

@constantinpape At the current stage, it looks like it's working. It would be great to have an overlook on this.
In brief (the current stage): It does the download and all the formattings end-to-end, to just expect the path where the data is stored (or supposed to be), and with download=True does all the rest stuff.

There's a rather complicated label transformation taking care of a lot of stuff (briefed in the respective function in the file itself). Let me know how it looks.

@anwai98 anwai98 marked this pull request as ready for review October 5, 2023 17:12
Copy link
Owner

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of things that don't look correct to me, I left comments about that.
Another general note: I think it would make more sense to apply the transformation that make instance and semantic labels out of the masks when creating the data; this makes it easier to later use different kinds of label transforms on the data (e.g. connected components for sub-patches).

torch_em/data/datasets/pannuke.py Outdated Show resolved Hide resolved
tmp_name = tmp_fold.split("_")[0] + tmp_fold.split("_")[1] # name of a particular sub-directory (per fold)
with h5py.File(os.path.join(path, f"pannuke_{tmp_fold}.h5"), "w") as f:
img_path = glob(os.path.join(path, tmp_fold, "*", "images", tmp_name, "images.npy"))[0]
gt_path = glob(os.path.join(path, tmp_fold, "*", "masks", tmp_name, "masks.npy"))[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only take a single of the files here? Is this on purpose?
I would have assumed that the code should look something like this instead:

img_paths = sorted(glob(..., "images.npy))  # imported to do sorted here, so that you get the same order for masks and images
gt_paths = sorted(glob(..., "masks.npy))  # imported to do sorted here, so that you get the same order for masks and images
for i, (img_path, gt_path) in enumerate(zip(img_paths, gt_paths)):
  img = np.load(img_path)
  gt = np.load(gt_path)
  assert img,shape == gt.shape  # or similar; but make sure that the shapes match, you might need to account for a different number of channels
  out_path = os.path.join(..., f"pannuke_{tmp_fold}_{i}.h5"
  with h5py.File(out_path, "w") as f:
      ...

That way you would use all the images in a given fold

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's to provide the option to (only) use specific folds of the dataset (there are 3 in total now) (I took the inspiration for this from cremi). Do you think we should do the download and h5 conversion at once for all the folds already?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's to provide the option to (only) use specific folds of the dataset (there are 3 in total now) (I took the inspiration for this from cremi).

That's not what I mean. Having the separate folds is ok. But with the glob here and then accessing only the first it looks like you're just selecting one image.

glob(os.path.join(path, tmp_fold, "*", "images", tmp_name, "images.npy"))[0]

Maybe I am also understanding the data organization wrong and it's many images stacked. But if that is the case we should save them in separate files to match the data organization expected by torch_em better.

Copy link
Contributor Author

@anwai98 anwai98 Oct 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahha yes, because we do the downloads first (for n number of folds) and then convert them once the download is done for all n folds.

For instance, we take all 3 folds into account, we download all of them first, and then go ahead to individually take care of h5 conversions.

(however, now that I think about it, I could just do sorted(glob(os.path.join(path, "**", "images.npy"), recursive=True)) (and same for masks.npy) and that should do the trick for me (to access the respective samples) as it's doing the downloads first and then conversion, nice)

gt_path = glob(os.path.join(path, tmp_fold, "*", "masks", tmp_name, "masks.npy"))[0]

f.create_dataset("images", data=np.load(img_path).transpose(3, 0, 1, 2))
f.create_dataset("masks", data=np.load(gt_path).transpose(3, 0, 1, 2))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused by how many channels we have here. I would have assumed that the image data has 3 dimensions (2 spatial ones and 1 channel dimension) and that the segmentation has 2 dimensions (only the spatial ones).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah never mind about the comment on the segmentations / labels, I saw the label trafo now. Still, I would only expect 3 dimensions, not four.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure, the input image and input label dimensions look like this (S x H x W x C)

where,

  • S is the number of slices
  • C is the number of channels
    • for the input images - it's RGB (3)
    • for the input labels - it's 6 (5 tissue types and last channel is the background)

torch_em/data/datasets/pannuke.py Outdated Show resolved Hide resolved
return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)


def label_trafo(labels):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this gives you a correct instance segmentation. Even if it does, it's too complex, you don't need the np.where. I would write it like this:

segmentation = np.zeros(labels.shape[1:])
max_id = 0
for label_channel in labels[:-1]:  # from what I understand we can just ignore the last channel because it encodes background
    this_labels = vigra.analysis.labelImage(label_channel)  # connected components to make sure we have an instance segmentation
    foreground = this_labels > 0
    segmentation[foreground] = this_labels[foreground] + max_id
    max_id = segmentation.max()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the snippets a slight bit with some minor updates (the major fix is on np.where)


for img_path, gt_path in zip(img_paths, gt_paths):
with h5py.File(os.path.join(path, f"pannuke_{fold}.h5"), "w") as f:
f.create_dataset("images", data=np.load(img_path).transpose(3, 0, 1, 2))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# chunks: (3. 1, 256, 256) - C x 1 x H x W
chunks = (data.shape[-1], 1) + data.shape[:2]
f.create_dataset(..., compression="gzip", chunks=chunks)

Copy link
Owner

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good now! I will also check it out later.

@constantinpape constantinpape merged commit abf2f50 into constantinpape:main Oct 11, 2023
2 checks passed
@anwai98 anwai98 deleted the pannuke branch November 7, 2023 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants