Skip to content

Commit

Permalink
metrics: improve objectwise computation speed using tf.TensorArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanollion committed Jun 4, 2024
1 parent a3606e4 commit d87b981
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 34 deletions.
43 changes: 16 additions & 27 deletions distnet_2d/utils/metrics_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,54 +38,43 @@ def fun(args):
ids, sizes, N = get_label_size(labels, max_objects_number) # (1, N), (1, N)
true_center_ob = true_center_ob[:, :N]

# EDM : foreground/background IoU + contour IoU
center_values = tf.math.exp(-tf.math.square(tf.math.divide(gdcm, scale)))
dYX = tf.stack([dY, dX], -1) # Y, X, T, 2
dYX = tf.transpose(dYX, perm=[2, 0, 1, 3]) # T, Y, X, 2
true_dYX = tf.stack([true_dY, true_dX], -1) # Y, X, T, 2
true_dYX = tf.transpose(true_dYX, perm=[2, 0, 1, 3]) # T, Y, X, 2
zero = tf.cast(0, edm.dtype)
one = tf.cast(1, edm.dtype)
pred_foreground = tf.math.greater(edm, zero)
true_foreground = tf.math.greater(labels, 0)
edm_IoU = IoU(true_foreground, pred_foreground, tolerance=True)

pred_contours = tf.math.logical_and(tf.math.greater(edm, zero), tf.math.less_equal(edm, tf.cast(1.5, edm.dtype)))
true_contours = tf.math.logical_and(tf.math.greater(true_edm, zero), tf.math.less_equal(true_edm, one))
contour_IoU = IoU(true_contours, pred_contours, tolerance=True)
edm_IoU = 0.5 * (edm_IoU + contour_IoU)
# EDM : foreground/background IoU #+ contour IoU
pred_foreground = tf.math.greater(edm, tf.cast(0.5, edm.dtype))
true_foreground = tf.math.greater(labels, tf.cast(0, labels.dtype))
edm_IoU = IoU(true_foreground, pred_foreground, tolerance=True)

# mask = tf.where(labels > 0, one, zero)
# edm_L2 = tf.math.divide_no_nan(tf.math.reduce_sum(mask * (edm - true_edm) ** 2), tf.cast(tf.reduce_sum(sizes), edm.dtype))
#pred_contours = tf.math.logical_and(tf.math.greater(edm, tf.cast(0.5, edm.dtype)), tf.math.less_equal(edm, tf.cast(1.5, edm.dtype)))
#true_contours = tf.math.logical_and(tf.math.greater(true_edm, tf.cast(0.5, edm.dtype)), tf.math.less_equal(true_edm, tf.cast(1.5, edm.dtype)))
#contour_IoU = IoU(true_contours, pred_contours, tolerance=True)
#edm_IoU = 0.5 * (edm_IoU + contour_IoU)

# CENTER compute center coordinates per objects: spatial softmax of predicted gaussian function of GDCM
center_values = tf.math.exp(-tf.math.square(tf.math.divide(gdcm, scale)))
center_coord = objectwise_compute(center_values, [0], spa_max_fun, labels, ids, sizes) # (N, 2)
#print(f"center loc: {tf.concat([true_center_ob[0], center_coord[0]], -1).numpy()}")
center_spa_l2 = coord_distance_function(true_center_ob, center_coord)
center_spa_l2 = tf.cond(tf.math.is_nan(center_spa_l2), lambda: zero, lambda: center_spa_l2)

# CENTER 2 : absolute value of center. Target is 1, min value is 0.
center_max_value = objectwise_compute(center_values, [0], max_fun, labels, ids, sizes)
center_max_value = tf.reduce_min(center_max_value) # worst case among all cells
#center_max_value = tf.cond(tf.math.is_nan(center_max_value), lambda: zero, lambda: center_max_value)
#print(f"center val: {center_max_value.numpy()}")
center_max_value = tf.reduce_min(center_max_value) # worst case among all cells = further away from 1 = min
# center_max_value = tf.cond(tf.math.is_nan(center_max_value), lambda: zero, lambda: center_max_value)

# motion: l2 of pred vs true center coordinates
dYX = tf.stack([dY, dX], -1) # Y, X, T, 2
dYX = tf.transpose(dYX, perm=[2, 0, 1, 3]) # T, Y, X, 2
# print(f"dXY shape: {dYX.shape} dY: {dY.shape}")
# DISPLACEMENT
dm = objectwise_compute(dYX, [0, 1], mean_fun, labels, ids, sizes, label_channels=[0, 0])

true_dYX = tf.stack([true_dY, true_dX], -1) # Y, X, T, 2
true_dYX = tf.transpose(true_dYX, perm=[2, 0, 1, 3]) # T, Y, X, 2
true_dm = objectwise_compute(true_dYX, [0, 1], mean_fun, labels, ids, sizes, label_channels=[0, 0])
#print(f"dM: {tf.concat([true_dm[0], dm[0]], -1).numpy()}")
# print(f"NEXT: dM: {tf.concat([true_dm[1], dm[1]], -1).numpy()}")
dm_l2 = coord_distance_function(true_dm, dm)
dm_l2 = tf.cond(tf.math.is_nan(dm_l2), lambda: zero, lambda: dm_l2)

