Skip to content

Commit

Permalink
Loader file addapted for pose estimation method
Browse files Browse the repository at this point in the history
  • Loading branch information
rezazad68 committed May 12, 2021
1 parent c3dfd30 commit 2b0b5ba
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions ivadomed/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def load_dataset(bids_df, data_list, transforms_params, model_params, target_suf
else:
# Task selection
task = imed_utils.get_task(model_params["name"])

dataset = BidsDataset(bids_df=bids_df,
subject_file_lst=data_list,
target_suffix=target_suffix,
Expand Down Expand Up @@ -415,7 +414,7 @@ def get_pair_slice(self, slice_index, gt_type="segmentation"):
else:
gt_slices = []
for gt_obj in gt_dataobj:
if gt_type == "segmentation":
if gt_type in ["segmentation", "pose_estimation"]:
if not isinstance(gt_obj, list): # annotation from only one rater
gt_slices.append(np.asarray(gt_obj[..., slice_index],
dtype=np.float32))
Expand Down Expand Up @@ -619,6 +618,7 @@ def __init__(self, filename_pairs, slice_axis=2, cache=True, transform=None, sli

def load_filenames(self):
"""Load preprocessed pair data (input and gt) in handler."""

for input_filenames, gt_filenames, roi_filename, metadata in self.filename_pairs:
roi_pair = SegmentationPair(input_filenames, roi_filename, metadata=metadata, slice_axis=self.slice_axis,
cache=self.cache, prepro_transforms=self.prepro_transforms)
Expand All @@ -628,7 +628,6 @@ def load_filenames(self):
soft_gt=self.soft_gt)

input_data_shape, _ = seg_pair.get_pair_shapes()

for idx_pair_slice in range(input_data_shape[-1]):
slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice, gt_type=self.task)
self.has_bounding_box = imed_obj_detect.verify_metadata(slice_seg_pair, self.has_bounding_box)
Expand Down Expand Up @@ -696,14 +695,20 @@ def __getitem__(self, index):
# Update metadata_input with metadata_roi
metadata_gt = imed_loader_utils.update_metadata(metadata_input, metadata_gt)

if self.task == "segmentation":



if self.task in ["segmentation", "pose_estimation"]:
# Run transforms on images
stack_gt, metadata_gt = self.transform(sample=seg_pair_slice["gt"],
metadata=metadata_gt,
data_type="gt")
# Make sure stack_gt is binarized
if stack_gt is not None and not self.soft_gt:
stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5).astype(np.uint8)
if self.task == "pose_estimation":
stack_gt = torch.moveaxis(torch.squeeze(stack_gt), -1, 0)


else:
# Force no transformation on labels for classification task
Expand Down

0 comments on commit 2b0b5ba

Please sign in to comment.