# Topological clustering algorithm

In [2]:
import tensorflow as tf
print(f"Tensorflow version : {tf.__version__}")

Tensorflow version : 2.2.0-rc1


In [3]:
# Moliere radius for lead (Wigmans2017, Appendix B)
moliereRadius = tf.constant(16.0 / 3) # moliere radius in units of smallest irreducible unit

## Cluster maker

In [11]:
from skimage.measure import label

def get_seedlist(m, tseed):
    """Return seed list. Seeds are cells in the image (m) with a signal
    above the seed threshold (tseed). Skimage.measure.label is used to 
    find connected components (neighbours are the 8 surrounding cells). 
    An entry consists of a proto-cluster id and the cell indices (i,j).
    
    Args:
      m: A 2D `Tensor` with shape [height, width] of numeric type.
      tseed: A scalar with same type as m.

    Returns:
      A `Tensor` with shape [None, 3] and type `int64`.
    """
    mask = tf.math.greater(m, tseed)
    f = lambda x: label(x, connectivity=2)
    labels = tf.numpy_function(f, [mask], Tout=tf.int64)
    sij = tf.where(labels)
    sid = tf.expand_dims(tf.gather_nd(labels, sij), axis=1)
    return tf.concat([sid, sij], axis=-1)

In [12]:
def sort_seedlist(m, s):
    """Sorts seed list (s) in descending order of energy.
    
    Args:
      m: A 2D `Tensor` with shape [height, width] of numeric type.
      s: A `Tensor` with shape [None, 3] and type `int64`.
      
    Returns:
      A `Tensor` with the same shape and type as s.
    """
    sid, sij = tf.split(s, [1,2], axis=1)
    seeds = tf.gather_nd(m, sij)
    indices = tf.math.top_k(seeds, k=tf.shape(s)[0])[1]
    return tf.gather(s, indices)

In [None]:
def get_mask_from_indices(indices, dense_shape):
    """Converts indices to boolean mask.
    
    Args:
      indices: A `Tensor` of shape [None, 2] and type `int64`.
      dense_type: A `Tensor` with shape [2,] and type `int64`.
    
    Returns:
      A `Tensor` with shape [heigh, width] and type `bool`.
    """
    values = tf.ones(tf.shape(indices)[0], tf.bool)
    sp = tf.SparseTensor(indices, values, dense_shape)
    return tf.sparse.to_dense(tf.sparse.reorder(sp))

In [5]:
def get_neighbours(idx, dense_shape,
                   incl=tf.zeros([0,2], tf.int64),
                   excl=tf.zeros([0,2], tf.int64)):
    """Return indices of 8 neighbors to the given index (idx).
    
    Args:
      idx: A `Tensor` with shape [2,] and type `int64`.
      dense_shape: A `Tensor` with shape [2,] and type `int64`.
      incl: A `Tensor` with shape [None, 2] and the same type as idx.
      excl: Same as for incl.
      
    Returns:
      A `Tensor` with shape [None, 2] and type `int64`.
    """
    # get indices of the 8 neighbours to idx
    i, j = tf.unstack(idx)
    irng = tf.range(i-1, i+2)
    jrng = tf.range(j-1, j+2)
    ii, jj = tf.meshgrid(irng, jrng, indexing='ij')
    ii = tf.reshape(ii, [-1,1])
    jj = tf.reshape(jj, [-1,1])
    indices = tf.concat([ii, jj], axis=1)
    mask_idx = tf.reduce_any(tf.not_equal(indices, idx), axis=1)
    indices = tf.boolean_mask(indices, mask_idx)
    
    # handle boundaries
    height, width = tf.unstack(dense_shape)
    bound = tf.concat([[[height, -1]], [[-1, width]]], axis=0)
    mask_bound = tf.not_equal(tf.expand_dims(indices, axis=-1), bound)
    mask_bound = tf.reduce_all(mask_bound, axis=[1,2])
    indices = tf.boolean_mask(indices, mask_bound)
    
    # get only indices in incl and remove indices in excl
    mask_indices = get_mask_from_indices(indices, dense_shape)
    mask_excl = get_mask_from_indices(excl, dense_shape)
    mask_incl = get_mask_from_indices(incl, dense_shape)
    mask_final = tf.logical_and(mask_indices, tf.logical_not(mask_excl))
    pred = tf.equal(tf.size(incl), 0)
    true_fn = lambda: mask_final
    false_fn = lambda: tf.logical_and(mask_final, mask_incl)
    mask_final = tf.cond(pred, true_fn, false_fn)
    return tf.where(mask_final)