# Link Multiplicity
true_lm = tf.cast(objectwise_compute(true_lm[..., tf.newaxis], [0, 1], mean_fun_lm, labels, ids, sizes, label_channels=[0, 0]), tf.int32)[..., 0] - tf.cast(1, tf.int32)
lm = objectwise_compute(lm, [0, 1], mean_fun_lm, labels, ids, sizes, label_channels=[0, 0])
lm = tf.math.argmax(lm, axis=-1, output_type=tf.int32)
#print(f"lm: {tf.stack([true_lm[0], lm[0]], -1).numpy()}")
# print(f"NEXT lm: {tf.stack([true_lm[1], lm[1]], -1).numpy()}")
errors = tf.math.not_equal(lm, true_lm)
lm_errors = tf.reduce_sum(tf.cast(errors, tf.float32))
return tf.stack([edm_IoU, -center_spa_l2, center_max_value, -dm_l2, -lm_errors])
Expand Down
23 changes: 16 additions & 7 deletions distnet_2d/utils/objectwise_computation_tf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf
import numpy as np


def get_label_size(labels, max_objects_number:int=0): # C, Y, X
N = max_objects_number if max_objects_number>0 else tf.math.maximum(tf.cast(1, labels.dtype), tf.math.reduce_max(labels))

Expand Down Expand Up @@ -64,6 +65,7 @@ def non_null():
return tf.cond(tf.math.equal(size, 0), lambda:tf.stack([nan, nan]), non_null) # when no values should return nan
return sam


def get_argmax_2d_by_object_fun(nan=float('NaN')):
nan = tf.cast(nan, tf.float32)
def fun(data, mask, size): # (Y, X)
Expand All @@ -76,6 +78,7 @@ def non_null():
return tf.cond(tf.math.equal(size, 0), lambda:tf.stack([nan, nan]), non_null) # when no values should return nan
return fun


def get_mean_by_object_fun(nan=float('NaN'), channel_axis:bool=True):
nan = tf.cast(nan, tf.float32)
if channel_axis:
Expand Down Expand Up @@ -119,15 +122,21 @@ def objectwise_compute(data, channels, fun, labels, ids, sizes, label_channels=N
def treat_im(args):
dc, lc = args
return _objectwise_compute_channel(data[dc], fun, labels[lc], ids[lc], sizes[lc])
return tf.map_fn(treat_im, (tf.convert_to_tensor(channels), tf.convert_to_tensor(label_channels)), fn_output_signature=data.dtype)
return tf.map_fn(treat_im, (tf.convert_to_tensor(channels), tf.convert_to_tensor(label_channels)), fn_output_signature=data.dtype, parallel_iterations=len(channels))


def _objectwise_compute_channel(data, fun, labels, ids, sizes): # tensor, fun, (Y, X), (N), ( N)
def treat_ob(args):
id, size = args
mask = tf.cond(tf.math.equal(id, 0), lambda:tf.zeros_like(labels, dtype=tf.float32), lambda:tf.cast(tf.math.equal(labels, id), tf.float32))
return fun(data, mask, size)
return tf.map_fn(treat_ob, (ids, sizes), fn_output_signature=data.dtype)
def non_null():
ta = tf.TensorArray(dtype=data.dtype, size=tf.shape(ids)[0])
for i in tf.range(tf.shape(ids)[0]):
mask = tf.cast(tf.math.equal(labels, ids[i]), tf.float32)
ta.write(i, fun(data, mask, sizes[i]))
return ta.stack()

def null():
return fun(data, tf.zeros_like(labels, dtype=tf.float32), 0)
return tf.cond(tf.math.equal(tf.size(ids), 0), null, non_null)



def coord_distance_fun(max:bool=True, sqrt:bool=False):
Expand Down Expand Up @@ -181,7 +190,7 @@ def IoU(true, pred, tolerance:bool=False):
true_inter = _dilate_mask(true) if tolerance else true
intersection = tf.math.count_nonzero(tf.math.logical_and(true_inter, pred), keepdims=False)
union = tf.math.count_nonzero(tf.math.logical_or(true, pred), keepdims=False)
return tf.cond(tf.math.equal(union, tf.cast(0, union.dtype)), lambda: tf.cast(1., tf.float32), lambda: tf.math.divide(tf.cast(intersection, tf.float32), tf.cast(union, tf.float32)))
return tf.cond(tf.math.equal(union, tf.cast(0, union.dtype)), lambda: tf.cast(1., tf.float32), lambda: tf.math.divide(tf.cast(intersection, tf.float32), tf.cast(union, tf.float32))) # if union is null -> metric is 1


def _dilate_mask(maskBYX):
Expand Down

0 comments on commit d87b981

Please sign in to comment.