Skip to content

Commit

Permalink
- option to predict edm and gcdm derivatives
Browse files Browse the repository at this point in the history
- fixed gcdm computation at interface between cells
  • Loading branch information
jeanollion committed May 25, 2024
1 parent 7793431 commit 8538aa1
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 50 deletions.
68 changes: 49 additions & 19 deletions distnet_2d/data/dydx_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import edt
from random import random
from .medoid import get_medoid
import time

class DyDxIterator(TrackingIterator):
def __init__(self,
Expand Down Expand Up @@ -217,9 +218,13 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
if return_next:
for j in range(0, frame_window):
self._erase_small_objects_at_edges(labelIms[i,...,frame_window+1+j], i, mask_to_erase_next, [m+j for m in mask_to_erase_chan_next], batch_by_channel)
object_slices = {}
for b, c in itertools.product(range(labelIms.shape[0]), range(labelIms.shape[-1])):
object_slices[(b, c)] = find_objects(labelIms[b,...,c])
edm = np.zeros(shape=labelIms.shape, dtype=np.float32)
for b,c in itertools.product(range(edm.shape[0]), range(edm.shape[-1])):
edm[b,...,c] = edt.edt(labelIms[b,...,c], black_border=False)
#_compute_edm(edm[b,...,c], labelIms[b,...,c], object_slices[(b, c)])
n_motion = 2 * frame_window if return_next else frame_window
if long_term:
n_motion = n_motion + (2 * ( frame_window - 1 ) if return_next else frame_window -1)
Expand Down Expand Up @@ -255,23 +260,27 @@ def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): #
for c in range(0, frame_window):
sel = [c, c+1]
l_c = [labels_and_centers[(i,s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, gdcmImPrev=centerIm[i,...,c] if self.return_center else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=labelIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=labelIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_mode=self.center_mode)
o_s = [object_slices[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, gdcmImPrev=centerIm[i,...,c] if self.return_center else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=labelIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=labelIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_mode=self.center_mode)
if return_next:
for c in range(frame_window, 2*frame_window):
sel = [c, c+1]
l_c = [labels_and_centers[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,..., c + 1] if self.return_center else None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=labelIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_mode=self.center_mode)
o_s = [object_slices[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, gdcmIm=centerIm[i,..., c + 1] if self.return_center else None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=labelIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_mode=self.center_mode)
if long_term:
off = 2*frame_window if return_next else frame_window
for c in range(0, frame_window-1):
sel = [c, frame_window]
l_c = [labels_and_centers[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
o_s = [object_slices[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], o_s, dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
if return_next:
for c in range(frame_window-1, 2*(frame_window-1)):
sel = [frame_window, c+3]
l_c = [labels_and_centers[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
o_s = [object_slices[(i, s)] for s in sel]
_compute_displacement(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c+off], o_s, dyIm[i,...,c+off], dxIm[i,...,c+off], dyImNext=dyImNext[i,...,c+off] if ndisp else None, dxImNext=dxImNext[i,...,c+off] if ndisp else None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_mode=self.center_mode)
other_output_channels = [chan_idx for chan_idx in self.output_channels if chan_idx!=1 and chan_idx!=2]
all_channels = [batch_by_channel[chan_idx] for chan_idx in other_output_channels]
channel_inc = 0
Expand Down Expand Up @@ -371,7 +380,7 @@ def _get_labels_and_centers(labelIm, edm, center_mode = "GEOMETRICAL"):
assert edm is not None and edm.shape == labelIm.shape
centers = []
for label in labels:
edm_label = ma.array(edm, mask = labelIm != label)
edm_label = ma.array(edm, mask=labelIm != label)
center = ma.argmax(edm_label, fill_value=0)
center = np.unravel_index(center, edm_label.shape)
centers.append(center)
Expand All @@ -381,7 +390,7 @@ def _get_labels_and_centers(labelIm, edm, center_mode = "GEOMETRICAL"):
elif center_mode == "SKELETON":
assert edm is not None and edm.shape == labelIm.shape
mass_centers = np.array(center_of_mass(labelIm, labelIm, labels))[np.newaxis] # 1, N_ob, 2
lm_coords = peak_local_max(edm, labels = labelIm) # N_lm, 2
lm_coords = peak_local_max(edm, labels=labelIm) # N_lm, 2
lm_coords_l = labelIm[lm_coords[:,0], lm_coords[:,1]] # N_lm
# labels in labelIm are not necessarily continuous -> replace by rank
label_rank = np.zeros(shape=(max(labels)+1,), dtype=np.int32)
Expand Down Expand Up @@ -480,7 +489,7 @@ def _get_category(n_neigh):
else:
return 2

def _compute_displacement(labels_map_centers, labelIm, labels_map_prev, dyIm, dxIm, dyImNext=None, dxImNext=None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=None, linkMultiplicityImNext=None, rankIm=None, rankImPrev=None, prevLabelArr=None, nextLabelArr=None, centerArr=None, centerArrPrev=None, center_mode:str= "MEDOID"):
def _compute_displacement(labels_map_centers, labelIm, labels_map_prev, object_slices, dyIm, dxIm, dyImNext=None, dxImNext=None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=None, linkMultiplicityImNext=None, rankIm=None, rankImPrev=None, prevLabelArr=None, nextLabelArr=None, centerArr=None, centerArrPrev=None, center_mode:str= "MEDOID"):
assert labelIm.shape[-1] == 2, f"invalid labelIm : {labelIm.shape[-1]} channels instead of 2"
assert (dxImNext is None) == (dyImNext is None)
if len(labels_map_centers[-1])==0: # no cells
Expand Down Expand Up @@ -526,10 +535,10 @@ def _compute_displacement(labels_map_centers, labelIm, labels_map_prev, dyIm, dx
rankImPrev[mask] = rank + 1
if gdcmIm is not None:
assert gdcmIm.shape == dyIm.shape, "invalid shape for center image"
_draw_centers(gdcmIm, labels_map_centers[-1], labelIm[...,1], geometrical_distance=center_mode == "GEOMETRICAL")
_draw_centers(gdcmIm, labels_map_centers[-1], labelIm[...,1], object_slices[1], geometrical_distance=center_mode == "GEOMETRICAL")
if gdcmImPrev is not None:
assert gdcmImPrev.shape == dyIm.shape, "invalid shape for center image prev"
_draw_centers(gdcmImPrev, labels_map_centers[0], labelIm[...,0], geometrical_distance=center_mode == "GEOMETRICAL")
_draw_centers(gdcmImPrev, labels_map_centers[0], labelIm[...,0], object_slices[0], geometrical_distance=center_mode == "GEOMETRICAL")
if centerArr is not None:
for rank, (label, center) in enumerate(labels_map_centers[-1].items()):
centerArr[rank,0] = center[0]
Expand All @@ -539,20 +548,27 @@ def _compute_displacement(labels_map_centers, labelIm, labels_map_prev, dyIm, dx
centerArrPrev[rank,0] = center[0]
centerArrPrev[rank,1] = center[1]

def _draw_centers(centerIm, labels_map_centers, labelIm, geometrical_distance:bool=False):
def _draw_centers(centerIm, labels_map_centers, labelIm, object_slices, geometrical_distance:bool=False):
if len(labels_map_centers)==0:
return
# geodesic distance to center
if not geometrical_distance:
count = 0
m = np.ones_like(labelIm)
for center in labels_map_centers.values():
if not (isnan(center[0]) or isnan(center[1])):
m[int(round(center[0])), int(round(center[1]))] = 0
count+=1
if count>0:
m = ma.masked_array(m, ~labelIm.astype(bool))
centerIm[:] = skfmm.distance(m)
shape = centerIm.shape
labelIm_dil = maximum_filter(labelIm, size=5)
non_zero = labelIm>0
labelIm_dil[non_zero] = labelIm[non_zero]
for (i, sl) in enumerate(object_slices):
if sl is not None:
center = labels_map_centers.get(i+1)
if not (isnan(center[0]) or isnan(center[1])):
sl = tuple( [slice(max(0, s.start - 2), min(s.stop + 2, ax - 1), s.step) for s, ax in zip(sl, shape)])
mask = labelIm_dil == i + 1
sub_m = mask[sl]
m = np.ones_like(sub_m)
#print(f"label: {i+1} slice: {sl}, center: {center}, sub_m {sub_m.shape}, coord: {(int(round(center[0]))-sl[0].start, int(round(center[1]))-sl[1].start)}", flush=True)
m[int(round(center[0]))-sl[0].start, int(round(center[1]))-sl[1].start] = 0
m = ma.masked_array(m, ~sub_m)
centerIm[sl][sub_m] = skfmm.distance(m)[sub_m]
else:
Y, X = centerIm.shape
Y, X = np.meshgrid(np.arange(Y, dtype = np.float32), np.arange(X, dtype = np.float32), indexing = 'ij')
Expand All @@ -565,3 +581,17 @@ def _draw_centers(centerIm, labels_map_centers, labelIm, geometrical_distance:bo
if mask.sum()>0:
d = np.sqrt(np.square(Y-center[0])+np.square(X-center[1]))
centerIm[mask] = d[mask]

def _compute_edm(edmIm, labelIm, object_slices):
shape = edmIm.shape
for (i, sl) in enumerate(object_slices):
if sl is not None:
sl = tuple([slice(s.start - 1 if s.start>0 else 0, s.stop+1 if s.stop<ax-1 else s.stop, s.step) for s, ax in zip(sl, shape)])
mask = labelIm == i+1
sub_m = mask[sl]
sub_edm = np.copy(edmIm[sl])
sub_edm[np.logical_not(sub_m)] = 0 # remove neighbor cells
m = np.zeros_like(sub_m)
m[sub_m] = 1
edmIm[sl][sub_m] = skfmm.distance(m)[sub_m]

Loading

0 comments on commit 8538aa1

Please sign in to comment.