In [14]:
def merge_adjacent_proto(nj, siid, m, p, s, l):
    """Merge adjacent proto-clusters."""
    siid = tf.reshape(siid, [-1,])
    sid, sidx = tf.split(s, [1,2], axis=1)
    lid, lidx = tf.split(l, [1,2], axis=1)
    proto_id, proto_idx = tf.split(p, [1,2], axis=1)
    dense_shape = tf.shape(m, tf.int64)
    nnj = get_neighbours(nj, dense_shape, excl=sidx)
    nnj_d = get_mask_from_indices(nnj, dense_shape)
    proto_id = tf.reshape(proto_id, [-1,])
    proto_sp = tf.SparseTensor(proto_idx, proto_id, dense_shape)
    proto_d = tf.sparse.to_dense(tf.sparse.reorder(proto_sp))
    mask_proto = tf.not_equal(proto_d, 0)
    mask_neigh = tf.logical_and(nnj_d, mask_proto)
    indices = tf.where(tf.logical_and(mask_neigh, tf.greater(m, tneighbour)))
#     indices = tf.where(tf.logical_and(mask_neigh))
    values = tf.gather_nd(proto_d, indices)
    values = tf.boolean_mask(values, tf.not_equal(values, siid))
    
    # loop over values
    k0 = tf.constant(0)
    ck = lambda k, p, s, l: tf.less(k, tf.shape(values)[0])
    def bk(k, p, s, l):
        neigh_id = tf.gather(values, k)
        pnew = tf.where(tf.equal(proto_d, neigh_id), siid*tf.ones_like(p), p)
        snew = tf.where(tf.equal(sid, neigh_id), siid*tf.ones_like(s), s)
        lnew = tf.where(tf.equal(lid, neigh_id), siid*tf.ones_like(l), l)
        return [tf.add(k, 1), pnew, snew, lnew]
    p, sid, lid = tf.while_loop(
        ck, bk, loop_vars=[k0, proto_d, sid, lid],
        shape_invariants=[k0.get_shape(), proto_d.get_shape(),
                          sid.get_shape(), lid.get_shape()])[1:]               
    indices = tf.where(tf.not_equal(p, 0))
    values = tf.expand_dims(tf.gather_nd(p, indices), axis=1)
    pnew = tf.concat([values, indices], axis=1)
    snew = tf.concat([sid, sidx], axis=1)
    lnew = tf.concat([lid, lidx], axis=1)
    return pnew, snew, lnew

In [None]:
def bj_maker(j, m, p, s, l, n, siid):
    """Body of while loop in bi_maker."""
    nj = tf.gather(n, j)
    nval = tf.gather_nd(m, nj)
    nnew = tf.expand_dims(tf.concat([siid, nj], axis=0), axis=0)

    def above_tneighbor(siid=siid):
        # append cells to both proto-clusters and neighbor seed list
        [pnew, snew, lnew] = merge_adjacent_proto(nj, siid, m, p, s, l)
        pnew = tf.concat([pnew, nnew], axis=0)
        lnew = tf.concat([lnew, nnew], axis=0)
        return [pnew, snew, lnew]

    def above_tcell():
        # append cells only to proto-clusters
        def true_fn():
            return tf.concat([p, nnew], axis=0)
        def false_fn(): return p
        pnew = tf.cond(tf.greater(nval, tcell), true_fn, false_fn)
        return [pnew, s, l]
    
    [pnew, snew, lnew] = tf.cond(tf.greater(nval, tneighbour),
                                 true_fn=above_tneighbor,
                                 false_fn=above_tcell)
    
    return tf.add(j, 1), m, pnew, snew, lnew, n, siid

