Skip to content

Commit

Permalink
Squash all commits from PR #787 for clarity
Browse files Browse the repository at this point in the history
This commit contains all commits between:
  - 2b0b5ba (Start)
  - 8d7d45b (End)

This commit was also rebased onto ccd82d2
from the master branch, because that commit was the last point at which
master was merged into Reza's branch. (See: a502b37)

All former commits are preserved on the branch 'ra/intervertebral-disc-labeling'.

Co-authored-by: Reza Azad <rezazad68@gmail.com>
  • Loading branch information
rezazad68 authored and joshuacwnewton committed Jul 6, 2021
1 parent ccd82d2 commit d727fd9
Show file tree
Hide file tree
Showing 9 changed files with 478 additions and 26 deletions.
94 changes: 94 additions & 0 deletions ivadomed/config/config_pose.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
"command": "train",
"gpu_ids": [0],
"path_output": "/data/example/labeling_t0test1",
"model_name": "pose_model",
"debugging": false,
"loader_parameters": {
"path_data": ["/data/data-multi-subject"],
"target_suffix": ["_heatmap0"],
"extensions": [".nii.gz"],
"roi_params": {
"suffix": null,
"slice_filter_roi": null
},
"contrast_params": {
"training_validation": ["T1w"],
"testing": [ "T1w"],
"balance": {}
},
"slice_filter_params": {
"filter_empty_mask": true,
"filter_empty_input": true
},
"slice_axis": "sagittal",
"multichannel": false,
"soft_gt": true
},
"split_dataset": {
"fname_split": null,
"random_seed": 8,
"split_method" : "participant_id",
"data_testing": {"data_type": null, "data_value":[]},
"balance": null,
"train_fraction": 0.6,
"test_fraction": 0.2
},
"training_parameters": {
"batch_size": 1,
"loss": {
"name": "JointsMSELoss"
},
"training_time": {
"num_epochs": 10,
"early_stopping_patience": 100,
"early_stopping_epsilon": 0.001
},
"scheduler": {
"initial_lr":0.0005,
"lr_scheduler": {
"name": "CosineAnnealingLR",
"base_lr": 1e-6,
"max_lr": 1e-4
}
},
"balance_samples": false,
"mixup_alpha": null,
"transfer_learning": null
},
"default_model": {
"name": "HourglassNet"
},
"FiLMedUnet": {
"applied": false,
"metadata": "contrasts",
"film_layers": [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
},
"uncertainty": {
"epistemic": false,
"aleatoric": false,
"n_it": 0
},
"postprocessing": {
"remove_noise": {"thr": -1},
"binarize_prediction": {"thr": 0.5},
"uncertainty": {"thr": -1, "suffix": "_unc-vox.nii.gz"},
"remove_small": {"unit": "vox", "thr": 3}
},
"evaluation_parameters": {
"target_size": {"unit": "vox", "thr": [20, 100]},
"overlap": {"unit": "vox", "thr": 3}
},
"transformation": {
"Resample":
{
"wspace": 256,
"hspace": 256,
"dspace": 1,
"flag_pixel": true
},
"VertebralSplitting": {"max_joint": 11, "applied_to": ["gt"]},
"NumpyToTensor": {},
"NormalizeInstance": {"applied_to": ["im"]}
}
}
2 changes: 0 additions & 2 deletions ivadomed/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def load_dataset(bids_df, data_list, transforms_params, model_params, target_suf
else:
# Task selection
task = imed_utils.get_task(model_params["name"])

dataset = BidsDataset(bids_df=bids_df,
subject_file_lst=data_list,
target_suffix=target_suffix,
Expand Down Expand Up @@ -642,7 +641,6 @@ def load_filenames(self):
soft_gt=self.soft_gt)

input_data_shape, _ = seg_pair.get_pair_shapes()

for idx_pair_slice in range(input_data_shape[-1]):
slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice, gt_type=self.task)
self.has_bounding_box = imed_obj_detect.verify_metadata(slice_seg_pair, self.has_bounding_box)
Expand Down
47 changes: 47 additions & 0 deletions ivadomed/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,53 @@ def forward(self, input, target):

