In [1]:
from rtree import index
from copy import copy
import os
import xml.etree.ElementTree as ET
from copy import copy
from scipy import spatial
import numpy as np
import libtiff
from scipy import ndimage as nd
from matplotlib import pyplot as plt
import networkx as nx
import community
%matplotlib inline

In [2]:
Prox = {}

def spherical_trsf(pos):
    x, y, z = pos
    r = (x**2 + y**2 + z**2)**.5
    teta = np.arccos(z/r)
    phi = np.arctan2(x, y)
    return r, teta, phi

center_position = {
    0: (400*4, 223*4, 129*4),
    50: (378*4, 239*4, 131*4),
    100: (406*4, 217*4, 129*4),
    150: (394*4, 236*4, 126*4),
    200: (377*4, 239*4, 131*4)
}

from matplotlib import cm
mean_center = np.mean(center_position.values(), axis=0)
mean_center[1] = 1100.

class CellSS(object):
    """docstring for CellSS:
        self.unique_id: id that is unique to the graph it is contained in
        self.id: id from the input data
        self.M: Mother of the """
    def __init__(self, unique_id, id, M, time, pos, D=None, N=None):
        self.unique_id = unique_id
        self.id = id
        self.time = time
        self.M = M
        if D is None:
            self.D = []
        else:
            self.D = D
        if N is None:
            self.N = []
        else:
            self.N = N
        self.pos = np.array(pos)