In [None]:
def bi_maker(i, m, p, s, l):
    """Body of while loop in find_neighbours_maker."""
    si = tf.gather(s, i)
    
    # find neighbours not in protolist
    siid, siidx = tf.split(si, [1,2], axis=0)
    pidx = tf.split(p, [1,2], axis=1)[1]
    dense_shape = tf.shape(m, tf.int64)
    n = get_neighbours(siidx, dense_shape, excl=pidx)
        
    # loop over neighbors
    j0 = tf.constant(0)
    cj = lambda j, m, p, s, l, n, siid: tf.less(j, tf.shape(n)[0])
    p, s, l = tf.while_loop(
        cj, bj_maker, loop_vars=[j0, m, p, s, l, n, siid],
        shape_invariants=[j0.get_shape(), m.get_shape(), 
                          tf.TensorShape([None,3]),
                          s.get_shape(), tf.TensorShape([None,3]),
                          n.get_shape(), siid.get_shape()])[2:5]
    
    return tf.add(i, 1), m, p, s, l

In [None]:
def find_neighbours_maker(m, p, s):
    """One loop of finding neighbours.
    
    Args:
      m: A 2D `Tensor` with shape [height, width] of numeric type.
      p: A `Tensor` with shape [None, 3] and type `int64`.
      s: Same as for p.
      
    Returns:
      A tuple of Tensor objects (pnew, snew).
      pnew: Same as for p.
      snew: Same as for s.
    """
    # sort current seed list in descending order
    ssort = sort_seedlist(m, s)
    
    # loop over current seed list
    i0 = tf.constant(0)
    l0 = tf.zeros([0,3], tf.int64)
    ci = lambda i, m, p, s, l: tf.less(i, tf.shape(ssort)[0])
    pnew, snew, lnew = tf.while_loop(
        ci, bi_maker, loop_vars=[i0, m, p, s, l0],
        shape_invariants=[i0.get_shape(),
                          m.get_shape(),
                          tf.TensorShape([None,3]),
                          s.get_shape(),
                          tf.TensorShape([None,3])])[2:]
    
    # neighbor seed list becomes the new seed list
    return pnew, lnew

In [15]:
def finalize_maker(m, p):
    """Filter clusters by energy threshold"""
    # convert to dense tensor
    pid, pidx = tf.split(p, [1,2], axis=1)
    dense_shape=tf.shape(m, out_type=tf.int64)
    psp = tf.SparseTensor(pidx, tf.reshape(pid, [-1,]), dense_shape)
    pd = tf.sparse.to_dense(tf.sparse.reorder(psp))

    # loop over unique proto-cluster id's
    u = tf.unique(tf.reshape(pid, [-1]), out_idx=pid.dtype)[0]
    i0 = tf.constant(0)
    c = lambda i, pdi: tf.less(i, tf.size(u))
    def b(i, pdi):
        ui = tf.gather(u, i)
        mask = tf.equal(pdi, ui)
        indices = tf.where(mask)
        values = tf.gather_nd(m, indices)
        Ei = tf.reduce_sum(values)
        condition = tf.logical_and(mask, tf.less(Ei, tenergy))
        pdi = tf.where(condition, tf.zeros_like(pdi), pdi)
        return [tf.add(i, 1), pdi]
    pdnew = tf.while_loop(c, b, loop_vars=[i0, pd],
                          shape_invariants=[i0.get_shape(), pd.shape])[1]
    
    # convert to proto-list
    pidxnew = tf.where(pdnew)
    pidnew = tf.reshape(tf.gather_nd(pdnew, pidxnew), [-1,1])
    pnew = tf.concat([pidnew, pidxnew], axis=1)
    return pnew

In [16]:
def cluster_maker(parsed, im):
    """Forms topological clusters from cells."""
    global tneighbour, tcell, tenergy
    tneighbour = parsed['tneighbour']
    tcell = parsed['tcell']
    tenergy = parsed['tenergy']
    
    s = get_seedlist(parsed[im], parsed['tseed'])
    
    # finding neighbours recursively until current seed list is empty
    c = lambda pi, si: tf.not_equal(tf.size(si), 0)
    b = lambda pi, si: find_neighbours_maker(parsed[im], pi, si)
    pnew = tf.while_loop(
        c, b, loop_vars=[s, s],
        shape_invariants=[tf.TensorShape([None,3]), 
                          tf.TensorShape([None,3])])[0]
    
    parsed = parsed.copy()
    parsed['proto'] = finalize_maker(parsed[im], pnew)
    return parsed

## Cluster splitter

The cluster maker is sufficient for isolated signals, but not for overlapping showers. If individual particles form local maxima they may still be separable. Acting on the clusters resulting from the cluster maker, this is what the cluster splitter does in the followning steps:

