-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PanNuke Dataloader #153
Conversation
anwai98
commented
Oct 5, 2023
- Adding the dataloader for the PanNuke dataset (histopathology domain)
@constantinpape At the current stage, it looks like it's working. It would be great to have an overlook on this. There's a rather complicated label transformation taking care of a lot of stuff (briefed in the respective function in the file itself). Let me know how it looks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a couple of things that don't look correct to me, I left comments about that.
Another general note: I think it would make more sense to apply the transformation that make instance and semantic labels out of the masks when creating the data; this makes it easier to later use different kinds of label transforms on the data (e.g. connected components for sub-patches).
torch_em/data/datasets/pannuke.py
Outdated
tmp_name = tmp_fold.split("_")[0] + tmp_fold.split("_")[1] # name of a particular sub-directory (per fold) | ||
with h5py.File(os.path.join(path, f"pannuke_{tmp_fold}.h5"), "w") as f: | ||
img_path = glob(os.path.join(path, tmp_fold, "*", "images", tmp_name, "images.npy"))[0] | ||
gt_path = glob(os.path.join(path, tmp_fold, "*", "masks", tmp_name, "masks.npy"))[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only take a single of the files here? Is this on purpose?
I would have assumed that the code should look something like this instead:
img_paths = sorted(glob(..., "images.npy)) # imported to do sorted here, so that you get the same order for masks and images
gt_paths = sorted(glob(..., "masks.npy)) # imported to do sorted here, so that you get the same order for masks and images
for i, (img_path, gt_path) in enumerate(zip(img_paths, gt_paths)):
img = np.load(img_path)
gt = np.load(gt_path)
assert img,shape == gt.shape # or similar; but make sure that the shapes match, you might need to account for a different number of channels
out_path = os.path.join(..., f"pannuke_{tmp_fold}_{i}.h5"
with h5py.File(out_path, "w") as f:
...
That way you would use all the images in a given fold
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's to provide the option to (only) use specific folds of the dataset (there are 3 in total now) (I took the inspiration for this from cremi). Do you think we should do the download and h5 conversion at once for all the folds already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's to provide the option to (only) use specific folds of the dataset (there are 3 in total now) (I took the inspiration for this from cremi).
That's not what I mean. Having the separate folds is ok. But with the glob
here and then accessing only the first it looks like you're just selecting one image.
glob(os.path.join(path, tmp_fold, "*", "images", tmp_name, "images.npy"))[0]
Maybe I am also understanding the data organization wrong and it's many images stacked. But if that is the case we should save them in separate files to match the data organization expected by torch_em better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahha yes, because we do the downloads first (for n number of folds) and then convert them once the download is done for all n folds.
For instance, we take all 3 folds into account, we download all of them first, and then go ahead to individually take care of h5 conversions.
(however, now that I think about it, I could just do sorted(glob(os.path.join(path, "**", "images.npy"), recursive=True))
(and same for masks.npy
) and that should do the trick for me (to access the respective samples) as it's doing the downloads first and then conversion, nice)
torch_em/data/datasets/pannuke.py
Outdated
gt_path = glob(os.path.join(path, tmp_fold, "*", "masks", tmp_name, "masks.npy"))[0] | ||
|
||
f.create_dataset("images", data=np.load(img_path).transpose(3, 0, 1, 2)) | ||
f.create_dataset("masks", data=np.load(gt_path).transpose(3, 0, 1, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused by how many channels we have here. I would have assumed that the image data has 3 dimensions (2 spatial ones and 1 channel dimension) and that the segmentation has 2 dimensions (only the spatial ones).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah never mind about the comment on the segmentations / labels, I saw the label trafo now. Still, I would only expect 3 dimensions, not four.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be sure, the input image and input label dimensions look like this (S x H x W x C)
where,
- S is the number of slices
- C is the number of channels
- for the input images - it's RGB (3)
- for the input labels - it's 6 (5 tissue types and last channel is the background)
torch_em/data/datasets/pannuke.py
Outdated
return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) | ||
|
||
|
||
def label_trafo(labels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this gives you a correct instance segmentation. Even if it does, it's too complex, you don't need the np.where
. I would write it like this:
segmentation = np.zeros(labels.shape[1:])
max_id = 0
for label_channel in labels[:-1]: # from what I understand we can just ignore the last channel because it encodes background
this_labels = vigra.analysis.labelImage(label_channel) # connected components to make sure we have an instance segmentation
foreground = this_labels > 0
segmentation[foreground] = this_labels[foreground] + max_id
max_id = segmentation.max()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored the snippets a slight bit with some minor updates (the major fix is on np.where
)
torch_em/data/datasets/pannuke.py
Outdated
|
||
for img_path, gt_path in zip(img_paths, gt_paths): | ||
with h5py.File(os.path.join(path, f"pannuke_{fold}.h5"), "w") as f: | ||
f.create_dataset("images", data=np.load(img_path).transpose(3, 0, 1, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# chunks: (3. 1, 256, 256) - C x 1 x H x W
chunks = (data.shape[-1], 1) + data.shape[:2]
f.create_dataset(..., compression="gzip", chunks=chunks)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good now! I will also check it out later.