class LineageTree(object):
    """docstring for LineageTree"""


    def _dist_v(self, v1, v2):
        v1 = np.array(v1)
        v2 = np.array(v2)
        return np.sum((v1-v2)**2)**(.5)

    def copy_cell(self, C, links=[]):
        C_tmp = copy(C)
        self.nodes.append(C)


    def to_tlp(self, fname, t_min=-1, t_max=np.inf, temporal=True, spatial=False, VF=False):
        """
        Write a lineage tree into an understable tulip file
        fname : path to the tulip file to create
        lin_tree : lineage tree to write
        properties : dictionary of properties { 'Property name': [{c_id: prop_val}, default_val]}
        """
        
        f=open(fname, "w")

        f.write("(tlp \"2.0\"\n")
        f.write("(nodes ")
        if t_max!=np.inf or t_min>-1:
            nodes_to_use = [n for n in self.nodes if t_min<n.time<=t_max]
            edges_to_use = []
            if temporal:
                edges_to_use += [e for e in self.edges if t_min<e[0].time<t_max]
            if spatial:
                edges_to_use += [e for e in self.spatial_edges if t_min<e[0].time<t_max]
        else:
            nodes_to_use = self.nodes
            edges_to_use = []
            if temporal:
                edges_to_use += self.edges
            if spatial:
                edges_to_use += self.spatial_edges

        for n in nodes_to_use:
            f.write(str(n.unique_id)+ " ")
        f.write(")\n")

        for i, e in enumerate(edges_to_use):
            f.write("(edge " + str(i) + " " + str(e[0].unique_id) + " " + str(e[1].unique_id) + ")\n")
        f.write("(property 0 int \"id\"\n")
        f.write("\t(default \"0\" \"0\")\n")
        for n in nodes_to_use:
            f.write("\t(node " + str(n.unique_id) + str(" \"") + str(n.id) + "\")\n")
        f.write(")\n")

        f.write("(property 0 int \"time\"\n")
        f.write("\t(default \"0\" \"0\")\n")
        for n in nodes_to_use:
            f.write("\t(node " + str(n.unique_id) + str(" \"") + str(n.time) + "\")\n")
        f.write(")\n")

        f.write("(property 0 layout \"viewLayout\"\n")
        f.write("\t(default \"(0, 0, 0)\" \"()\")\n")
        for n in nodes_to_use:
            f.write("\t(node " + str(n.unique_id) + str(" \"") + str(tuple(n.pos)) + "\")\n")
        f.write(")\n")

        f.write("(property 0 double \"distance\"\n")
        f.write("\t(default \"0\" \"0\")\n")
        for i, e in enumerate(edges_to_use):
            d_tmp = self._dist_v(e[0].pos, e[1].pos)
            f.write("\t(edge " + str(i) + str(" \"") + str(d_tmp) + "\")\n")
            f.write("\t(node " + str(e[0].unique_id) + str(" \"") + str(d_tmp) + "\")\n")
        f.write(")\n")

        # for property in properties:
        #     prop_name=property[0]
        #     vals=property[1]
        #     default=property[2]
        #     f.write("(property 0 string \""+prop_name+"\"\n")
        #     f.write("\t(default \""+str(default)+"\" \"0\")\n")
        #     for node in nodes:
        #         f.write("\t(node " + str(node) + str(" \"") + str(vals.get(node, default)) + "\")\n")
        #     f.write(")\n") 
        f.write(")")
        f.close()

    def median_average(self, subset):
        subset_dist = [np.mean([di.pos for di in c.D], axis = 0) - c.pos for c in subset if c.D != []]
        target_C = [c for c in subset if c.D != []]
        if subset_dist != []:
            med_distance = spatial.distance.squareform(spatial.distance.pdist(subset_dist))
            return subset_dist[np.argmin(np.sum(med_distance, axis=0))]
        else:
            return [0, 0, 0]

    def median_average_bw(self, subset):
        subset_dist = [c.M.pos - c.pos for c in subset if c.M != self.R]
        target_C = [c for c in subset if c.D != []]
        if subset_dist != []:
            med_distance = spatial.distance.squareform(spatial.distance.pdist(subset_dist))
            return subset_dist[np.argmin(np.sum(med_distance, axis=0))]
        else:
            return [0, 0, 0]

    def build_median_vector(self, C, dist_th, delta_t = 2):#temporal_space=lambda d, t, c: d+(t*c)):
        if not hasattr(self, 'spatial_edges'):
            self.compute_spatial_edges(dist_th)
        subset = [C]
        subset += C.N
        added_D = added_M = subset
        for i in xrange(delta_t):
            _added_D = []
            _added_M = []
            for c in added_D:
                _added_D += c.D
            for c in added_M:
                if not c.M is None:
                    _added_M += [c.M]
            subset += _added_M
            subset += _added_D
            added_D = _added_D
            added_M = _added_M


        return self.median_average(subset)

    def build_vector_field(self, dist_th=50):
        ruler = 0
        for C in self.nodes:
            if ruler != C.time:
                print C.time
            C.direction = self.build_median_vector(C, dist_th)
            ruler = C.time

    def read_from_xml(self, file_format, tb, te, z_mult=1.):
        self.time = {}
        self.time_edges = {}
        unique_id = 0
        self.R = CellSS(-1, -1, None, -1, [-1]*3)
        self.nodes = []
        self.edges = []
        for t in range(tb, te+1):
            t_str = '%04d' % t
            tree = ET.parse(file_format.replace('$TIME$', t_str))
            root = tree.getroot()
            self.time[t] = {}
            self.time_edges[t] = []
            for it in root.getchildren():
                M_id, pos, cell_id = (int(it.attrib['parent']), 
                                      [float(v) for v in it.attrib['m'].split(' ') if v!=''], 
                                      int(it.attrib['id']))
                pos[-1] = pos[-1]*z_mult
                if self.time.get(t-1, {}).has_key(M_id):
                    M = self.time[t-1][M_id]
                    C = CellSS(unique_id, cell_id, M, t, pos)
                    M.D.append(C)
                    self.edges.append((M, C))
                    self.time_edges[t].append((M, C))
                else:
                    C = CellSS(unique_id, cell_id, self.R, t, pos)
                    self.R.D.append(C)

                self.nodes.append(C)
                self.time[t][cell_id] = C
                unique_id += 1
        self.max_id = unique_id - 1

    def build_VF_propagation(self, t_b=0, t_e=200, nb_max=20, dist_max=200):
        VF = LineageTree(None, None, None)
        VF.nodes = []
        VF.edges = []
        VF.R = CellSS(-1, -1, None, -1, None)
        starting_cells = self.time[t_b].values()
        unique_id = 0
        VF.time = {t_b: []}
        for i, C in enumerate(starting_cells):
            C_tmp = CellSS(unique_id=unique_id, id=unique_id, M=VF.R, time = t_b, pos = C.pos)
            VF.nodes.append(C_tmp)
            VF.time[t_b].append(C_tmp)
            unique_id+=1

        for t in range(t_b, t_e):
            p = index.Property()
            p.dimension = 3
            idx3d = index.Index(properties=p)
            to_check_self = self.time[t].values()
            to_check_VF = VF.time[t]
            for i, C in enumerate(to_check_self):
                idx3d.add(i, tuple(C.pos))
            print t
            VF.time[t+1] = []
            for C in to_check_VF:
                closest_cells = np.array(to_check_self)[list(idx3d.nearest(tuple(C.pos), nb_max))]
                max_value = np.min(np.where(np.array([self._dist_v(C.pos, ci.pos) for ci in closest_cells]+[dist_max+1])>dist_max))
                cells_to_keep = closest_cells[:max_value]
                med = self.median_average(cells_to_keep)
                C_next = CellSS(unique_id, unique_id, M=C, time = t+1, pos= C.pos + med)
                VF.time[t+1].append(C_next)
                C.D.append(C_next)
                # C.d_p = med
                VF.edges.append((C, C_next))
                VF.nodes.append(C_next)
                unique_id += 1
        VF.t_b = t_b
        VF.t_e = t_e
        return VF



    def build_VF_propagation_backward(self, t_b=0, t_e=200, nb_max=20, dist_max=200):
        VF = LineageTree(None, None, None)
        VF.nodes = []
        VF.edges = []
        VF.R = CellSS(-1, -1, None, -1, None)
        starting_cells = self.time[t_b].values()
        unique_id = 0
        VF.time = {t_b: []}
        for i, C in enumerate(starting_cells):
            C_tmp = CellSS(unique_id=unique_id, id=unique_id, M=VF.R, time = t_b, pos = C.pos)
            VF.nodes.append(C_tmp)
            VF.time[t_b].append(C_tmp)
            unique_id+=1

        if t_b>t_e : 
            increment = -1
        else: 
            increment = 1
        for t in range(t_b, t_e, -1):
            p = index.Property()
            p.dimension = 3
            idx3d = index.Index(properties=p)
            to_check_self = self.time[t].values()
            to_check_VF = VF.time[t]
            for i, C in enumerate(to_check_self):
                idx3d.add(i, tuple(C.pos))
            print t
            VF.time[t-1] = []
            for C in to_check_VF:
                closest_cells = np.array(to_check_self)[list(idx3d.nearest(tuple(C.pos), nb_max))]
                max_value = np.min(np.where(np.array([self._dist_v(C.pos, ci.pos) for ci in closest_cells]+[dist_max+1])>dist_max))
                cells_to_keep = closest_cells[:max_value]
                med = self.median_average_bw(cells_to_keep)
                C_next = CellSS(unique_id, unique_id, M=C, time = t-1, pos= C.pos + med)
                VF.time[t-1].append(C_next)
                C.D.append(C_next)
                # C.d_p = med
                VF.edges.append((C, C_next))
                VF.nodes.append(C_next)
                unique_id += 1
        VF.t_b = t_b
        VF.t_e = t_e
        return VF

    def compute_spatial_edges(self, th=50):
        self.spatial_edges=[]
        for t, Cs in self.time.iteritems():
            nodes_tmp, pos_tmp = zip(*[(C, C.pos) for C in Cs.itervalues()])
            nodes_tmp = np.array(nodes_tmp)
            distances = spatial.distance.squareform(spatial.distance.pdist(pos_tmp))
            nodes_to_match = np.where((0<distances) & (distances<th))
            to_link = zip(nodes_tmp[nodes_to_match[0]], nodes_tmp[nodes_to_match[1]])
            self.spatial_edges.extend(to_link)
            for C1, C2 in to_link:
                C1.N.append(C2)

    def __init__(self, file_format, tb, te, z_mult = .1):
        super(LineageTree, self).__init__()
        
        if not (file_format is None or tb is None or te is None):
            self.read_from_xml(file_format, tb, te, z_mult=z_mult)
            self.t_b = tb
            self.t_e = te