* **Finding local maxima**: a local maxima is defined as a cell with: a) $E>t_\text{locmax}$, b) energy greater than that of any neighboring cell, and c) number of neighboring cells withing the parent cluster $N>t_\text{num}$ (default is $\geq4$). Each local maximum forms a cluster and parent clusters without any local maximum cell will not be split.
* **Finding neighbors**: the local maxima now becomes the initial seed list much like in cluster maker, except that only cells originally clustered are used, without thresholding and merging. Instead of merging, shared cells are added to a shared cell list to be handled separately.
* **Shared cells**: the shared cell list is expanded iteratively adding neighbors from the originally clustered cells not yet assigned to any proto-cluster. Each of these are then added to the two adjoining proto-clusters with the weights $w_1=\frac{E_{1}}{E_{1}+rE_{2}}, w_{2}=1-w_{1}, r=\exp(d_{1}-d_{2})$, where $E_{1,2}$ are the energies of the two proto-clusters and $d_{1,2}$ are the distances of the shared cell to the proto-cluster centroids in units of a typical em shower scale.
* **Finalize**:

### Finding local maxima

In [17]:
def finding_local_maxima(m, p):
    """Find local maxima cells by looping over protolist.
    m: filtered image, x: local maxima"""
    
    # loop over protolist
    dense_shape = tf.shape(m, tf.int64)
    i0 = tf.constant(0)
    x = tf.zeros([0,3], p.dtype)
    c = lambda i, p, x: tf.less(i, tf.shape(p)[0])
    def b(i, p, x):
        pi = tf.gather(p, i)
        piid, piidx = tf.split(pi, [1,2], axis=0)
        pval = tf.gather_nd(m, piidx)
        midx = tf.where(tf.greater(m, 0))
        nidx = get_neighbours(piidx, dense_shape, incl=midx)
        m_val = tf.gather_nd(m, nidx)
#         m_val = tf.boolean_mask(m_val, tf.greater(m_val, 0))

        pred = tf.logical_and(tf.logical_and(tf.greater(pval, tlocmax), 
                                      tf.greater(pval, tf.math.reduce_max(m_val))),
                                      tf.greater(tf.size(m_val), tnum))
        r = tf.cond(pred, lambda: tf.expand_dims(pi, axis=0), lambda: tf.zeros([0,3], pi.dtype))
        
        return [tf.add(i, 1), p, tf.concat([x, r], axis=0)]
    
    x = tf.while_loop(
        c, b, loop_vars=[i0, p, x], 
        shape_invariants=[i0.get_shape(), p.get_shape(), tf.TensorShape([None,3])])[2]
    
    return x

### Finding neighbors

In [18]:
def is_shared(m, nj, p, siid):
    dense_shape = tf.shape(m, out_type=tf.int64)
    nnj = get_neighbours(nj, dense_shape)
    proto_id, proto_idx = tf.split(p, num_or_size_splits=[1,2], axis=1)
    
    nnj_d = get_mask_from_indices(nnj, dense_shape)
    proto_sp = tf.SparseTensor(proto_idx, tf.reshape(proto_id, [-1,]), dense_shape)
    proto_d = tf.sparse.to_dense(tf.sparse.reorder(proto_sp))
    indices = tf.where(tf.logical_and(nnj_d, tf.not_equal(proto_d, 0)))
    values = tf.gather_nd(proto_d, indices)
    y, idx = tf.unique(values)
    return [y, tf.equal(tf.size(y), 2)]

def bj_splitter(j, m, p, s, l, o, n, siid):
    """Body of while loop over neighbors"""
    nj = tf.gather(n, j)
    
    ids, pred = is_shared(m, nj, p, siid)
    
    def true_fn():
        nnew = tf.expand_dims(tf.concat([ids, nj], axis=0), axis=0)
        onew = tf.concat([o, nnew], axis=0)
        return [p, s, l, onew]
    def false_fn():
        nnew = tf.expand_dims(tf.concat([siid, nj], axis=0), axis=0)
        pnew = tf.concat([p, nnew], axis=0)
        lnew = tf.concat([l, nnew], axis=0)
        return [pnew, s, lnew, o]

    [pnew, snew, lnew, onew] = tf.cond(pred, true_fn, false_fn)
        
    return [tf.add(j, 1), m, pnew, snew, lnew, onew, n, siid]

