Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intervertebral disc labeling using pose estimation #787

Closed
wants to merge 56 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
2b0b5ba
Loader file addapted for pose estimation method
rezazad68 May 12, 2021
eb84e99
Joint MSE loss for pose estimation method is added
rezazad68 May 12, 2021
c2675e4
Hourglass model name added to the MODEL_LIST
rezazad68 May 12, 2021
d1e0bf8
Hourglass network added
rezazad68 May 12, 2021
e2792c2
Strategy for training and evaluating the pose model is added
rezazad68 May 12, 2021
ee90ed4
Transforms for Resize and Vertebral disc splitting are added
rezazad68 May 12, 2021
647d51e
get task type slighly modified
rezazad68 May 12, 2021
b1ef47d
cv2 added
rezazad68 May 12, 2021
29a8c39
Config file for pose model
rezazad68 May 12, 2021
a502b37
Merge branch 'master' into master
rezazad68 May 13, 2021
434c53f
Update ivadomed/training.py
rezazad68 May 14, 2021
42175a9
cv2 changed to skimage and some extra comments are removed
rezazad68 May 18, 2021
83df56c
Merge branch 'master' of https://github.com/rezazad68/ivadomed
rezazad68 May 18, 2021
eb1a21c
opencv dependency is removed
rezazad68 May 23, 2021
7999937
Resize class is removed and Resample class slightly modified to resiz…
rezazad68 May 23, 2021
05fedf8
Pose config file is updated for resize section using Resample class
rezazad68 May 23, 2021
8984d48
channel order fixed for vertebral spitting to avoid move axis on the …
rezazad68 May 23, 2021
a7211e5
move axiis removed for gt data
rezazad68 May 23, 2021
9901ce0
converting to float shifted inside the loss function in line 184
rezazad68 May 23, 2021
2079bc5
float coversion is added
rezazad68 May 23, 2021
826f503
code link is added and variables used for constant vaues
rezazad68 May 23, 2021
0a5d8a3
Update ivadomed/transforms.py
rezazad68 May 24, 2021
8cf98d2
Update ivadomed/transforms.py
rezazad68 May 24, 2021
c6c568b
test for hourglass mode is added
rezazad68 May 29, 2021
e4787c0
intervertebral undo-transform is fixed
rezazad68 May 29, 2021
a8542c5
JointMSEloss revised to solve the batch issue
rezazad68 May 29, 2021
3002eea
JointMSEloss is update
rezazad68 May 30, 2021
373ed73
code is slightly refactored to be more clear
rezazad68 May 30, 2021
aab1039
evaluation part of the training modified for the pose method
rezazad68 May 30, 2021
1d4a751
VertebralSplitting updated to solve the extra dimension issue
rezazad68 May 31, 2021
5a9abfe
task type pose is removed, it was similar to the segmentation task so…
rezazad68 May 31, 2021
bc4e2dd
squeeze is removed and condition on task type pose is replaced with …
rezazad68 May 31, 2021
59ad92b
extra line is removed from transforms
rezazad68 May 31, 2021
f705f09
Update ivadomed/models.py
rezazad68 Jun 1, 2021
bfedc3a
3 empty lines are removed
rezazad68 Jun 1, 2021
eab08a3
modified for pose model
rezazad68 Jun 2, 2021
5376a0c
Default values for prepare_dataset_vertebral_labeling are removed
rezazad68 Jun 9, 2021
1cc3064
print added to show the progress
rezazad68 Jun 13, 2021
8d96972
new line added to the requirements.txt file to remove the conflict
rezazad68 Jun 17, 2021
6b94b1e
Empty lines are removed from the loader
rezazad68 Jun 25, 2021
2601274
Empty lines are removed
rezazad68 Jun 25, 2021
8a807c6
Empty lines are removed
rezazad68 Jun 25, 2021
ff7a244
Empty lines are removed
rezazad68 Jun 25, 2021
6939eaf
Empty lines are removed
rezazad68 Jun 25, 2021
796d1bb
space removed from loader
rezazad68 Jun 27, 2021
14aaa76
description added
rezazad68 Jun 28, 2021
a1aa7de
space removed
rezazad68 Jun 28, 2021
b262370
space added in line 37 utils
rezazad68 Jul 5, 2021
bf24b0d
empty lines added
rezazad68 Jul 5, 2021
e12d50d
slight modification
rezazad68 Jul 5, 2021
8d7d45b
reversed to previous format
rezazad68 Jul 5, 2021
fd825a8
empty line is removed
rezazad68 Jul 6, 2021
6487dd7
empty line is added
rezazad68 Jul 6, 2021
8de3c1e
line is added
rezazad68 Jul 6, 2021
fb12fe2
empty lines added
rezazad68 Jul 6, 2021
fb997d0
line added
rezazad68 Jul 6, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 88 additions & 0 deletions ivadomed/config/config_pose.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{
"command": "train",
"gpu_ids": [0],
"path_output": "/data/example/labeling_t0test1",
"model_name": "pose_model",
"debugging": false,
"loader_parameters": {
"path_data": ["/data/data-multi-subject"],
"target_suffix": ["_heatmap0"],
"extensions": [".nii.gz"],
"roi_params": {
"suffix": null,
"slice_filter_roi": null
},
"contrast_params": {
"training_validation": ["T1w"],
"testing": [ "T1w"],
"balance": {}
},
"slice_filter_params": {
"filter_empty_mask": true,
"filter_empty_input": true
},
"slice_axis": "sagittal",
"multichannel": false,
"soft_gt": true
},
"split_dataset": {
"fname_split": null,
"random_seed": 8,
"split_method" : "participant_id",
"data_testing": {"data_type": null, "data_value":[]},
"balance": null,
"train_fraction": 0.6,
"test_fraction": 0.2
},
"training_parameters": {
"batch_size": 1,
"loss": {
"name": "JointsMSELoss"
},
"training_time": {
"num_epochs": 100,
"early_stopping_patience": 100,
"early_stopping_epsilon": 0.001
},
"scheduler": {
"initial_lr":0.0005,
"lr_scheduler": {
"name": "CosineAnnealingLR",
"base_lr": 1e-6,
"max_lr": 1e-4
}
},
"balance_samples": false,
"mixup_alpha": null,
"transfer_learning": null
},
"default_model": {
"name": "HourglassNet"
},
"FiLMedUnet": {
"applied": false,
"metadata": "contrasts",
"film_layers": [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
},
"uncertainty": {
"epistemic": false,
"aleatoric": false,
"n_it": 0
},
"postprocessing": {
"remove_noise": {"thr": -1},
"binarize_prediction": {"thr": 0.5},
"uncertainty": {"thr": -1, "suffix": "_unc-vox.nii.gz"},
"remove_small": {"unit": "vox", "thr": 3}
},
"evaluation_parameters": {
"target_size": {"unit": "vox", "thr": [20, 100]},
"overlap": {"unit": "vox", "thr": 3}
},
"transformation": {
"Resize": {"height":256, "width":256},
"VertebralSplitting": {"max_joint": 11, "applied_to": ["gt"]},
"NumpyToTensor": {},
"NormalizeInstance": {"applied_to": ["im"]}
}
}
13 changes: 9 additions & 4 deletions ivadomed/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def load_dataset(bids_df, data_list, transforms_params, model_params, target_suf
else:
# Task selection
task = imed_utils.get_task(model_params["name"])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this change. There should be no change here.

dataset = BidsDataset(bids_df=bids_df,
subject_file_lst=data_list,
target_suffix=target_suffix,
Expand Down Expand Up @@ -418,7 +417,7 @@ def get_pair_slice(self, slice_index, gt_type="segmentation"):
else:
gt_slices = []
for gt_obj in gt_dataobj:
if gt_type == "segmentation":
if gt_type in ["segmentation", "pose_estimation"]:
if not isinstance(gt_obj, list): # annotation from only one rater
gt_slices.append(np.asarray(gt_obj[..., slice_index],
dtype=np.float32))
Expand Down Expand Up @@ -633,6 +632,7 @@ def __init__(self, filename_pairs, length=[], stride=[], slice_axis=2, cache=Tru

def load_filenames(self):
"""Load preprocessed pair data (input and gt) in handler."""

for input_filenames, gt_filenames, roi_filename, metadata in self.filename_pairs:
roi_pair = SegmentationPair(input_filenames, roi_filename, metadata=metadata, slice_axis=self.slice_axis,
cache=self.cache, prepro_transforms=self.prepro_transforms)
Expand All @@ -642,7 +642,6 @@ def load_filenames(self):
soft_gt=self.soft_gt)

input_data_shape, _ = seg_pair.get_pair_shapes()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this change. There should be no change here.

for idx_pair_slice in range(input_data_shape[-1]):
slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice, gt_type=self.task)
self.has_bounding_box = imed_obj_detect.verify_metadata(slice_seg_pair, self.has_bounding_box)
Expand Down Expand Up @@ -757,14 +756,20 @@ def __getitem__(self, index):
# Update metadata_input with metadata_roi
metadata_gt = imed_loader_utils.update_metadata(metadata_input, metadata_gt)

if self.task == "segmentation":



if self.task in ["segmentation", "pose_estimation"]:
# Run transforms on images
stack_gt, metadata_gt = self.transform(sample=seg_pair_slice["gt"],
metadata=metadata_gt,
data_type="gt")
# Make sure stack_gt is binarized
if stack_gt is not None and not self.soft_gt:
stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5).astype(np.uint8)
if self.task == "pose_estimation":
stack_gt = torch.moveaxis(torch.squeeze(stack_gt), -1, 0)
rezazad68 marked this conversation as resolved.
Show resolved Hide resolved
joshuacwnewton marked this conversation as resolved.
Show resolved Hide resolved


else:
# Force no transformation on labels for classification task
Expand Down
41 changes: 41 additions & 0 deletions ivadomed/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,47 @@ def forward(self, input, target):

return mean_loss

class JointsMSELoss(nn.Module):
"""
Joint MSE loss for pose estimation method.
.. seealso::
Alejandro Newell et al. "Stacked Hourglass Networks for Human Pose Estimation."
Proceedings of the European Conference on Computer Vision. 2016.
Args:
output (tensor): prediction mask mask
target_and_weights (list): list as follows:
-- ground truth mask (Tensor): estimated mask by the pose model
-- target_weight (Tensor): visibility of the intervertebral disc to control the loss value for the training process
joshuacwnewton marked this conversation as resolved.
Show resolved Hide resolved

returns:
tensor: sum of losses computed on (mask, target) with the params
"""
def __init__(self, use_target_weight=True):
super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='mean')
self.use_target_weight = use_target_weight