return mean_loss

class JointsMSELoss(nn.Module):
"""
Joint MSE loss for pose estimation method.
.. seealso::
Alejandro Newell et al. "Stacked Hourglass Networks for Human Pose Estimation."
Proceedings of the European Conference on Computer Vision. 2016.
Args:
output (tensor): prediction mask mask
target_and_weights (list): list as follows:
-- ground truth mask (Tensor): estimated mask by the pose model
-- target_weight (Tensor): visibility of the intervertebral disc to control the loss value for the training process
returns:
tensor: sum of losses computed on (mask, target) with the params
"""
def __init__(self):
super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='mean')

def forward(self, output, target_and_weights):
if len(target_and_weights) == 2 and all(isinstance(n, torch.Tensor) for n in target_and_weights):
target, target_weight = target_and_weights
use_target_weight = True
elif isinstance(target_and_weights, torch.Tensor):
target, target_weight = target_and_weights, None
use_target_weight = False
else:
raise ValueError("Input must either a Tensor (target) or a list of 2 Tensors (target, weights).")
target = target.float()
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = 0

for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze()
heatmap_gt = heatmaps_gt[idx].squeeze()
if use_target_weight:
loss += 0.5 * self.criterion(
heatmap_pred.mul(target_weight[:, idx:idx+1]),
heatmap_gt.mul(target_weight[:, idx:idx+1])
)
else:
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)

return loss / num_joints

