Skip to content

Commit

Permalink
- dydx_iterator.py: option to return edm derivatives
Browse files Browse the repository at this point in the history
- fixed numpy derivatives
  • Loading branch information
jeanollion committed May 26, 2024
1 parent 57b1274 commit fd3a8ff
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 20 deletions.
30 changes: 26 additions & 4 deletions distnet_2d/data/dydx_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self,
array_keywords:list=['/linksPrev'],
elasticdeform_parameters:dict = None,
downscale_displacement_and_link_multiplicity=1,
return_edm_derivatives: bool = False,
return_center:bool = True,
center_mode:str = "MEDOID", # GEOMETRICAL, "EDM_MAX", "EDM_MEAN", "SKELETON", "MEDOID"
return_label_rank = False,
Expand All @@ -49,6 +50,7 @@ def __init__(self,
self.aug_frame_subsampling=aug_frame_subsampling
self.allow_frame_subsampling_direct_neigh=allow_frame_subsampling_direct_neigh
self.output_float16=output_float16
self.return_edm_derivatives=return_edm_derivatives
self.return_center=return_center
self.center_mode=center_mode.upper()
self.return_label_rank=return_label_rank
Expand Down Expand Up @@ -309,6 +311,11 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
nextLabelArr = nextLabelArr[:, 1:]

edm[edm==0] = -1
if self.return_edm_derivatives:
der_y, der_x = np.zeros_like(edm)
for b, c in itertools.product(range(edm.shape[0]), range(edm.shape[-1])):
derivatives_labelwise(edm[b,...,c], -1, der_y[b,...,c], der_x[b,...,c], labelIm[b,...,c], object_slices[(b, c)])
edm = np.concatenate([edm, der_y, der_x], -1)
all_channels.insert(channel_inc, edm)
if self.return_center:
channel_inc+=1
Expand Down Expand Up @@ -575,15 +582,30 @@ def _draw_centers(centerIm, labels_map_centers, labelIm, object_slices, geometri

def edt_antialiased(labelIm, object_slices):
shape = labelIm.shape
upsampled = np.kron(labelIm, np.ones((2, 2)))
upsampled = np.kron(labelIm, np.ones((2, 2))) # upsample by factor 2
w=np.ones(shape=(3, 3), dtype=np.int8)
for (i, sl) in enumerate(object_slices):
if sl is not None:
sl = tuple([slice(max(s.start*2 - 1, 0), min(s.stop*2 + 1, ax*2 - 1), s.step) for s, ax in zip(sl, shape)])
sub_labelIm = upsampled[sl]
mask = sub_labelIm == i + 1
new_mask = convolve(mask.astype(np.int8), weights=w, mode="nearest") > 4
sub_labelIm[mask] = 0
new_mask = convolve(mask.astype(np.int8), weights=w, mode="nearest") > 4 # smooth borders
sub_labelIm[mask] = 0 # replace mask by smoothed
sub_labelIm[new_mask] = i + 1
edm = np.divide(edt.edt(upsampled), 2)
return edm.reshape((shape[0], 2, shape[1], 2)).mean(-1).mean(1)
return edm.reshape((shape[0], 2, shape[1], 2)).mean(-1).mean(1) # downsample (bin) by factor 2


def derivatives_labelwise(image, bck_value, der_y, der_x, labelIm, object_slices):
shape = labelIm.shape
for (i, sl) in enumerate(object_slices):
if sl is not None:
sl = tuple([slice(max(s.start - 1, 0), min(s.stop, ax - 1), s.step) for s, ax in zip(sl, shape)])
sub_labelIm = labelIm[sl]
mask = sub_labelIm == i + 1
sub_im = np.copy(image[sl])
sub_im[np.logical_not(mask)] = bck_value # erase neighboring cells
sub_der_y, sub_der_x = der.der_2d(sub_im, 0, 1)
der_y[sl][mask] = sub_der_y[mask]
der_x[sl][mask] = sub_der_x[mask]
return der_y, der_x
32 changes: 16 additions & 16 deletions distnet_2d/utils/image_derivatives_np.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
import numpy as np

def der_2d(image, axis:int):
def der_2d(image, *axis:int):
"""
Compute the partial derivative (central difference approximation) of source in a particular dimension: d_f( x ) = ( f( x + 1 ) - f( x - 1 ) ) / 2.
Output tensors has the same shape as the input: [B, Y, X, C].
Output tensors has the same shape as the input: [Y, X].
Args:
image: Tensor with shape [B, Y, X, C].
axis: axis to compute gradient on (1 = dy or 2 = dx)
image: Tensor with shape [Y, X].
axis: axis to compute gradient on (0 = dy or 1 = dx)
Returns:
tensor dy or dx holding the vertical or horizontal partial derivative
gradients (1-step finite difference).
Raises:
ValueError: If `image` is not a 4D tensor.
ValueError: If `image` is not a 2D tensor.
"""
if len(axis) > 1:
return [der_2d(image, ax) for ax in axis]
else:
axis = axis[0]
assert image.ndim == 2, f'image_gradients expects a 2D tensor [Y, X], not {image.shape}'
assert axis in [0, 1], "axis must be in [0, 1]"
Y, X = image.shape
if axis == 1:
if axis == 0:
dy = np.divide(image[2:] - image[:-2], 2)
zeros = np.zeros(np.stack([1, X]), image.dtype)
dy = np.concatenate([zeros, dy, zeros], 0)
dy = np.concatenate([zeros, dy, zeros], axis)
return np.reshape(dy, image.shape)
else:
dx = np.divide(image[:, 2:] - image[:, :-2], 1)
dx = np.divide(image[:, 2:] - image[:, :-2], 2)
zeros = np.zeros(np.stack([Y, 1]), image.dtype)
dx = np.concatenate([zeros, dx, zeros], 1)
dx = np.concatenate([zeros, dx, zeros], axis)
return np.reshape(dx, image.shape)


Expand All @@ -36,9 +40,7 @@ def gradient_magnitude_2d(image=None, dy=None, dx=None, sqrt:bool=True):
assert dy is not None and dx is not None, "provide either image or partial derivatives"
assert dy.shape == dx.shape, "partial derivatives must have same shape"
else:
dy = der_2d(image, 0)
dx = der_2d(image, 1)

dy, dx = der_2d(image, 0, 1)
grad = dx * dx + dy * dy
if sqrt:
grad = np.sqrt(grad)
Expand All @@ -50,9 +52,7 @@ def laplacian_2d(image=None, dy=None, dx=None):
assert dy is not None and dx is not None, "provide either image or partial derivatives"
assert dy.shape == dx.shape, "partial derivatives must have same shape"
else:
dy = der_2d(image, 0)
dx = der_2d(image, 1)

dy, dx = der_2d(image, 0, 1)
ddy = der_2d(dy, 0)
ddx = der_2d(dx, 1)
return ddy + ddx
return ddy + ddx

0 comments on commit fd3a8ff

Please sign in to comment.