def forward(self, output, target_and_weights):
target, target_weight = target_and_weights

batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = 0

for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze()
heatmap_gt = heatmaps_gt[idx].squeeze()
if self.use_target_weight:
loss += 0.5 * self.criterion(
heatmap_pred.mul(target_weight[:, idx]),
heatmap_gt.mul(target_weight[:, idx])
)
else:
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)

return loss / num_joints
rezazad68 marked this conversation as resolved.
Show resolved Hide resolved

class LossCombination(nn.Module):
"""
Expand Down
8 changes: 6 additions & 2 deletions ivadomed/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
cudnn.benchmark = True

# List of not-default available models i.e. different from Unet
MODEL_LIST = ['Modified3DUNet', 'HeMISUnet', 'FiLMedUnet', 'resnet18', 'densenet121', 'Countception']
MODEL_LIST = ['Modified3DUNet', 'HeMISUnet', 'FiLMedUnet', 'resnet18', 'densenet121', 'Countception', 'HourglassNet']


def get_parser():
Expand Down Expand Up @@ -113,6 +113,9 @@ def film_normalize_data(context, model_params, ds_train, ds_valid, path_output):
return model_params, ds_train, ds_valid, train_onehotencoder





rezazad68 marked this conversation as resolved.
Show resolved Hide resolved
def get_dataset(bids_df, loader_params, data_lst, transform_params, cuda_available, device, ds_type):
ds = imed_loader.load_dataset(bids_df, **{**loader_params, **{'data_list': data_lst,
'transforms_params': transform_params,
Expand Down Expand Up @@ -535,7 +538,6 @@ def create_dataset_and_ivadomed_version_log(context):

def run_main():
imed_utils.init_ivadomed()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this change. There should be no change here.

parser = get_parser()
args = parser.parse_args()

Expand All @@ -555,4 +557,6 @@ def run_main():


if __name__ == "__main__":


rezazad68 marked this conversation as resolved.
Show resolved Hide resolved
run_main()