Skip to content

Commit

Permalink
metrics: improve objectwise computation speed using tf.TensorArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanollion committed Jun 4, 2024
1 parent d87b981 commit ba3ef40
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 24 deletions.
2 changes: 0 additions & 2 deletions distnet_2d/data/dydx_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,6 @@ def _get_category(n_neigh):
def _compute_displacement(labels_map_centers, labelIm, labels_map_prev, object_slices, dyIm, dxIm, dyImNext=None, dxImNext=None, gdcmIm=None, gdcmImPrev=None, linkMultiplicityIm=None, linkMultiplicityImNext=None, rankIm=None, rankImPrev=None, prevLabelArr=None, nextLabelArr=None, centerArr=None, centerArrPrev=None, center_mode:str= "MEDOID"):
assert labelIm.shape[-1] == 2, f"invalid labelIm : {labelIm.shape[-1]} channels instead of 2"
assert (dxImNext is None) == (dyImNext is None)
if len(labels_map_centers[-1])==0: # no cells
return
curLabelIm = labelIm[...,-1]
labels_prev = labels_map_centers[0].keys()
labels_prev_rank = {l:r for r, l in enumerate(labels_prev)}
Expand Down
21 changes: 12 additions & 9 deletions distnet_2d/utils/metrics_tf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from .objectwise_computation_tf import get_max_by_object_fun, coord_distance_fun, get_argmax_2d_by_object_fun, get_mean_by_object_fun, get_label_size, IoU, objectwise_compute
from .objectwise_computation_tf import get_max_by_object_fun, coord_distance_fun, get_argmax_2d_by_object_fun, get_mean_by_object_fun, get_label_size, IoU, objectwise_compute, objectwise_compute_channel


def get_metrics_fun(center_scale: float, max_objects_number: int = 0):
Expand All @@ -22,8 +22,8 @@ def get_metrics_fun(center_scale: float, max_objects_number: int = 0):
spa_max_fun = get_argmax_2d_by_object_fun()
mean_fun = get_mean_by_object_fun()
max_fun = get_max_by_object_fun(nan=1., channel_axis=False)
mean_fun_lm = get_mean_by_object_fun(nan=-1.)

mean_fun_true_lm = get_mean_by_object_fun(nan=1., channel_axis=False)
mean_fun_lm = get_mean_by_object_fun(nan=0.)

def fun(args):
edm, gdcm, dY, dX, lm, true_edm, true_dY, true_dX, true_lm, labels, prev_labels, true_center_ob = args
Expand All @@ -36,6 +36,8 @@ def fun(args):
lm = tf.transpose(lm, perm=[2, 0, 1, 3]) # T, Y, X, 3
true_lm = tf.transpose(true_lm, perm=[2, 0, 1])
ids, sizes, N = get_label_size(labels, max_objects_number) # (1, N), (1, N)
ids = ids[0]
sizes = sizes[0]
true_center_ob = true_center_ob[:, :N]

center_values = tf.math.exp(-tf.math.square(tf.math.divide(gdcm, scale)))
Expand All @@ -55,25 +57,26 @@ def fun(args):
#contour_IoU = IoU(true_contours, pred_contours, tolerance=True)
#edm_IoU = 0.5 * (edm_IoU + contour_IoU)

labels = labels[0]
# CENTER compute center coordinates per objects: spatial softmax of predicted gaussian function of GDCM
center_coord = objectwise_compute(center_values, [0], spa_max_fun, labels, ids, sizes) # (N, 2)
center_coord = objectwise_compute(center_values[0], spa_max_fun, labels, ids, sizes) # (N, 2)
center_spa_l2 = coord_distance_function(true_center_ob, center_coord)
center_spa_l2 = tf.cond(tf.math.is_nan(center_spa_l2), lambda: zero, lambda: center_spa_l2)

# CENTER 2 : absolute value of center. Target is 1, min value is 0.
center_max_value = objectwise_compute(center_values, [0], max_fun, labels, ids, sizes)
center_max_value = objectwise_compute(center_values[0], max_fun, labels, ids, sizes)
center_max_value = tf.reduce_min(center_max_value) # worst case among all cells = further away from 1 = min
# center_max_value = tf.cond(tf.math.is_nan(center_max_value), lambda: zero, lambda: center_max_value)

