Skip to content

Commit

Permalink
fixed on hardsamplemining metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanollion committed May 16, 2024
1 parent f5c2dc9 commit 105a279
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
23 changes: 11 additions & 12 deletions distnet_2d/utils/metrics_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .objectwise_computation_tf import get_max_by_object_fun, coord_distance_fun, get_argmax_2d_by_object_fun, get_mean_by_object_fun, get_label_size, IoU, objectwise_compute


def get_metrics_fun(center_scale: float, max_objects_number: int = 0, reduce: bool = False):
def get_metrics_fun(center_scale: float, max_objects_number: int = 0, reduce:bool=True):
"""
return metric function for disnet2D
assumes iterator in return_central_only= True mode (thus framewindow = 1 and next = true)
Expand All @@ -18,24 +18,23 @@ def get_metrics_fun(center_scale: float, max_objects_number: int = 0, reduce: bo
"""

scale = tf.cast(center_scale, tf.float32)
coord_distance_function = coord_distance_fun(max=True)
coord_distance_function_sqrt = coord_distance_fun(max=True, sqrt=True)
coord_distance_function = coord_distance_fun(max=True, sqrt=True)
spa_max_fun = get_argmax_2d_by_object_fun()
mean_fun = get_mean_by_object_fun()
max_fun = get_max_by_object_fun(nan=0., channel_axis=False)
mean_fun_lm = get_mean_by_object_fun(nan=-1)
max_fun = get_max_by_object_fun(nan=1., channel_axis=False)
mean_fun_lm = get_mean_by_object_fun(nan=-1.)

def fun(args):
edm, gdcm, dY, dX, lm, true_edm, true_dY, true_dX, true_lm, labels, prev_labels, true_center_ob = args
labels = tf.transpose(labels, perm=[2, 0, 1]) # (T, Y, X)
labels = tf.transpose(labels, perm=[2, 0, 1]) # (1, Y, X)
gdcm = tf.transpose(gdcm, perm=[2, 0, 1]) # (1, Y, X)
edm = tf.transpose(edm, perm=[2, 0, 1]) # (1, Y, X)
true_edm = tf.transpose(true_edm, perm=[2, 0, 1]) # (1, Y, X)
motion_shape = tf.shape(dY)
lm = tf.reshape(lm, shape=tf.concat([motion_shape[:2], motion_shape[-1:], [3]], 0))
lm = tf.transpose(lm, perm=[2, 0, 1, 3]) # T, Y, X, 3
true_lm = tf.transpose(true_lm, perm=[2, 0, 1])
ids, sizes, N = get_label_size(labels, max_objects_number) # (T, N), (T, N)
ids, sizes, N = get_label_size(labels, max_objects_number) # (1, N), (1, N)
true_center_ob = true_center_ob[:, :N]

# EDM : foreground/background IoU + contour IoU
Expand All @@ -57,13 +56,13 @@ def fun(args):
center_values = tf.math.exp(-tf.math.square(tf.math.divide(gdcm, scale)))
center_coord = objectwise_compute(center_values, [0], spa_max_fun, labels, ids, sizes) # (N, 2)
#print(f"center loc: {tf.concat([true_center_ob[0], center_coord[0]], -1).numpy()}")
center_spa_l2 = coord_distance_function_sqrt(true_center_ob, center_coord)
center_spa_l2 = coord_distance_function(true_center_ob, center_coord)
center_spa_l2 = tf.cond(tf.math.is_nan(center_spa_l2), lambda: zero, lambda: center_spa_l2)

# CENTER 2 : absolute value of center. Target is 1, min value is 0.
center_max_value = objectwise_compute(center_values, [0], max_fun, labels, ids, sizes)
center_max_value = tf.reduce_min(center_max_value) # worst case among all cells
center_max_value = tf.cond(tf.math.is_nan(center_max_value), lambda: zero, lambda: center_max_value)
#center_max_value = tf.cond(tf.math.is_nan(center_max_value), lambda: zero, lambda: center_max_value)
#print(f"center val: {center_max_value.numpy()}")

