Skip to content

Commit

Permalink
- improved edm computation
Browse files Browse the repository at this point in the history
- fixed labelwise derivative computation.
- dydx iterator can return edm derivatives that are used in loss by distnet model
  • Loading branch information
jeanollion committed May 26, 2024
1 parent fd3a8ff commit eb8f6c6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 51 deletions.
49 changes: 27 additions & 22 deletions distnet_2d/data/dydx_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
object_slices[(b, c)] = find_objects(labelIms[b,...,c])
edm = np.zeros(shape=labelIms.shape, dtype=np.float32)
for b,c in itertools.product(range(edm.shape[0]), range(edm.shape[-1])):
edm[b,...,c] = edt_antialiased(labelIms[b,...,c], object_slices[(b, c)])
edm[b,...,c] = edt_smooth(labelIms[b,...,c], object_slices[(b, c)])
#edm[b,...,c] = edt.edt(labelIms[b,...,c], black_border=False)
n_motion = 2 * frame_window if return_next else frame_window
if long_term:
Expand All @@ -244,7 +244,7 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
linkMultiplicityImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
centerIm = np.zeros(labelIms.shape, dtype=self.dtype) if self.return_center else None
if self.return_label_rank:
labelIm = np.zeros(labelIms.shape, dtype=np.int32)
rankIm = np.zeros(labelIms.shape, dtype=np.int32)
prevLabelArr = np.zeros(labelIms.shape[:1]+(n_motion, self.n_label_max), dtype=np.int32)
nextLabelArr = np.zeros(labelIms.shape[:1] + (n_motion, self.n_label_max), dtype=np.int32)
centerArr = np.zeros(labelIms.shape[:1]+labelIms.shape[-1:]+(self.n_label_max,2), dtype=np.float32)
Expand All @@ -268,13 +268,13 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
sel = [c, c+1]
l_c = [labels_and_centers[(i,s)] for s in sel]
o_s = [object_slices[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, gdcmImPrev=centerIm[i,...,c] if self.return_center else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=labelIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=labelIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_mode=self.center_mode)
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, gdcmImPrev=centerIm[i,...,c] if self.return_center else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=rankIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_mode=self.center_mode)
if return_next:
for c in range(frame_window, 2*frame_window):
sel = [c, c+1]
l_c = [labels_and_centers[(i, s)] for s in sel]
o_s = [object_slices[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,..., c + 1] if self.return_center else None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=labelIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_mode=self.center_mode)
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,..., c + 1] if self.return_center else None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_mode=self.center_mode)
if long_term:
off = 2*frame_window if return_next else frame_window
for c in range(0, frame_window-1):
Expand All @@ -288,9 +288,15 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
l_c = [labels_and_centers[(i, s)] for s in sel]
o_s = [object_slices[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], o_s, dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
other_output_channels = [chan_idx for chan_idx in self.output_channels if chan_idx!=1 and chan_idx!=2]
all_channels = [batch_by_channel[chan_idx] for chan_idx in other_output_channels]
channel_inc = 0

edm[edm == 0] = -1
if self.return_edm_derivatives:
der_y, der_x = np.zeros_like(edm), np.zeros_like(edm)
for b, c in itertools.product(range(edm.shape[0]), range(edm.shape[-1])):
derivatives_labelwise(edm[b, ..., c], -1, der_y[b, ..., c], der_x[b, ..., c], labelIms[b, ..., c], object_slices[(b, c)])
if self.return_central_only:
der_y = der_y[..., 1:2]
der_x = der_x[..., 1:2]
if self.return_central_only: # select only central frame for edm / center and only displacement / link multiplicity related to central frame
edm = edm[..., 1:2]
centerIm = centerIm[..., 1:2]
Expand All @@ -304,18 +310,16 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
if self.return_link_multiplicity:
linkMultiplicityImNext = linkMultiplicityImNext[..., 1:]
if self.return_label_rank:
labelIm = labelIm[..., 1:2]
rankIm = rankIm[..., 1:2]
centerArr = centerArr[: , 1:2]
prevLabelArr = prevLabelArr[:, :1]
if ndisp:
nextLabelArr = nextLabelArr[:, 1:]

edm[edm==0] = -1
if self.return_edm_derivatives:
der_y, der_x = np.zeros_like(edm)
for b, c in itertools.product(range(edm.shape[0]), range(edm.shape[-1])):
derivatives_labelwise(edm[b,...,c], -1, der_y[b,...,c], der_x[b,...,c], labelIm[b,...,c], object_slices[(b, c)])
edm = np.concatenate([edm, der_y, der_x], -1)
other_output_channels = [chan_idx for chan_idx in self.output_channels if chan_idx != 1 and chan_idx != 2]
all_channels = [batch_by_channel[chan_idx] for chan_idx in other_output_channels]
channel_inc = 0
all_channels.insert(channel_inc, edm)
if self.return_center:
channel_inc+=1
Expand Down Expand Up @@ -350,7 +354,7 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
if self.return_label_rank:
if ndisp:
prevLabelArr = np.concatenate([prevLabelArr, nextLabelArr], 1)
all_channels.append(labelIm)
all_channels.append(rankIm)
all_channels.append(prevLabelArr)
all_channels.append(centerArr)
return all_channels
Expand Down Expand Up @@ -580,7 +584,7 @@ def _draw_centers(centerIm, labels_map_centers, labelIm, object_slices, geometri
centerIm[mask] = d[mask]


def edt_antialiased(labelIm, object_slices):
def edt_smooth(labelIm, object_slices):
shape = labelIm.shape
upsampled = np.kron(labelIm, np.ones((2, 2))) # upsample by factor 2
w=np.ones(shape=(3, 3), dtype=np.int8)
Expand All @@ -592,20 +596,21 @@ def edt_antialiased(labelIm, object_slices):
new_mask = convolve(mask.astype(np.int8), weights=w, mode="nearest") > 4 # smooth borders
sub_labelIm[mask] = 0 # replace mask by smoothed
sub_labelIm[new_mask] = i + 1
edm = np.divide(edt.edt(upsampled), 2)
return edm.reshape((shape[0], 2, shape[1], 2)).mean(-1).mean(1) # downsample (bin) by factor 2

edm = edt.edt(upsampled)
edm = edm.reshape((shape[0], 2, shape[1], 2)).mean(-1).mean(1) # downsample (bin) by factor 2
edm = np.divide(edm + 0.5, 2) #convert to pixel unit
edm[edm <= 0.5] = 0
return edm

def derivatives_labelwise(image, bck_value, der_y, der_x, labelIm, object_slices):
shape = labelIm.shape
for (i, sl) in enumerate(object_slices):
if sl is not None:
sl = tuple([slice(max(s.start - 1, 0), min(s.stop, ax - 1), s.step) for s, ax in zip(sl, shape)])
sub_labelIm = labelIm[sl]
mask = sub_labelIm == i + 1
sl = tuple([slice(max(s.start - 1, 0), min(s.stop + 1, ax - 1), s.step) for s, ax in zip(sl, shape)])
mask = labelIm[sl] == i + 1
sub_im = np.copy(image[sl])
sub_im[np.logical_not(mask)] = bck_value # erase neighboring cells
sub_der_y, sub_der_x = der.der_2d(sub_im, 0, 1)
der_y[sl][mask] = sub_der_y[mask]
der_x[sl][mask] = sub_der_x[mask]
return der_y, der_x
return der_y, der_x
22 changes: 12 additions & 10 deletions distnet_2d/model/distnet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def __init__(self, *args, spatial_dims,
center_loss_weight:float=1,
displacement_loss_weight:float=1,
category_loss_weight:float=1,
edm_loss=PseudoHuber(1), edm_derivatives:bool=False,
gcdm_loss=PseudoHuber(1), gcdm_derivatives:bool=False,
edm_loss=PseudoHuber(1), edm_derivative_loss:bool=False,
gcdm_loss=PseudoHuber(1), gcdm_derivative_loss:bool=False,
displacement_loss=PseudoHuber(1),
category_weights=None, # array of weights: [normal, division, no previous cell] or None = auto
category_class_frequency_range=[1/50, 50],
Expand All @@ -38,9 +38,9 @@ def __init__(self, *args, spatial_dims,
self.predict_next_displacement=predict_next_displacement
self.frame_window = frame_window
self.edm_loss = edm_loss
self.edm_derivatives = edm_derivatives
self.edm_derivative_loss = edm_derivative_loss
self.gcdm_loss = gcdm_loss
self.gcdm_derivatives = gcdm_derivatives
self.gcdm_derivative_loss = gcdm_derivative_loss
self.displacement_loss = displacement_loss
self.predict_gcdm_derivatives = predict_gcdm_derivatives
self.predict_edm_derivatives = predict_edm_derivatives
Expand Down Expand Up @@ -94,23 +94,25 @@ def train_step(self, data):
gcdm, gcdm_dy, gcdm_dx = tf.split(y_pred[1], num_or_size_splits=3, axis=-1)
else:
gcdm, gcdm_dy, gcdm_dx = y_pred[1], None, None

if self.predict_edm_derivatives or self.edm_derivative_loss:
true_edm, true_edm_dy, true_edm_dx = tf.split(y[0], num_or_size_splits=3, axis=-1)
else:
true_edm, true_edm_dy, true_edm_dx = y[0], None, None
# compute loss
losses = dict()
loss_weights = dict()

cell_mask = tf.math.greater(y[0], 0.5)
cell_mask_interior = tf.math.greater(y[0], 1.5) if self.gcdm_derivatives or self.predict_gcdm_derivatives else None
cell_mask = tf.math.greater(true_edm, 0.5)
cell_mask_interior = tf.math.greater(true_edm, 1.5) if self.gcdm_derivative_loss or self.predict_gcdm_derivatives else None
# edm
if edm_weight>0:
#edm_loss = self.edm_loss(y[0], edm)
edm_loss = compute_loss_derivatives(y[0], edm, self.edm_loss, pred_dy=edm_dy, pred_dx=edm_dx, derivative_loss=self.edm_derivatives, laplacian_loss=self.edm_derivatives)
edm_loss = compute_loss_derivatives(true_edm, edm, self.edm_loss, true_dy=true_edm_dy, true_dx=true_edm_dx, pred_dy=edm_dy, pred_dx=edm_dx, derivative_loss=self.edm_derivative_loss, laplacian_loss=self.edm_derivative_loss)
losses["edm"] = edm_loss
loss_weights["edm"] = edm_weight

# center
if center_weight>0:
center_loss = compute_loss_derivatives(y[1], gcdm, self.gcdm_loss, pred_dy=gcdm_dy, pred_dx=gcdm_dx, mask=cell_mask, mask_interior=cell_mask_interior, derivative_loss=self.gcdm_derivatives)
center_loss = compute_loss_derivatives(y[1], gcdm, self.gcdm_loss, pred_dy=gcdm_dy, pred_dx=gcdm_dx, mask=cell_mask, mask_interior=cell_mask_interior, derivative_loss=self.gcdm_derivative_loss)
losses["center"] = center_loss
loss_weights["center"] = center_weight

Expand Down
32 changes: 13 additions & 19 deletions distnet_2d/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,43 +28,37 @@ def call(self, y_true, y_pred):
return tf.multiply(self.delta_sq, tf.sqrt(1. + tf.square((y_true - y_pred)/self.delta)) - 1.)


def compute_loss_derivatives(true, pred, loss_fun, pred_dy=None, pred_dx=None, pred_grad=None, pred_lap=None, mask=None, mask_interior=None, derivative_loss: bool = False, gradient_loss: bool = False, laplacian_loss: bool = False):
def compute_loss_derivatives(true, pred, loss_fun, true_dy=None, true_dx=None, pred_dy=None, pred_dx=None, pred_lap=None, mask=None, mask_interior=None, derivative_loss: bool = False, laplacian_loss: bool = False):
loss = loss_fun(true, tf.where(mask, pred, 0) if mask is not None else pred)
#print(f"compute loss with mask: {mask is not None} interior: {mask_interior is not None} der: {derivative_loss} grad: {gradient_loss} lap: {laplacian_loss} pred lap: {y_pred_lap is not None} pred dy: {y_pred_dy is not None} pred dx: {y_pred_dx is not None}", flush=True)
if derivative_loss or gradient_loss or laplacian_loss or pred_dy is not None or pred_dx is not None or pred_grad is not None or pred_lap is not None:
if derivative_loss or laplacian_loss or pred_dy is not None or pred_dx is not None or pred_lap is not None:
if mask_interior is None:
mask_interior = mask
dy, dx = der.der_2d(true, 1), der.der_2d(true, 2)
if derivative_loss or gradient_loss or laplacian_loss:
if true_dy is None:
true_dy = der.der_2d(true, 1)
if true_dx is None:
true_dx = der.der_2d(true, 2)
if derivative_loss or laplacian_loss:
dy_pred, dx_pred = der.der_2d(pred, 1), der.der_2d(pred, 2)
if laplacian_loss or pred_lap is not None:
lap = der.laplacian_2d(None, dy, dx)
true_lap = der.laplacian_2d(None, true_dy, true_dx)
if pred_lap is not None:
pred_lap = tf.where(mask_interior, pred_lap, 0) if mask_interior is not None else pred_lap
loss = loss + loss_fun(lap, pred_lap)
loss = loss + loss_fun(true_lap, pred_lap)
if laplacian_loss:
lap_pred = der.laplacian_2d(None, dy_pred, dx_pred)
lap_pred = tf.where(mask_interior, lap_pred, 0) if mask_interior is not None else lap_pred
loss = loss + loss_fun(lap, lap_pred)
if gradient_loss or pred_grad is not None:
grad = dy * dy + dx * dx
if pred_grad is not None:
pred_grad = tf.where(mask_interior, pred_grad, 0) if mask_interior is not None else pred_grad
loss = loss + loss_fun(grad, pred_grad)
if gradient_loss:
grad_pred = dy_pred * dy_pred + dx_pred * dx_pred
grad_pred = tf.where(mask_interior, grad_pred, 0) if mask_interior is not None else grad_pred
loss = loss + loss_fun(grad, grad_pred)
loss = loss + loss_fun(true_lap, lap_pred)
if pred_dy is not None:
pred_dy = tf.where(mask_interior, pred_dy, 0) if mask_interior is not None else pred_dy
loss = loss + loss_fun(dy, pred_dy)
loss = loss + loss_fun(true_dy, pred_dy)
if pred_dx is not None:
pred_dx = tf.where(mask_interior, pred_dx, 0) if mask_interior is not None else pred_dx
loss = loss + loss_fun(dx, pred_dx)
loss = loss + loss_fun(true_dx, pred_dx)
if derivative_loss:
dy_pred = tf.where(mask_interior, dy_pred, 0) if mask_interior is not None else dy_pred
dx_pred = tf.where(mask_interior, dx_pred, 0) if mask_interior is not None else dx_pred
loss = loss + loss_fun(dy, dy_pred) + loss_fun(dx, dx_pred)
loss = loss + loss_fun(true_dy, dy_pred) + loss_fun(true_dx, dx_pred)
return loss


Expand Down

0 comments on commit eb8f6c6

Please sign in to comment.