# DISPLACEMENT
dm = objectwise_compute(dYX, [0, 1], mean_fun, labels, ids, sizes, label_channels=[0, 0])
true_dm = objectwise_compute(true_dYX, [0, 1], mean_fun, labels, ids, sizes, label_channels=[0, 0])
dm = objectwise_compute_channel(dYX, mean_fun, labels, ids, sizes)
true_dm = objectwise_compute_channel(true_dYX, mean_fun, labels, ids, sizes)
dm_l2 = coord_distance_function(true_dm, dm)
dm_l2 = tf.cond(tf.math.is_nan(dm_l2), lambda: zero, lambda: dm_l2)

# Link Multiplicity
true_lm = tf.cast(objectwise_compute(true_lm[..., tf.newaxis], [0, 1], mean_fun_lm, labels, ids, sizes, label_channels=[0, 0]), tf.int32)[..., 0] - tf.cast(1, tf.int32)
lm = objectwise_compute(lm, [0, 1], mean_fun_lm, labels, ids, sizes, label_channels=[0, 0])
true_lm = tf.cast(objectwise_compute_channel(true_lm, mean_fun_true_lm, labels, ids, sizes), tf.int32) - tf.cast(1, tf.int32)
lm = objectwise_compute_channel(lm, mean_fun_lm, labels, ids, sizes)
lm = tf.math.argmax(lm, axis=-1, output_type=tf.int32)
errors = tf.math.not_equal(lm, true_lm)
lm_errors = tf.reduce_sum(tf.cast(errors, tf.float32))
Expand Down
38 changes: 25 additions & 13 deletions distnet_2d/utils/objectwise_computation_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,31 +113,43 @@ def fun(data, mask, size): # (Y, X), (Y, X), (1,)
return fun


def objectwise_compute(data, channels, fun, labels, ids, sizes, label_channels=None): # [(tensor, range, fun)], (T, Y, X (,C) ), (T, N), (T, N)
if label_channels is None:
label_channels = channels
else:
assert len(label_channels) == len(channels), "label_channels and channels must have same length"

def treat_im(args):
dc, lc = args
return _objectwise_compute_channel(data[dc], fun, labels[lc], ids[lc], sizes[lc])
return tf.map_fn(treat_im, (tf.convert_to_tensor(channels), tf.convert_to_tensor(label_channels)), fn_output_signature=data.dtype, parallel_iterations=len(channels))

def objectwise_compute(data, fun, labels, ids, sizes): # tensor (Y, X, ...) , fun, (Y, X), (N), ( N) -> (N, ...)

def _objectwise_compute_channel(data, fun, labels, ids, sizes): # tensor, fun, (Y, X), (N), ( N)
def non_null():
ta = tf.TensorArray(dtype=data.dtype, size=tf.shape(ids)[0])
for i in tf.range(tf.shape(ids)[0]):
mask = tf.cast(tf.math.equal(labels, ids[i]), tf.float32)
ta.write(i, fun(data, mask, sizes[i]))
ta = ta.write(i, fun(data, mask, sizes[i]))
return ta.stack()

def null():
return fun(data, tf.zeros_like(labels, dtype=tf.float32), 0)
return tf.cond(tf.math.equal(tf.size(ids), 0), null, non_null)


def objectwise_compute_channel(data, fun, labels, ids, sizes): # tensor (C, Y, X, ...) , fun, (Y, X), (N), ( N) -> (C, N, ...)

def non_null():
n_chan = tf.shape(data)[0]
n_obj = tf.shape(ids)[0]
ta = tf.TensorArray(dtype=data.dtype, size=n_obj * n_chan)
for i in tf.range(n_obj):
mask = tf.cast(tf.math.equal(labels, ids[i]), tf.float32)
for j in tf.range(n_chan):
ta = ta.write(j * n_obj + i, fun(data[j], mask, sizes[i]))
tensor = ta.stack()
return tf.reshape(tensor, shape=tf.concat([[n_chan, n_obj], tf.shape(tensor)[1:]], 0))

def null():
n_chan = tf.shape(data)[0]
ta = tf.TensorArray(dtype=data.dtype, size=n_chan)
mask = tf.zeros_like(labels, dtype=tf.float32)
for j in tf.range(n_chan):
ta = ta.write(j, fun(data[j], mask, 0))
tensor = ta.stack()
return tf.reshape(tensor, shape=tf.concat([[n_chan, 1], tf.shape(tensor)[1:]], 0))
return tf.cond(tf.math.equal(tf.size(ids), 0), null, non_null)


def coord_distance_fun(max:bool=True, sqrt:bool=False):
def loss(true, pred): # (C, N, 2)
Expand Down

0 comments on commit ba3ef40

Please sign in to comment.