# motion: l2 of pred vs true center coordinates
Expand All @@ -78,10 +77,10 @@ def fun(args):
#print(f"dM: {tf.concat([true_dm[0], dm[0]], -1).numpy()}")
# print(f"NEXT: dM: {tf.concat([true_dm[1], dm[1]], -1).numpy()}")
dm_l2 = coord_distance_function(true_dm, dm)
dm_l2 = tf.cond(tf.math.is_nan(dm_l2), lambda: tf.cast(0, dm_l2.dtype), lambda: dm_l2)
dm_l2 = tf.cond(tf.math.is_nan(dm_l2), lambda: zero, lambda: dm_l2)

# Link Multiplicity
true_lm = tf.cast(objectwise_compute(true_lm[..., tf.newaxis], [0, 1], mean_fun_lm, labels, ids, sizes, label_channels=[0, 0]), tf.int32)[..., 0] - 1
true_lm = tf.cast(objectwise_compute(true_lm[..., tf.newaxis], [0, 1], mean_fun_lm, labels, ids, sizes, label_channels=[0, 0]), tf.int32)[..., 0] - tf.cast(1, tf.int32)
lm = objectwise_compute(lm, [0, 1], mean_fun_lm, labels, ids, sizes, label_channels=[0, 0])
lm = tf.math.argmax(lm, axis=-1, output_type=tf.int32)
#print(f"lm: {tf.stack([true_lm[0], lm[0]], -1).numpy()}")
Expand All @@ -93,7 +92,7 @@ def fun(args):
def metrics_fun(edm, gcdm, dY, dX, lm, true_edm, true_dY, true_dX, true_lm, labels, prev_labels, true_center_array):
metrics = tf.map_fn(fun, (edm, gcdm, dY, dX, lm, true_edm, true_dY, true_dX, true_lm, labels, prev_labels, true_center_array), fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))
if reduce:
metrics = [tf.reduce_mean(m) for m in metrics]
metrics = [tf.reduce_max(m) for m in metrics]
return metrics

return metrics_fun
6 changes: 3 additions & 3 deletions distnet_2d/utils/objectwise_computation_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

def get_label_size(labels, max_objects_number:int=0): # C, Y, X
N = max_objects_number if max_objects_number>0 else tf.math.reduce_max(labels)
N = max_objects_number if max_objects_number>0 else tf.math.maximum(tf.cast(1, labels.dtype), tf.math.reduce_max(labels))

def treat_image(im):

Expand Down Expand Up @@ -179,8 +179,8 @@ def _generate_kernel(sizeY, sizeX, C=1, O=0):

def IoU(true, pred, tolerance:bool=False):
true_inter = _dilate_mask(true) if tolerance else true
intersection = tf.math.count_nonzero(tf.math.logical_and(true_inter, pred))
union = tf.math.count_nonzero(tf.math.logical_or(true, pred))
intersection = tf.math.count_nonzero(tf.math.logical_and(true_inter, pred), keepdims=False)
union = tf.math.count_nonzero(tf.math.logical_or(true, pred), keepdims=False)
return tf.cond(tf.math.equal(union, tf.cast(0, union.dtype)), lambda: tf.cast(1., tf.float32), lambda: tf.math.divide(tf.cast(intersection, tf.float32), tf.cast(union, tf.float32)))


Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

setuptools.setup(
name="DiSTNet2D",
version="0.1.5",
version="0.1.6",
author="Jean Ollion",
author_email="jean.ollion@polytechnique.org",
description="tensorflow/keras implementation of DiSTNet 2D",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/jeanollion/distnet2d",
download_url='https://github.com/jeanollion/distnet2d/releases/download/v0.1.5/distnet2d-0.1.5.tar.gz',
download_url='https://github.com/jeanollion/distnet2d/releases/download/v0.1.6/distnet2d-0.1.6.tar.gz',
packages=setuptools.find_packages(),
keywords=['Segmentation', 'Tracking', 'Cell', 'Tensorflow', 'Keras'],
classifiers=[
Expand Down

0 comments on commit 105a279

Please sign in to comment.