def bi_splitter(i, m, p, s, l, o):
    """Body of while loop over current seed list"""
    si = tf.gather(s, i)
    
    # find neighbours not in protolist and only include cells originally clustered
    siid, siidx = tf.split(si, [1,2], axis=0)
    pidx = tf.split(p, [1,2], axis=1)[1]
    oidx = tf.split(o, num_or_size_splits=[2,2], axis=1)[1]
    midx = tf.where(tf.greater(m, 0))
    dense_shape = tf.shape(m, tf.int64)
    n = get_neighbours(siidx, dense_shape, incl=midx, excl=tf.concat([pidx, oidx], axis=0))
    
    # loop over neighbors
    j0 = tf.constant(0)
    cj = lambda j, m, p, s, l, o, n, siid: tf.less(j, tf.shape(n)[0])
    _, m, p, s, l, o, _, _ = tf.while_loop(
        cj, bj_splitter, loop_vars=[j0, m, p, s, l, o, n, siid],
        shape_invariants=[j0.get_shape(), m.get_shape(), tf.TensorShape([None,3]),
                          s.get_shape(), tf.TensorShape([None,3]), tf.TensorShape([None,4]),
                          n.get_shape(), siid.get_shape()])
    
    return [tf.add(i, 1), m, p, s, l, o]

def finding_neighbors_splitter(m, p, s, o):
    """Finding neighbors recursively until current seed list is empty.
    m: image, p: protolist, s: seedlist, l: neighlist, n: neighbor, o: sharedlist"""
    # sort current seed list in descending order
    ssort = sort_seedlist(m, s)
    
    # loop over current seed list
    i0 = tf.constant(0)
    l0 = tf.zeros([0,3], tf.int64)
    ci = lambda i, m, p, s, l, o: tf.less(i, tf.shape(ssort)[0])
    _, m, pnew, s, lnew, onew = tf.while_loop(
        ci, bi_splitter, loop_vars=[i0, m, p, s, l0, o],
        shape_invariants=[i0.get_shape(), m.get_shape(), 
                          tf.TensorShape([None,3]), s.get_shape(),
                          tf.TensorShape([None,3]), tf.TensorShape([None,4])])
    
    # neighbor seed list becomes the new seed list
    return pnew, lnew, onew

### Shared cells

In [19]:
def bj_shared(j, p, l, n, siid):
    """Body of while loop over neighbors"""
    nj = tf.gather(n, j)
    nnew = tf.expand_dims(tf.concat([siid, nj], axis=0), axis=0)
    pnew = tf.concat([p, nnew], axis=0)
    lnew = tf.concat([l, nnew], axis=0)
    return [tf.add(j, 1), pnew, lnew, n, siid]

def bi_shared(i, m, p, s, l):
    """Body of while loop over current seed list"""
    si = tf.gather(s, i)
    
    # find neighbours not in protolist and only include cells originally clustered
    siid, siidx = tf.split(si, num_or_size_splits=[2,2], axis=0)
    _, pidx = tf.split(p, num_or_size_splits=[2,2], axis=1)
    midx = tf.where(m)
    dense_shape = tf.shape(m, out_type=tf.int64)
    n = get_neighbours(siidx, dense_shape, incl=midx, excl=pidx)
    
    # loop over neighbors
    j0 = tf.constant(0)
    cj = lambda j, p, l, n, siid: tf.less(j, tf.shape(n)[0])
    _, p, l, _, _ = tf.while_loop(
        cj, bj_shared, loop_vars=[j0, p, l, n, siid],
        shape_invariants=[j0.get_shape(), tf.TensorShape([None, 4]), tf.TensorShape([None, 4]),
                          n.get_shape(), siid.get_shape()])
    
    return [tf.add(i, 1), m, p, s, l]

def expand_sharedlist(m, p, s):
    # loop over current seed list
    i0, l0 = tf.constant(0), tf.zeros([0,4], s.dtype)
    ci = lambda i, m, p, s, l: tf.less(i, tf.shape(s)[0])
    _, _, pnew, s, lnew = tf.while_loop(
        ci, bi_shared, loop_vars=[i0, m, p, s, l0],
        shape_invariants=[i0.get_shape(), m.get_shape(), tf.TensorShape([None,4]), 
                          s.get_shape(), tf.TensorShape([None,4])])
    
    # neighbor seed list becomes the new seed list
    return pnew, lnew

