Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc updates to transforms, loss and UNETR #181

Merged
merged 7 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions scripts/load_mae_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from collections import OrderedDict
constantinpape marked this conversation as resolved.
Show resolved Hide resolved

import torch
from torch_em.model import UNETR

checkpoint = "imagenet.pth"
encoder_state = torch.load(checkpoint, map_location="cpu")["model"]
encoder_state = OrderedDict({
k: v for k, v in encoder_state.items()
if (k != "mask_token" and not k.startswith("decoder"))
})

unetr_model = UNETR(backbone="mae", encoder="vit_l", encoder_checkpoint=encoder_state)
26 changes: 26 additions & 0 deletions torch_em/loss/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,32 @@ def forward(self, input_, target):
)


class BCEDiceLoss(nn.Module):

def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7):
super().__init__()
self.alpha = alpha
self.beta = beta
self.channelwise = channelwise
self.eps = eps

# all torch_em classes should store init kwargs to easily recreate the init call
self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}

def forward(self, input_, target):
loss_dice = dice_score(
input_,
target,
invert=True,
channelwise=self.channelwise,
eps=self.eps
)
loss_bce = nn.functional.binary_cross_entropy(
input_, target
)
return self.alpha * loss_dice + self.beta * loss_bce


# TODO think about how to handle combined losses like this for mixed precision training
class BCEDiceLossWithLogits(nn.Module):

Expand Down
4 changes: 3 additions & 1 deletion torch_em/loss/distance_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DistanceLoss(nn.Module):
"""
def __init__(
self,
mask_distances_in_bg: bool,
mask_distances_in_bg: bool = True,
foreground_loss: nn.Module = DiceLoss(),
distance_loss: nn.Module = nn.MSELoss(reduction="mean")
) -> None:
Expand All @@ -26,6 +26,8 @@ def __init__(
self.distance_loss = distance_loss
self.mask_distances_in_bg = mask_distances_in_bg

self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}

def forward(self, input_, target):
assert input_.shape == target.shape, input_.shape
assert input_.shape[1] == 3, input_.shape
Expand Down
61 changes: 33 additions & 28 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections import OrderedDict
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Tuple

from .unet import Decoder, ConvBlock2d, Upsampler2d
from .vit import get_vision_transformer

Expand All @@ -22,37 +23,41 @@ class UNETR(nn.Module):

def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):

if backbone == "sam":
# If we have a SAM encoder, then we first try to load the full SAM Model
# (using micro_sam) and otherwise fall back on directly loading the encoder state
# from the checkpoint
try:
_, model = get_sam_model(
model_type=encoder,
checkpoint_path=checkpoint,
return_sam=True
)
encoder_state = model.image_encoder.state_dict()
except Exception:
# If we have a MAE encoder, then we directly load the encoder state
# from the checkpoint.
if isinstance(checkpoint, str):
if backbone == "sam":
# If we have a SAM encoder, then we first try to load the full SAM Model
# (using micro_sam) and otherwise fall back on directly loading the encoder state
# from the checkpoint
try:
_, model = get_sam_model(
model_type=encoder,
checkpoint_path=checkpoint,
return_sam=True
)
encoder_state = model.image_encoder.state_dict()
except Exception:
# If we have a MAE encoder, then we directly load the encoder state
# from the checkpoint.
encoder_state = torch.load(checkpoint)

elif backbone == "mae":
encoder_state = torch.load(checkpoint)

elif backbone == "mae":
encoder_state = torch.load(checkpoint)
else:
encoder_state = checkpoint

self.encoder.load_state_dict(encoder_state)

def __init__(
self,
backbone="sam",
encoder="vit_b",
decoder=None,
out_channels=1,
use_sam_stats=False,
use_mae_stats=False,
encoder_checkpoint_path=None,
final_activation=None,
backbone: str = "sam",
encoder: str = "vit_b",
decoder: Optional[nn.Module] = None,
out_channels: int = 1,
use_sam_stats: bool = False,
use_mae_stats: bool = False,
encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
final_activation: Optional[Union[str, nn.Module]] = None,
) -> None:
super().__init__()

Expand All @@ -61,8 +66,8 @@ def __init__(

print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(backbone=backbone, model=encoder)
if encoder_checkpoint_path is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint_path)
if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)

# parameters for the decoder network
depth = 3
Expand Down
31 changes: 26 additions & 5 deletions torch_em/transform/label.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np
import skimage.measure
import skimage.segmentation
Expand Down Expand Up @@ -192,16 +194,25 @@ def __call__(self, labels):


class DistanceTransform:
"""Compute distances to foreground.
"""Compute distances to foreground in the labels.

Args:
distances: Whether to compute the absolute distances.
directed_distances: Whether to compute the directed distances (vector distances).
normalize: Whether to normalize the computed distances.
max_distance: Maximal distance at which to threshold the distances.
foreground_id: Label id to which the distance is compute.
invert Whether to invert the distances:
func: Normalization function for the distances.
"""
eps = 1e-7

def __init__(
self,
distances=True,
directed_distances=False,
normalize=True,
max_distance=None,
distances: bool = True,
directed_distances: bool = False,
normalize: bool = True,
max_distance: Optional[float] = None,
foreground_id=1,
invert=False,
func=None
Expand Down Expand Up @@ -272,6 +283,16 @@ def __call__(self, labels):

class PerObjectDistanceTransform:
"""Compute normalized distances per object in a segmentation.

Args:
distances: Whether to compute the undirected distances.
boundary_distances: Whether to compute the distances to the object boundaries.
directed_distances: Whether to compute the directed distances (vector distances).
foreground: Whether to return a foreground channel.
apply_label: Whether to apply connected components to the labels before computing distances.
correct_centers: Whether to correct centers that are not in the objects.
min_size: Minimal size of objects for distance calculdation.
distance_fill_value: Fill value for the distances outside of objects.
"""
eps = 1e-7

Expand Down