Skip to content

Commit

Permalink
Add raw roi columns (#72)
Browse files Browse the repository at this point in the history
* Change roi_names to roi_indices

* Add raw_roi_names

* Add raw_roi_names
  • Loading branch information
LaiaPomar committed Feb 19, 2023
1 parent 52c3a8d commit 6327578
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 37 deletions.
20 changes: 11 additions & 9 deletions imgtools/autopipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def __init__(self,
# output ops
self.output = ImageAutoOutput(self.output_directory, self.output_streams, self.nnunet_info, self.is_nnunet_inference)

self.existing_roi_names = {"background": 0}
self.existing_roi_indices = {"background": 0}
if nnunet or nnunet_inference:
self.total_modality_counter = {}
self.patients_with_missing_labels = set()
Expand Down Expand Up @@ -492,7 +492,7 @@ def process_one_subject(self, subject_id):
img = pet

mask = self.make_binary_mask(structure_set, img,
self.existing_roi_names,
self.existing_roi_indices,
self.ignore_missing_regex,
roi_select_first=self.roi_select_first,
roi_separate=self.roi_separate)
Expand Down Expand Up @@ -522,10 +522,10 @@ def process_one_subject(self, subject_id):
else:
break

for name in mask.roi_names.keys():
if name not in self.existing_roi_names.keys():
self.existing_roi_names[name] = len(self.existing_roi_names)
mask.existing_roi_names = self.existing_roi_names
for name in mask.roi_indices.keys():
if name not in self.existing_roi_indices.keys():
self.existing_roi_indices[name] = len(self.existing_roi_indices)
mask.existing_roi_indices = self.existing_roi_indices


if self.v:
Expand All @@ -549,7 +549,7 @@ def process_one_subject(self, subject_id):
if self.v:
print(mask_arr.shape)

roi_names_list = list(mask.roi_names.keys())
roi_names_list = list(mask.roi_indices.keys())
for i in range(mask_arr.shape[0]):
new_mask = sitk.GetImageFromArray(np.transpose(mask_arr[i]))
new_mask.CopyInformation(mask)
Expand All @@ -563,6 +563,8 @@ def process_one_subject(self, subject_id):
metadata.update(structure_set.metadata)

metadata[f"metadata_{colname}"] = [structure_set.roi_names]
for roi, labels in mask.raw_roi_names.items():
metadata[f"raw_labels_{roi}"] = labels

print(subject_id, "SAVED MASK ON", conn_to)

Expand Down Expand Up @@ -625,12 +627,12 @@ def save_data(self):
if self.is_nnunet: #dataset.json for nnunet and .sh file to run to process it
imagests_path = pathlib.Path(self.output_directory, "imagesTs").as_posix()
images_test_location = imagests_path if os.path.exists(imagests_path) else None
# print(self.existing_roi_names)
# print(self.existing_roi_indices)
generate_dataset_json(pathlib.Path(self.output_directory, "dataset.json").as_posix(),
pathlib.Path(self.output_directory, "imagesTr").as_posix(),
images_test_location,
tuple(self.nnunet_info["modalities"].keys()),
{v:k for k, v in self.existing_roi_names.items()},
{v:k for k, v in self.existing_roi_indices.items()},
os.path.split(self.input_directory)[1])
_, child = os.path.split(self.output_directory)
shell_path = pathlib.Path(self.output_directory, child.split("_")[1]+".sh").as_posix()
Expand Down
46 changes: 25 additions & 21 deletions imgtools/modules/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def wrapper(img, *args, **kwargs):
result = f(img, *args, **kwargs)
if isinstance(img, Segmentation):
result = sitk.Cast(result, sitk.sitkVectorUInt8)
return Segmentation(result, roi_names=img.roi_names)
return Segmentation(result, roi_indices=img.roi_indices, raw_roi_names=img.raw_roi_names)
else:
return result
return wrapper
Expand All @@ -30,32 +30,36 @@ def map_over_labels(segmentation, f, include_background=False, return_segmentati
res = [f(segmentation.get_label(label=label), **kwargs) for label in labels]
if return_segmentation and isinstance(res[0], sitk.Image):
res = [sitk.Cast(r, sitk.sitkUInt8) for r in res]
res = Segmentation(sitk.Compose(*res), roi_names=segmentation.roi_names)
res = Segmentation(sitk.Compose(*res), roi_indices=segmentation.roi_indices, raw_roi_names=segmentation.raw_roi_names)
return res


class Segmentation(sitk.Image):
def __init__(self, segmentation, roi_names=None, existing_roi_names=None):
def __init__(self, segmentation, roi_indices=None, existing_roi_indices=None, raw_roi_names=None):
super().__init__(segmentation)
self.num_labels = self.GetNumberOfComponentsPerPixel()
if not roi_names:
self.roi_names = {f"label_{i}": i for i in range(1, self.num_labels+1)}
if not roi_indices:
self.roi_indices = {f"label_{i}": i for i in range(1, self.num_labels+1)}
else:
self.roi_names = roi_names
if 0 in self.roi_names.values():
self.roi_names = {k : v+1 for k, v in self.roi_names.items()}
if len(self.roi_names) != self.num_labels:
self.roi_indices = roi_indices
if 0 in self.roi_indices.values():
self.roi_indices = {k : v+1 for k, v in self.roi_indices.items()}
if not raw_roi_names:
raw_roi_names={}
else:
self.raw_roi_names = raw_roi_names
if len(self.roi_indices) != self.num_labels:
for i in range(1, self.num_labels+1):
if i not in self.roi_names.values():
self.roi_names[f"label_{i}"] = i
self.existing_roi_names = existing_roi_names
if i not in self.roi_indices.values():
self.roi_indices[f"label_{i}"] = i
self.existing_roi_indices = existing_roi_indices

def get_label(self, label=None, name=None, relabel=False):
if label is None and name is None:
raise ValueError("Must pass either label or name.")

if label is None:
label = self.roi_names[name]
label = self.roi_indices[name]

if label == 0:
# background is stored implicitly and needs to be computed
Expand All @@ -81,11 +85,11 @@ def to_label_image(self):
def __getitem__(self, idx):
res = super().__getitem__(idx)
if isinstance(res, sitk.Image):
res = Segmentation(res, self.roi_names)
res = Segmentation(res, roi_indices=self.roi_indices, raw_roi_names=self.raw_roi_names)
return res

def __repr__(self):
return f"<Segmentation with ROIs: {self.roi_names!r}>"
return f"<Segmentation with ROIs: {self.roi_indices!r}>"

def generate_sparse_mask(self, verbose=False) -> SparseMask:
"""
Expand All @@ -101,11 +105,11 @@ def generate_sparse_mask(self, verbose=False) -> SparseMask:
SparseMask
The sparse mask object.
"""
# print("asdlkfjalkfsjg", self.roi_names)
# print("asdlkfjalkfsjg", self.roi_indices)
mask_arr = np.transpose(sitk.GetArrayFromImage(self))
for name in self.roi_names.keys():
self.roi_names[name] = self.existing_roi_names[name]
# print(self.roi_names)
for name in self.roi_indices.keys():
self.roi_indices[name] = self.existing_roi_indices[name]
# print(self.roi_indices)

sparsemask_arr = np.zeros(mask_arr.shape[1:])

Expand All @@ -115,7 +119,7 @@ def generate_sparse_mask(self, verbose=False) -> SparseMask:
if len(mask_arr.shape) == 4:
for i in range(mask_arr.shape[0]):
slice = mask_arr[i, :, :, :]
slice *= list(self.roi_names.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
slice *= list(self.roi_indices.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
if verbose:
res = self._max_adder(sparsemask_arr, slice)
sparsemask_arr = res[0]
Expand All @@ -126,7 +130,7 @@ def generate_sparse_mask(self, verbose=False) -> SparseMask:
else:
sparsemask_arr = mask_arr

sparsemask = SparseMask(sparsemask_arr, self.roi_names)
sparsemask = SparseMask(sparsemask_arr, self.roi_indices)

if verbose:
if len(voxels_with_overlap) != 0:
Expand Down
11 changes: 6 additions & 5 deletions imgtools/modules/structureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_mask(self, reference_image, mask, label, idx, continuous):
def to_segmentation(self, reference_image: sitk.Image,
roi_names: Dict[str, str] = None,
continuous: bool = True,
existing_roi_names: Dict[str, int] = None,
existing_roi_indices: Dict[str, int] = None,
ignore_missing_regex: bool = False,
roi_select_first: bool = False,
roi_separate: bool = False) -> Segmentation:
Expand Down Expand Up @@ -202,22 +202,23 @@ def to_segmentation(self, reference_image: sitk.Image,
size = reference_image.GetSize()[::-1] + (len(labels),)
mask = np.zeros(size, dtype=np.uint8)

seg_roi_names = {}
seg_roi_indices = {}
if roi_names != {} and isinstance(roi_names, dict):
for i, (name, label_list) in enumerate(labels.items()):
for label in label_list:
self.get_mask(reference_image, mask, label, i, continuous)
seg_roi_names[name] = i
seg_roi_indices[name] = i

else:
for name, label in labels.items():
self.get_mask(reference_image, mask, name, label, continuous)
seg_roi_names = {"_".join(k): v for v, k in groupby(labels, key=lambda x: labels[x])}
seg_roi_indices = {"_".join(k): v for v, k in groupby(labels, key=lambda x: labels[x])}


mask[mask > 1] = 1
mask = sitk.GetImageFromArray(mask, isVector=True)
mask.CopyInformation(reference_image)
mask = Segmentation(mask, roi_names=seg_roi_names, existing_roi_names=existing_roi_names) #in the segmentation, pass all the existing roi names and then process is in the segmentation class
mask = Segmentation(mask, roi_indices=seg_roi_indices, existing_roi_indices=existing_roi_indices, raw_roi_names=labels) #in the segmentation, pass all the existing roi names and then process is in the segmentation class

return mask

Expand Down
4 changes: 2 additions & 2 deletions imgtools/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ def __init__(self,
def __call__(self,
structure_set: StructureSet,
reference_image: sitk.Image,
existing_roi_names: Dict[str, int],
existing_roi_indices: Dict[str, int],
ignore_missing_regex: bool,
roi_select_first: bool = False,
roi_separate: bool = False) -> Segmentation:
Expand All @@ -1496,7 +1496,7 @@ def __call__(self,
return structure_set.to_segmentation(reference_image,
roi_names=self.roi_names,
continuous=self.continuous,
existing_roi_names=existing_roi_names,
existing_roi_indices=existing_roi_indices,
ignore_missing_regex=ignore_missing_regex,
roi_select_first=roi_select_first,
roi_separate=roi_separate)
Expand Down

0 comments on commit 6327578

Please sign in to comment.