### Finalize cluster splitter

In [20]:
def finalize_splitter(m, p, o):
    pid, pidx = tf.split(p, num_or_size_splits=[1,2], axis=1)
    pval = tf.gather_nd(m, pidx)
    dense_shape = tf.shape(m, out_type=tf.int64)
    psp = tf.SparseTensor(pidx, tf.reshape(pid, [-1,]), dense_shape)
    pd = tf.sparse.to_dense(tf.sparse.reorder(psp))
    pnew = tf.concat([tf.cast(pid, pval.dtype), tf.cast(pidx, pval.dtype), tf.expand_dims(pval, axis=1)], axis=1)
    
    y = tf.unique(tf.reshape(pid, [-1,]))[0]
#     print("unique pid", y)
    
    i0, com0, E0 = tf.constant(0), tf.zeros([0,2], m.dtype), tf.zeros([0, 1], m.dtype)
    ci = lambda i, com, E: tf.less(i, tf.size(y))
    def bi(i, com, E):
        yi = tf.gather(y, i)
        indices = tf.where(tf.equal(pd, yi))
        values = tf.expand_dims(tf.gather_nd(m, indices), axis=1)
        indices = tf.cast(indices, values.dtype)
        comi = tf.reduce_sum(tf.multiply(indices, values), axis=0)
        Ei = tf.reduce_sum(values)
        comi = tf.divide(comi, Ei)
        comi = tf.expand_dims(comi, axis=0)
        Ei = tf.reshape(Ei, [-1,1])
        return [tf.add(i, 1), tf.concat([com, comi], axis=0), tf.concat([E, Ei], axis=0)]
    _, com, E = tf.while_loop(
        ci, bi, loop_vars=[i0, com0, E0],
        shape_invariants=[i0.get_shape(), tf.TensorShape([None,2]), tf.TensorShape([None, 1])])
#     print(com, E)
    
    j0, l0 = tf.constant(0), tf.zeros([0, 4], m.dtype)
    cj = lambda j, l: tf.less(j, tf.shape(o)[0])
    def bj(j, l):
        oj = tf.gather(o, j)
        a, b, ojidx = tf.split(oj, num_or_size_splits=[1,1,2], axis=0)
        oval = tf.expand_dims(tf.gather_nd(m, ojidx), axis=0)
        ojidx = tf.cast(ojidx, com.dtype)
        amask = tf.equal(y, a)
        bmask = tf.equal(y, b)
        acom = tf.reshape(tf.boolean_mask(com, amask), [-1,])
        bcom = tf.reshape(tf.boolean_mask(com, bmask), [-1,])
        d1 = tf.sqrt(tf.reduce_sum(tf.math.squared_difference(ojidx, acom))) / moliereRadius
        d2 = tf.sqrt(tf.reduce_sum(tf.math.squared_difference(ojidx, bcom))) / moliereRadius
        # missing in units of !
        r = tf.exp(tf.subtract(d1, d2))
        E1 = tf.boolean_mask(E, amask)
        E2 = tf.boolean_mask(E, bmask)
        w1 = tf.reshape(tf.divide(E1, tf.add(E1, tf.multiply(r, E2))), [-1,])
        w2 = tf.subtract(1., w1)
        la = tf.expand_dims(tf.concat([tf.cast(a, com.dtype), ojidx, tf.multiply(w1, oval)], axis=0), axis=0)
        lb = tf.expand_dims(tf.concat([tf.cast(b, com.dtype), ojidx, tf.multiply(w2, oval)], axis=0), axis=0)
        return [tf.add(j, 1), tf.concat([l, la, lb], axis=0)]
    _, lnew = tf.while_loop(
        cj, bj, loop_vars=[j0, l0],
        shape_invariants=[j0.get_shape(), tf.TensorShape([None, 4])])

    pnew = tf.concat([pnew, lnew], axis=0)

    # sort in descending order in energy 
    indices = tf.reshape(tf.cast(tf.math.top_k(tf.reshape(E, [1, -1]), k=tf.shape(E)[0])[1], o.dtype), [-1,])