def reduce_graph(VF, t_b=None, t_e=None, reduction_scale=5):
    # import cPickle as pkl
    # f = open(p_to_VF)
    # VF = pkl.load(f)
    # f.close()
    if t_b is None:
        t_b = VF.t_b
    if t_e is None:
        t_e = VF.t_e

    VF_reduced = LineageTree(None, None, None)
    VF_reduced.nodes = []
    VF_reduced.edges = []
    VF_reduced.R = CellSS(-1, -1, None, -1, None)
    starting_cells = VF.time[t_b]
    unique_id = 0
    VF_reduced.time = {t_b: []}
    for i, C in enumerate(starting_cells):
        C_tmp = CellSS(unique_id=unique_id, id=unique_id, M=VF_reduced.R, D=copy(C.D), time = t_b, pos = C.pos)
        VF_reduced.nodes.append(C_tmp)
        VF_reduced.time[t_b].append(C_tmp)
        unique_id+=1

    for t in xrange(t_b, t_e, reduction_scale):
        VF_reduced.time[t+reduction_scale] = []
        for C in VF_reduced.time[t]:
            C_tmp = copy(C)
            for i in range(np.abs(reduction_scale)):
                if len(C_tmp.D)!=0:
                    C_tmp = copy(C_tmp.D[0])
            C_to_add = CellSS(unique_id, unique_id, M=C, D=C_tmp.D, time = t+reduction_scale, pos = C_tmp.pos)
            VF_reduced.time[t+reduction_scale].append(C_to_add)
            C.D = [C_to_add]
            VF_reduced.edges.append((C, C_to_add))
            VF_reduced.nodes.append(C_to_add)
            unique_id += 1
    return VF_reduced