class LossCombination(nn.Module):
"""
Expand Down
3 changes: 1 addition & 2 deletions ivadomed/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
cudnn.benchmark = True

# List of not-default available models i.e. different from Unet
MODEL_LIST = ['Modified3DUNet', 'HeMISUnet', 'FiLMedUnet', 'resnet18', 'densenet121', 'Countception']
MODEL_LIST = ['Modified3DUNet', 'HeMISUnet', 'FiLMedUnet', 'resnet18', 'densenet121', 'Countception', 'HourglassNet']


def get_parser():
Expand Down Expand Up @@ -535,7 +535,6 @@ def create_dataset_and_ivadomed_version_log(context):

def run_main():
imed_utils.init_ivadomed()

parser = get_parser()
args = parser.parse_args()

Expand Down
207 changes: 207 additions & 0 deletions ivadomed/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,213 @@ def forward(self, x):

return net

class HGBottleneck(nn.Module):
"""Bottleneck of the Hourglass Network.
Hourglass network is an auto-encoder decoder architecture for the task of pose estimation. In this architecture,
Features are processed across all scales and consolidated to best capture the various spatial relationships associated with the pose.
code from: https://github.com/bearpaw/pytorch-pose/blob/master/pose/models/hourglass.py
.. seealso::
Alejandro Newell et al. "Stacked Hourglass Networks for Human Pose Estimation."
Proceedings of the European Conference on Computer Vision. 2016.
Args:
inplanes (int): number of input channels of the convolution kernel
planes (int): number of output channels of the convolution kernel
stride (int): stride value in the convolution operation
"""
expansion = 2

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(HGBottleneck, self).__init__()

self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=True)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)

out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)

out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual

return out

class Hourglass(nn.Module):
"""Stacked Hourglass Networks.
Hourglass network is an auto-encoder decoder architecture for the task of pose estimation. In this architecture,
Features are processed across all scales and consolidated to best capture the various spatial relationships associated with the pose.
code from: https://github.com/bearpaw/pytorch-pose/blob/master/pose/models/hourglass.py
.. seealso::
Alejandro Newell et al. "Stacked Hourglass Networks for Human Pose Estimation."
Proceedings of the European Conference on Computer Vision. 2016.
Args:
block (nn.Module): Deep bottleneck structure
num_blocks (int): number of blocks in the encoder structure
depth (int): hourglass depth
name (str): model's name used for call in configuration file.
"""
def __init__(self, block, num_blocks, planes, depth):
super(Hourglass, self).__init__()
self.depth = depth
self.block = block
self.hg = self._make_hour_glass(block, num_blocks, planes, depth)

def _make_residual(self, block, num_blocks, planes):
layers = []
for i in range(num_blocks):
layers.append(block(planes*block.expansion, planes))
return nn.Sequential(*layers)

def _make_hour_glass(self, block, num_blocks, planes, depth):
branch = 3
hg = []
for i in range(depth):
res = []
for j in range(branch):
res.append(self._make_residual(block, num_blocks, planes))
if i == 0:
res.append(self._make_residual(block, num_blocks, planes))
hg.append(nn.ModuleList(res))
return nn.ModuleList(hg)

def _hour_glass_forward(self, n, x):
up1 = self.hg[n-1][0](x)
low1 = F.max_pool2d(x, 2, stride=2)
low1 = self.hg[n-1][1](low1)

if n > 1:
low2 = self._hour_glass_forward(n-1, low1)
else:
low2 = self.hg[n-1][3](low1)
low3 = self.hg[n-1][2](low2)
up2 = F.interpolate(low3, scale_factor=2)
out = up1 + up2
return out

def forward(self, x):
return self._hour_glass_forward(self.depth, x)


class HourglassNet(nn.Module):
"""Stacked Hourglass Networks.
Hourglass network is an auto-encoder decoder architecture for the task of pose estimation. In this architecture,
Features are processed across all scales and consolidated to best capture the various spatial relationships associated with the pose.
code from: https://github.com/bearpaw/pytorch-pose/blob/master/pose/models/hourglass.py
.. seealso::
Alejandro Newell et al. "Stacked Hourglass Networks for Human Pose Estimation."
Proceedings of the European Conference on Computer Vision. 2016.
Args:
num_stacks (int): number of stacked hourglass networks
num_blocks (int): number of blocks in the encoder structure
num_classes (int): number of joints to learn the pose structure
name (str): model's name used for call in configuration file.
"""
def __init__(self, num_stacks=2, num_blocks=4, num_classes=11, block = HGBottleneck, **kwargs):
super(HourglassNet, self).__init__()

self.inplanes = 64
self.num_feats = 128
self.num_stacks = num_stacks
self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=True)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_residual(block, self.inplanes, 1)
self.layer2 = self._make_residual(block, self.inplanes, 1)
self.layer3 = self._make_residual(block, self.num_feats, 1)
self.maxpool = nn.MaxPool2d(2, stride=2)
self.scale_score = nn.Upsample(scale_factor=4, mode='bilinear')

# build hourglass modules
ch = self.num_feats*block.expansion
hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
for i in range(num_stacks):
hg.append(Hourglass(block, num_blocks, self.num_feats, depth= 4))
res.append(self._make_residual(block, self.num_feats, num_blocks))
fc.append(self._make_fc(ch, ch))
score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True))
if i < num_stacks-1:
fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True))
score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True))
self.hg = nn.ModuleList(hg)
self.res = nn.ModuleList(res)
self.fc = nn.ModuleList(fc)
self.score = nn.ModuleList(score)
self.fc_ = nn.ModuleList(fc_)
self.score_ = nn.ModuleList(score_)

def _make_residual(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=True),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def _make_fc(self, inplanes, outplanes):
bn = nn.BatchNorm2d(inplanes)
conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True)
return nn.Sequential(
conv,
bn,
self.relu,
)

def forward(self, x):
out = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)

x = self.layer1(x)
x = self.maxpool(x)
x = self.layer2(x)
x = self.layer3(x)

for i in range(self.num_stacks):
y = self.hg[i](x)
y = self.res[i](y)
y = self.fc[i](y)
score = self.score[i](y)
out.append(self.scale_score(score))
if i < self.num_stacks-1:
fc_ = self.fc_[i](y)
score_ = self.score_[i](score)
x = x + fc_ + score_

return out


def set_model_for_retrain(model_path, retrain_fraction, map_location, reset=True):
"""Set model for transfer learning.
Expand Down

0 comments on commit d727fd9

Please sign in to comment.