#     print(y, indices)
    
    k0, p0 = tf.constant(0), tf.zeros([0, 4], pnew.dtype)
    ck = lambda k, p: tf.less(k, tf.shape(pnew)[0])
    def bk(k, p):
        pk = tf.gather(pnew, k)
        pkid, pkidx = tf.split(pk, num_or_size_splits=[1,3], axis=0)
        yidx = tf.reshape(tf.where(tf.equal(y, tf.cast(pkid, y.dtype))), [-1,])
        pkid = tf.cast(tf.add(tf.where(tf.equal(indices, yidx)), 1), pkidx.dtype)
        pknew = tf.concat([pkid, tf.expand_dims(pkidx, axis=0)], axis=1)
        return [tf.add(k, 1), tf.concat([p, pknew], axis=0)]
    pnew = tf.while_loop(ck, bk, loop_vars=[k0, p0],
                        shape_invariants=[k0.get_shape(), tf.TensorShape([None, 4])])[1]
    
    return pnew

### Cluster splitter

In [1]:
def cluster_splitter(parsed, im):
    global tlocmax, tnum
    proto = parsed['proto']
    image = parsed[im]
    tlocmax = parsed['tlocmax']
    tnum = parsed['tnum']
    
    # filter image with protolist
    proto_id, proto_idx = tf.split(proto, [1,2], axis=1)
    proto_val = tf.gather_nd(image, proto_idx)
    dense_shape = tf.shape(image, out_type=tf.int64)
    image_sp = tf.SparseTensor(proto_idx, proto_val, dense_shape)
    image_d = tf.sparse.to_dense(tf.sparse.reorder(image_sp))
    
    locmax = finding_local_maxima(image_d, proto)
    tf.print("locmax", locmax)
    
    # create seed list from local maxima
    locmax_idx = tf.split(locmax, [1,2], axis=1)[1]
    proto_sp = tf.SparseTensor(proto_idx, tf.reshape(proto_id, [-1,]), dense_shape)
    proto_d = tf.sparse.to_dense(tf.sparse.reorder(proto_sp))
    start = tf.add(tf.reduce_max(proto_id), 1)
    limit = start + tf.shape(locmax, out_type=proto_id.dtype)[0]
    locmax_id = tf.expand_dims(tf.range(start, limit, dtype=locmax.dtype), axis=1)
    seedlist = tf.concat([locmax_id, locmax_idx], axis=1)
    tf.print("seedlist", seedlist)
    
    sharedlist = tf.zeros([0,4], tf.int64)
    
    # finding neighbours recursively until current seed list is empty
    c = lambda pi, si, oi: tf.not_equal(tf.size(si), 0)
    b = lambda pi, si, oi: finding_neighbors_splitter(image_d, pi, si, oi)
    [protolist, seedlist, sharedlist] = tf.while_loop(c, b, loop_vars=[seedlist, seedlist, sharedlist],
                         shape_invariants=[tf.TensorShape([None, 3]), 
                                           tf.TensorShape([None, 3]),
                                           tf.TensorShape([None, 4])])
    
#     [protolist, seedlist, sharedlist] = finding_neighbors_splitter(image_d, seedlist, seedlist, sharedlist)
    
    # originally clustered cells not in protolist
    protolist_id, protolist_idx = tf.split(protolist, [1,2], axis=1)
    protolist_sp = tf.SparseTensor(protolist_idx, tf.reshape(protolist_id, [-1,]), dense_shape)
    protolist_d = tf.sparse.to_dense(tf.sparse.reorder(protolist_sp))
    mask = tf.logical_and(tf.cast(proto_d, tf.bool), tf.logical_not(tf.cast(protolist_d, tf.bool)))
    
    c = lambda pi, si: tf.not_equal(tf.size(si), 0)
    b = lambda pi, si: expand_sharedlist(mask, pi, si)
    sharedlist = tf.while_loop(c, b, loop_vars=[sharedlist, sharedlist],
                         shape_invariants=[tf.TensorShape([None, 4]), 
                                           tf.TensorShape([None, 4])])[0]
    
    
# #     sharedlist = expand_sharedlist(mask, sharedlist, sharedlist)
    
    # add parent clusters without a local maximum
    mask = tf.logical_and(tf.cast(proto_d, tf.bool), tf.cast(protolist_d, tf.bool))
    proto_y = tf.expand_dims(tf.unique(tf.reshape(proto_id, [-1,]))[0], axis=0)
    proto_masked_y = tf.expand_dims(tf.unique(tf.boolean_mask(proto_d, mask))[0], axis=0)
    other = tf.sparse.to_dense(tf.sets.difference(proto_y, proto_masked_y))

    other = tf.reshape(other, [-1,1])
    