def dist_v(v1, v2):
    v1 = np.array(v1)
    v2 = np.array(v2)
    return np.sum((v1-v2)**2)**(.5)

def conn_G_to_tlp(fname, b_g):

    nodes_to_use = b_g.keys()
    tmp_edges = [(k, vi) for k, v in b_g.iteritems() for vi in v]

    f=open(fname, "w")

    f.write("(tlp \"2.0\"\n")

    f.write("(nodes ")
    for n in nodes_to_use:
        f.write(str(n.unique_id)+ " ")
    f.write(")\n")

    for i, e in enumerate(tmp_edges):
        f.write("(edge " + str(i) + " " + str(e[0].unique_id) + " " + str(e[1].unique_id) + ")\n")
    
    f.write("(property 0 int \"id\"\n")
    f.write("\t(default \"0\" \"0\")\n")
    for n in nodes_to_use:
        f.write("\t(node " + str(n.unique_id) + str(" \"") + str(n.id) + "\")\n")
    f.write(")\n")

    # f.write("(property 0 int \"time\"\n")
    # f.write("\t(default \"0\" \"0\")\n")
    # for n in nodes_to_use:
    #     f.write("\t(node " + str(n.unique_id) + str(" \"") + str(n.time) + "\")\n")
    # f.write(")\n")

    f.write("(property 0 layout \"viewLayout\"\n")
    f.write("\t(default \"(0, 0, 0)\" \"()\")\n")
    for n in nodes_to_use:
        f.write("\t(node " + str(n.unique_id) + str(" \"") + str(tuple(n.pos)) + "\")\n")
    f.write(")\n")

    f.write("(property 0 double \"distance\"\n")
    f.write("\t(default \"0\" \"0\")\n")
    for i, e in enumerate(tmp_edges):
        f.write("\t(edge " + str(i) + str(" \"") + str(dist_v(e[0].pos, e[1].pos)) + "\")\n")    
    f.write(")\n")

    f.write("(property 0 int \"P\"\n")
    f.write("\t(default \"0\" \"0\")\n")
    for n in nodes_to_use:
        if hasattr(n, 'P'):
            f.write("\t(node " + str(n.unique_id) + str(" \"") + str(n.P) + "\")\n")
    f.write(")\n")

    f.write(")")
    f.close()

In [3]:
path_to_files = 'D:\Users\Leo\FernandoTGMMRuns\GMEMtracking3D_2015_3_26_15_14_22_Zebrafish_12_09_24_trainCDWT_v2_iter2\XML_finalResult_lht'

LT = LineageTree(file_format = path_to_files + '\GMEMfinalResult_frame$TIME$.xml', tb = 0, te = 730, z_mult = 8.)



KeyboardInterrupt: 