#     if tf.not_equal(tf.size(other), 0):
    i = tf.constant(0)
    c = lambda i, p: tf.less(i, tf.shape(other)[0])
    def b(i, p): 
        otheri = tf.gather(other, i)
        p = tf.where(tf.equal(proto_d, otheri), proto_d, p) 
        return [tf.add(i, 1), p]
    protolist_d = tf.while_loop(c, b, [i, protolist_d])[1]
    pidx = tf.where(protolist_d)
    pid = tf.expand_dims(tf.gather_nd(protolist_d, pidx), axis=1)
    protolist = tf.concat([pid, pidx], axis=1)
    
    cluster = finalize_splitter(image_d, protolist, sharedlist)
        
    parsed['cluster'] = cluster
    
    return parsed

## Discriminating variables

In [22]:
def scalar_features(parsed):
    cluster = parsed['cluster']
    
    cid, cidx, cval = tf.split(cluster, [1,2,1], axis=1)    
    y = tf.unique(tf.reshape(cid, [-1,]))[0]
#     print("unique pid", y)
    
    i0 = tf.constant(0)
    z0 = tf.zeros([0,5], cluster.dtype)
    c = lambda i, z: tf.less(i, tf.size(y))
    def bi(i, z):
        yi = tf.gather(y, i)
        cid_indices = tf.where(tf.equal(tf.reshape(cid, [-1,]), yi))
        values = tf.gather_nd(cval, cid_indices)
        indices = tf.reshape(tf.gather(cidx, cid_indices), [-1,2])
        com = tf.reduce_sum(tf.multiply(indices, values), axis=0, keepdims=True)
        com /= tf.reduce_sum(values)
    
        signal_sum = tf.reduce_sum(values, keepdims=True)
        signal_max = tf.cond(tf.greater(signal_sum, 0.), lambda: tf.reduce_max(values), lambda: tf.constant([[0.]]))
        signal_hot = tf.cond(tf.greater(signal_sum, 0.), lambda: signal_max/signal_sum, lambda: tf.constant([[0.]]))
        signal_rad = tf.sqrt(tf.reduce_sum(tf.pow(indices-com, 2), axis=1, keepdims=True))
        signal_rad_mean = tf.cond(tf.greater(signal_sum, 0.), 
                             lambda: tf.reduce_sum(tf.multiply(signal_rad, values), axis=0, keepdims=True)/tf.reduce_sum(values),
                             lambda: tf.constant([[0.]]))
        
        zi = tf.concat([com, signal_sum, signal_rad_mean, signal_hot], axis=1)
        
        return [tf.add(i, 1), tf.concat([z, zi], axis=0)]
    z = tf.while_loop(
        c, bi, loop_vars=[i0, z0],
        shape_invariants=[i0.get_shape(), tf.TensorShape([None,5])])[1]
    parsed['feature'] = z
    return parsed

In [3]:
def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# def serialize_example(f0,f1,f2,f3,f4,f5,f6,f7,f8):
def serialize_example(f0,f1,f2,f3,f4,f5):
    """
    Creates a tf.Example message ready to be written to a file.
    """
    # Create a dictionary mapping the feature name to the tf.Example-compatible
    # data type.
    feature = {
      'eventId': _int64_feature(f0),
      'cluster_comi': _float_feature(f1),
      'cluster_comj': _float_feature(f2),
      'S_sum': _float_feature(f3),
      'S_rad_mean': _float_feature(f4),
      'S_hot': _float_feature(f5),
#       'C_sum': _float_feature(f6),
#       'C_rad_mean': _float_feature(f7),
#       'C_hot': _float_feature(f8)
    }

    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def tf_serialize_example(example):
    f0 = example['eventId']
#     f1,f2,f3,f4,f5,f6,f7,f8 = tf.split(example['feature'], num_or_size_splits=8, axis=1)
    f1,f2,f3,f4,f5 = tf.split(example['feature'], num_or_size_splits=5, axis=1)
#     tf_string = tf.py_function(serialize_example, (f0,f1,f2,f3,f4,f5,f6,f7,f8), tf.string)
    tf_string = tf.py_function(serialize_example, (f0,f1,f2,f3,f4,f5), tf.string)
    return tf.reshape(tf_string, ())