# Tracking cell division in Drosophila pupal wing  
2022/04/18 Created by Ryu Takayanagi (4th year, Department of Bioinformatics and Systems Biology, Faculty of Scienece, Univ. Tokyo)
2023/02/27 Add comments and minor modifications by Kaoru Sugimura 
2023/04 Add minor modifications by Kaoru Sugimura 
2023/08 Add modifications and extensions by Joseph Schull (2nd Year, Data Science, University of California, Berkeley)

### Environment (RT)
WSL2(Ubuntu 20.04.2)  
python(3.8.3)  
numpy(1.21.5)  
pandas(1.2.4)  
matplotlib(3.3.4 & brew install imagemagick)  
plotly(5.2.1)  
graphviz(3.0.0 & brew install graphviz)  
scipy(1.8.0)  
tqdm(4.62.3)

### Environment (KS)
mac (OS12.5.1)

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.colors as mcolors
import matplotlib
import glob, os
import plotly.graph_objects as go
from matplotlib.patches import Polygon
from matplotlib.collections import PolyCollection, LineCollection
from collections import deque, Counter
from graphviz import Digraph
from IPython.display import Image
from scipy.stats import gaussian_kde
from tqdm.auto import tqdm

from sklearn.cluster import KMeans
import math

import warnings
warnings.simplefilter('ignore', FutureWarning)

class TrackingAnalyzer:
    def __init__(self, resource_dir, result_dir, XMIN, XMAX, YMIN, YMAX):
        
        if os.path.isfile(resource_dir+"Tracking_index.csv"):  
            self.Tracking_index = pd.read_csv(resource_dir+"Tracking_index.csv", header=None)
        else:
            raise FileNotFoundError("Tracking_index.csv is not found.")
        
        if os.path.isfile(resource_dir+"Tracking_division_pair.csv"):
            self.Tracking_division_pair = pd.read_csv(resource_dir+"Tracking_division_pair.csv",
                                                      names=["time", "before cell", "after cell 1", "after cell 2"])
            # Comment (KS) skiprows=1 should be removed? -> Done
        else:
            raise FileNotFoundError("Tracking_division_pair.csv is not found.")
        
        if os.path.isfile(resource_dir+"ori_pos_x.csv"):
            self.ori_pos_x = pd.read_csv(resource_dir+"ori_pos_x.csv", header=None)
        else:
            raise FileNotFoundError("ori_pos_x.csv is not found.")
            
        if os.path.isfile(resource_dir+"ori_pos_y.csv"):
            self.ori_pos_y = pd.read_csv(resource_dir+"ori_pos_y.csv", header=None)
        else:
            raise FileNotFoundError("ori_pos_y.csv is not found.")
        
        if os.path.isfile(resource_dir+"veinid_.txt"):
            self.veintxtfilename = resource_dir+"veinid_.txt"
        else:
            raise FileNotFoundError("veintxtfile is not found.")
        
        if os.path.isdir(resource_dir+"txtfile/"):
            self.txtfilename = glob.glob(resource_dir+"txtfile/"+"*_0000.txt")[0].split("_0000.txt")[0]
        else:
            raise FileNotFoundError("txtfile directory is not found.")
        
        if os.path.isdir(resource_dir+"force_Data/"):
            self.TPfilename = glob.glob(resource_dir+"force_Data/"+"*_0000.txt")[0].split("_0000.txt")[0]
        else:
            raise FileNotFoundError("force_Data directory is not found.")
        
        if not os.path.isdir(result_dir):
            os.makedirs(result_dir[:-1])
        
        self.RESOURCE_DIR = resource_dir
        self.RESULT_DIR = result_dir
        self.N_FRAMES = self.Tracking_index.shape[1]
        self.N_TRACKS = len(self.Tracking_index)
        self.XMIN = XMIN
        self.YMIN = YMIN
        self.XMAX = XMAX
        self.YMAX = YMAX
        self.wingroi_tracks = []
        self.calc_inarea = False
        self.calc_forest = False
        self.singledivision = []
        self.doubledivision = []
    
    
    def get_skeltonized_data(self, filepath, exclude_interior = True): #(modified by JS to include exclude_interior boolean, determines whether all cells or just exterior ones are included)
        '''
        Get the x, y coordinates of vertices, Get 2 vertices for each junction, Get n vertices for each n-polygon
        Ext -> External Vertices, Replace np.nan etc. for data of external vertices
        頂点のx座標, y座標、辺に対応する頂点の組, 細胞に対応する頂点の組を取得
        Extの頂点データはnp.nanなどに置き換え、辺・細胞は無視する
        '''    
        
        with open(filepath) as f:
            vertex_x = []
            vertex_y = []
            is_ext_vertex = []
            edges = []
            is_ext_edges = []
            cells = []
            is_ext_cells = []
            l = f.readline()
            while l:
                if l[0] == "V": #Comment(KS) I edited input txt files.
                    spl = l.split()
                    vertex_x.append(float(spl[1]))
                    vertex_y.append(float(spl[2]))         # X-corrdinates,  Left to Right, Y-corrdinates,  Top to Bottom 
                    is_ext_vertex.append(spl[-1] == "Ext")
                elif l[0] == "E":
                    spl = l.split()
                    edges.append([int(spl[1]), int(spl[2])])
                    is_ext_edges.append(spl[-1] == "Ext")
                elif l[0] == "C":
                    spl = l.split()
                    if exclude_interior and spl[-1] == "Ext":
                        cells.append(list(map(int, spl[3:-1]))) # Remove the last component, "Ext" 
                        is_ext_cells.append(True)
                    elif not exclude_interior:
                        cells.append(list(map(int, spl[3:-1]))) # Remove the last component, "Ext"
                    else:
                        cells.append(list(map(int, spl[3:])))
                        is_ext_cells.append(False)
                l = f.readline()
            return np.array(vertex_x), np.array(vertex_y), is_ext_vertex, edges, np.array(is_ext_edges), cells, np.array(is_ext_cells)


    def get_vein_data(self, filepath):
        '''
        Get the id and the x, y coordinates of center of vein cells at 32 h APF.
        '''    
        
        with open(filepath) as f:
            vid32 = []
            vcenter_x32 = []
            vcenter_y32 = []
            l = f.readline()
            while l:
                spl = l.split()
                vid32.append(int(spl[0])) 
                vcenter_x32.append(float(spl[1]))
                vcenter_y32.append(float(spl[2]))         # X-corrdinates,  Left to Right, Y-corrdinates,  Top to Bottom 
                l = f.readline()
            return vid32, np.array(vcenter_x32), np.array(vcenter_y32)

        
    def show_vein(self, y_offset=2000):
        '''
        Show vein cells at 32 h APF to check the quality of manual labeling etc.
        '''    
        
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        
        coll = PolyCollection([], cmap=matplotlib.cm.Blues, edgecolors='none')  
        coll.set(clim=(0, 1))
        ax.add_collection(coll)
        ax.set_xlim(self.XMIN, self.XMAX)
        ax.set_ylim(self.YMIN, self.YMAX)
      
        frame = 198
        vx, vy, _, _, _, cells, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
        vid32, vcenter_x32, vcenter_y32 = self.get_vein_data(self.veintxtfilename)

        coll.set_verts([tuple(zip(vx[cells[i]], vy[cells[i]] + y_offset)) for i in range(len(cells))])
        color = [1 if i in vid32 else 0.2 for i in range(len(cells))] 
        coll.set_array(color)  
        plt.savefig(self.RESULT_DIR+"vein-T"+str(frame+1)+".png", dpi=600)

        
    def get_TP_data(self, filepath):
        with open(filepath) as f:
            mu = 0
            T = []
            P = []
            l = f.readline()
            while l:
                spl = l.strip().split()
                if len(spl) >= 1:
                    if spl[1] == "mu=":
                        mu = float(spl[2])
                    elif spl[0].isdigit():
                        T.append(float(spl[1]))
                    elif spl[0] == "-" and spl[1] == "-":
                        P.append(float(spl[3]))
                l = f.readline()
            return mu, np.array(T), np.array(P)
   
    def calc_stress_tensor(self):
        '''
        Calculate the cell sterss tensor. May need a validation.
        応力テンソルを計算し保存する。若干実装に不安があるので検証した方がいいかもしれない
        保存されている場合はスルー
        '''
        if not os.path.isdir(self.RESULT_DIR+"stress_tensor/"):
            os.mkdir(self.RESULT_DIR+"stress_tensor")
      
        if not os.path.isdir(self.RESULT_DIR+"cellshape_tensor/"):
            os.mkdir(self.RESULT_DIR+"cellshape_tensor")
    
            # Quick search of the edge index by using a hash table
            # 辺のインデックスを（dict型=hash tableを用いて)高速に検索
        def edgefind(edgedict, x1, x2):
            e = edgedict.get((x1, x2))
            if e != None:
                return e
            else:
                return edgedict.get((x2, x1))

        # List the vector and tension of junctions of a cell 
        def cell2edgevecs_and_tensions(cellnum, cells, edgedict, vx, vy, T):
            c = cells[cellnum]
            return [(np.array([vx[c[(i+1) % len(c)]]-vx[c[i]], vy[c[(i+1) % len(c)]]-vy[c[i]]]), T[edgefind(edgedict, c[i], c[(i+1) % len(c)])]) for i in range(len(c))]
        
        # Calculate cell area by using the outer product of edge vectors. PV stands for polygon vertex.
        def cell_area(cellnum, cells, vx, vy):
            PV = [(vx[i], vy[i]) for i in cells[cellnum]]
            return abs(sum(PV[i][0]*PV[i-1][1] - PV[i][1]*PV[i-1][0] for i in range(len(cells[cellnum])))) / 2
        
        for frame in tqdm(range(self.N_FRAMES)):
            vx, vy, _, edges, e_ext, cells, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
            mu, T, P = self.get_TP_data(self.TPfilename+"_{:0>4}.txt".format(frame))
            
            edgedict = dict()
            for i in range(len(edges)):
                edgedict[tuple(edges[i])] = i
            
            # Pick up "NOT c_ext" 
            # c_extの否定 = Extではない細胞 = ちゃんと閉じた細胞のみ(cells_in)
            cells_in = np.where(np.logical_not(c_ext))[0]
            
            # Batchelor stress tensor
            N = np.zeros((len(cells_in), 2, 2))
            for i in tqdm(range(len(cells_in)), leave=False):
                for l, t in cell2edgevecs_and_tensions(cells_in[i], cells, edgedict, vx, vy, T):
                    N[i] += np.outer(l, l)/np.sqrt(l@l)*t # np.outer: calculate the direct product
                N[i] /= cell_area(cells_in[i], cells, vx, vy)
                N[i][0, 0] -= P[cells_in[i]]
                N[i][1, 1] -= P[cells_in[i]]
            np.save(self.RESULT_DIR+"stress_tensor/"+str(frame), N)

            # Cell shape tensor
            N = np.zeros((len(cells_in), 2, 2))
            for i in tqdm(range(len(cells_in)), leave=False):
                for l, t in cell2edgevecs_and_tensions(cells_in[i], cells, edgedict, vx, vy, T):
                    N[i] += np.outer(l, l)/np.sqrt(l@l) # np.outer: calculate the direct product
                N[i] /= cell_area(cells_in[i], cells, vx, vy)
            np.save(self.RESULT_DIR+"cellshape_tensor/"+str(frame), N)


    def animate_pressure(self, title, y_offset=2000):
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        
        coll = PolyCollection([], cmap=matplotlib.cm.cool, edgecolors='none')  
        coll.set(clim=(-0.05, 0.1))
        plt.colorbar(coll)
        ax.add_collection(coll)
        ax.set_xlim(self.XMIN, self.XMAX)
        ax.set_ylim(self.YMIN, self.YMAX)
        
        def update(frame, pbar):
            vx, vy, _, _, _, cells, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
            mu, T, P = self.get_TP_data(self.TPfilename+"_{:0>4}.txt".format(frame))
            # Update data only. Add y-offset = 2000
            # データのみを更新する。txtfileのデータはy_offsetだけyの値が低いので補正する必要があることに注意する
            coll.set_verts([tuple(zip(vx[cells[i]], vy[cells[i]] + y_offset)) for i in np.where(np.logical_not(c_ext))[0]])
            coll.set_array(P[np.logical_not(c_ext)])
            pbar.update()
        
        tend = 199        
#         pbar=tqdm(total=self.N_FRAMES)
        pbar=tqdm(total=tend)
        anim = animation.FuncAnimation(fig, update, frames=range(tend), fargs=(pbar,), interval=10)
        anim.save(self.RESULT_DIR+"pressure.gif", writer="imagemagick", dpi=150)
        pbar.close()

        
    def animate_tension(self, title, y_offset=2000):  
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        
        coll = LineCollection([], cmap=matplotlib.cm.jet, lw=0.5)
        coll.set(clim=(0.4, 1.6))
        plt.colorbar(coll)
        ax.add_collection(coll)
        ax.set_xlim(self.XMIN, self.XMAX)
        ax.set_ylim(self.YMIN, self.YMAX)
        
        def update(frame, pbar):
            vx, vy, _, edges, e_ext, _, _ = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
            mu, T, P = self.get_TP_data(self.TPfilename+"_{:0>4}.txt".format(frame))
            coll.set_segments([tuple(zip(vx[edges[i]], vy[edges[i]] + y_offset)) for i in np.where(np.logical_not(e_ext))[0]])
            coll.set_array(T[np.logical_not(e_ext)])
            pbar.update()
        
        tend = 199
        pbar=tqdm(total=tend)
        anim = animation.FuncAnimation(fig, update, frames=range(tend), fargs=(pbar,), interval=10)
        anim.save(self.RESULT_DIR+"tension.gif", writer="imagemagick", dpi=150)
        pbar.close()
  

    def animate_local_cell_stress(self, title, y_offset=2000, scale=50):
        if not os.path.isdir(self.RESULT_DIR+"stress_tensor/"):
            raise Exception("calculate stress tensor before animate")
        
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        
        coll = PolyCollection([], color="white", lw=0.1, ec="black")
        ax.add_collection(coll)
        coll2 = LineCollection([], lw = 0.2, ec="red")
        ax.add_collection(coll2)
        ax.set_xlim(self.XMIN, self.XMAX)
        ax.set_ylim(self.YMIN, self.YMAX)
        
        def update(frame, pbar):
            vx, vy, _, _, _, cells, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
            N = np.load(self.RESULT_DIR+"stress_tensor/"+str(frame)+".npy")
            cells_in = np.where(np.logical_not(c_ext))[0]
            coll.set_verts([tuple(zip(vx[cells[c]], vy[cells[c]]+y_offset)) for c in cells_in])
            
            # np.linalg.eig , Calculate eigen values and corresponding eigen vectors
            # before If -> The max. eigen value is the first number in ：np.linalg.eig 
            # np.linalg.eig = 固有値と固有ベクトルを求めてくれる.
            # 重心=細胞を構成する頂点の(vx, vy)の平均に固有ベクトルを足し引きする.
            # if文の前：np.linalg.eigの結果の一つ目が最大固有値である場合. 
            # if分の後：二つ目が最大固有値である場合.
            coll2.set_segments([
                [np.array(list(zip(vx[cells[cells_in[i]]], vy[cells[cells_in[i]]]+y_offset))).mean(axis=0)
                 +np.linalg.eig(N[i])[1][:, 0]*np.linalg.eig(N[i])[0][0]*scale,
                 np.array(list(zip(vx[cells[cells_in[i]]], vy[cells[cells_in[i]]]+y_offset))).mean(axis=0)
                 -np.linalg.eig(N[i])[1][:, 0]*np.linalg.eig(N[i])[0][0]*scale]
                if np.linalg.eig(N[i])[0][0] >= np.linalg.eig(N[i])[0][1] else
                [np.array(list(zip(vx[cells[cells_in[i]]], vy[cells[cells_in[i]]]+y_offset))).mean(axis=0)
                 +np.linalg.eig(N[i])[1][:, 1]*np.linalg.eig(N[i])[0][1]*scale,
                 np.array(list(zip(vx[cells[cells_in[i]]], vy[cells[cells_in[i]]]+y_offset))).mean(axis=0)
                 -np.linalg.eig(N[i])[1][:, 1]*np.linalg.eig(N[i])[0][1]*scale] for i in range(len(cells_in))])
            pbar.update()
        
        tend = 199
        pbar=tqdm(total=tend)
#         anim = animation.FuncAnimation(fig, update, frames=range(self.N_FRAMES), fargs=(pbar,), interval=10)
        anim = animation.FuncAnimation(fig, update, frames=range(tend), fargs=(pbar,), interval=10)
        anim.save(self.RESULT_DIR+"local_cell_stress_"+str(scale)+".gif", writer="imagemagick", dpi=150)
        pbar.close()   
   

    def animate_cellshape_tensor(self, title, y_offset=2000, scale=50):
        if not os.path.isdir(self.RESULT_DIR+"cellshape_tensor/"):
            raise Exception("calculate cell shape tensor before animate")
        
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        
        coll = PolyCollection([], color="white", lw=0.1, ec="black")
        ax.add_collection(coll)
        coll2 = LineCollection([], lw = 0.2, ec="blue")
        ax.add_collection(coll2)
        ax.set_xlim(self.XMIN, self.XMAX)
        ax.set_ylim(self.YMIN, self.YMAX)
        
        def update(frame, pbar):
            vx, vy, _, _, _, cells, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
            N = np.load(self.RESULT_DIR+"cellshape_tensor/"+str(frame)+".npy")
            cells_in = np.where(np.logical_not(c_ext))[0]
            coll.set_verts([tuple(zip(vx[cells[c]], vy[cells[c]]+y_offset)) for c in cells_in])
            
            coll2.set_segments([
                [np.array(list(zip(vx[cells[cells_in[i]]], vy[cells[cells_in[i]]]+y_offset))).mean(axis=0)
                 +np.linalg.eig(N[i])[1][:, 0]*np.linalg.eig(N[i])[0][0]*scale,
                 np.array(list(zip(vx[cells[cells_in[i]]], vy[cells[cells_in[i]]]+y_offset))).mean(axis=0)
                 -np.linalg.eig(N[i])[1][:, 0]*np.linalg.eig(N[i])[0][0]*scale]
                if np.linalg.eig(N[i])[0][0] >= np.linalg.eig(N[i])[0][1] else
                [np.array(list(zip(vx[cells[cells_in[i]]], vy[cells[cells_in[i]]]+y_offset))).mean(axis=0)
                 +np.linalg.eig(N[i])[1][:, 1]*np.linalg.eig(N[i])[0][1]*scale,
                 np.array(list(zip(vx[cells[cells_in[i]]], vy[cells[cells_in[i]]]+y_offset))).mean(axis=0)
                 -np.linalg.eig(N[i])[1][:, 1]*np.linalg.eig(N[i])[0][1]*scale] for i in range(len(cells_in))])
            pbar.update()
        
        tend = 199   
        pbar=tqdm(total=tend)
        anim = animation.FuncAnimation(fig, update, frames=range(tend), fargs=(pbar,), interval=10)
        anim.save(self.RESULT_DIR+"cellshape_tensor_"+str(scale)+".gif", writer="imagemagick", dpi=150)
        pbar.close()   
   

    def get_permanent_tracks(self):
        '''
        Return tracks that are present from time 0 to end timepoint
        最初のフレームから最後のフレームまで存在するトラッキング（永続トラック）番号のリストを返す
        '''
        return self.Tracking_index.dropna().index

    
    def show_permanent_tracks(self):
        '''
        Show the initial position of tracks. Gay: all tracks, Red: permanent tracks.
        Hover to see the cell id in Jupyter notebook.
        This information is used to define wingroi_tracks below.
        永続トラックの初期位置を表示（重心） 
        インタラクティブなのでカーソルを合わせることで番号が確認できる
        '''
        
        permanents = self.get_permanent_tracks()
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=self.ori_pos_x[0], y=self.YMAX - self.ori_pos_y[0], mode="markers", marker=dict(size=2, color="lightgray")))
        fig.add_trace(go.Scatter(x=self.ori_pos_x[0][permanents], y=self.YMAX - self.ori_pos_y[0][permanents],
                                 mode="markers", marker=dict(size=3, color="red"), hovertext = list(range(len(permanents))), name="permanent"))

        fig.update_layout(width=900, height=600, xaxis_range=[self.XMIN,self.XMAX], yaxis_range=[self.YMIN,self.YMAX], template=dict(layout=go.Layout()))
        fig.update_xaxes(visible=False)
        fig.update_yaxes(visible=False)
        fig.show()

    def show_wingroi_tracks(self, time = -1): # (JS) This function calls a function which calculates proximal distal boundary regions, and plots these as lines on a diagram of the wingroi tracks
    
        wingroi = self.wingroi_tracks 
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=self.ori_pos_x[time], y=self.YMAX - self.ori_pos_y[time], mode="markers", marker=dict(size=2, color="lightgray")))
        fig.add_trace(go.Scatter(x=self.ori_pos_x[0][wingroi], y=self.YMAX - self.ori_pos_y[0][wingroi],
                                 mode="markers", marker=dict(size=3, color="red"), hovertext = list(range(len(wingroi))), name="wing roi track"))
        
        if time > -1:
            boundary_array = self.set_proximal_distal_regions(time)  # Specify the region boundaries
            for x_val in boundary_array:
                fig.add_shape(type="line",
                      x0=x_val, y0=self.YMIN,  # Starting point of the line (x0, y0)
                      x1=x_val, y1=self.YMAX,  # Ending point of the line (x1, y1)
                      line=dict(color="blue", width=1))  # Line properties
            

        fig.update_layout(width=900, height=600, xaxis_range=[self.XMIN,self.XMAX], yaxis_range=[self.YMIN,self.YMAX], template=dict(layout=go.Layout()))
        fig.update_xaxes(visible=False)
        fig.update_yaxes(visible=False)
        fig.show()

        return boundary_array 

        
    def animate_permanent_tracks(self):
        '''
        A movie of permanent track
        永続トラックを動画にする（重心）
        '''
        
        permanents = self.get_permanent_tracks()

        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)

        def update(frame, pbar):
            ax.cla()
            ax.scatter(self.ori_pos_x[frame], self.YMAX - self.ori_pos_y[frame], c="gray", s=1)
            ax.scatter(self.ori_pos_x[frame][permanents], self.YMAX - self.ori_pos_y[frame][permanents], c="red", s=3)
            ax.set_xlim(self.XMIN, self.XMAX)
            ax.set_ylim(self.YMIN, self.YMAX)
            pbar.update()
        
        pbar = tqdm(total=self.N_FRAMES)
        anim = animation.FuncAnimation(fig, update, frames=range(self.N_FRAMES), fargs=(pbar,), interval=10)
        anim.save(self.RESULT_DIR+"permanent_tracks.gif", writer="imagemagick", dpi=150)
        pbar.close()
  

    def set_wingroi_tracks(self, wingroi_tracks):
        '''
        wingroi_tracks : list of track to define the outer boundary of wingROI. id is defined by permanent_tracks.
        wingroi_tracks : ポリゴン（解析領域）を構成するトラックの番号リスト（番号は永続トラックリストのもの）
        '''
        
        self.wingroi_tracks = self.get_permanent_tracks()[wingroi_tracks]

        
    def _inwingroi(self, sx, sy, px, py):
        ''' 
        Return if the cell is inside the ROI or not by using Crossing Number Algorithm. 
        May not work for a donuts-shaped ROI.
        px[:], py[:]: Vertex coordinates of polygon (i.e., wingroi)
        sx, sy: Coordinates of vertices of interest
        ポリゴンの内外判定を Crossing Number Algorithm に基づいて行う
        https://www.nttpc.co.jp/technology/number_algorithm.html
        ドーナツ型のポリゴンに対応できているかは不明 -> 内部に孔をあけたいときは検証・テストするべき
        px[:], py[:]: ポリゴンの座標リスト
        sx, sy: 調べたい頂点の座標
        '''   
        
        npoly = len(px)
        inside = False
        for i1 in range(npoly): 
            i2 = (i1 + 1) % npoly
            if min(px[i1], px[i2]) < sx < max(px[i1], px[i2]):
                if (py[i1] + (py[i2]-py[i1])/(px[i2]-px[i1])*(sx-px[i1]) - sy) > 0:
                    inside = not inside # 判定を裏返す（奇数回の反転でポリゴン内）, Overturn a decision (odd-number overturn  -> iside the ROI)
        return inside
     
        
    def calc_trackings_in_and_out_wingroi_area(self):
        '''
        Split tracks into inside or outside the wingROI.
        Output results to ori_pos_x_inarea.csv. Read the csv files if they exist.
        px[:], py[:]: Coordinates of polygon (i.e., ROI)
        sx, sy: Coordinates of vertices of interest
        指定したwingroiの範囲内かどうかトラッキングを分類し、表データを分割する
        計算には時間がかかるため、結果は出力され、二回目以降はそれを読み込む
        '''   
        
        if len(self.wingroi_tracks) == 0:
            raise Exception("set wingroi_tracks by function: set_wingroi_tracks()")
        
        if os.path.isfile(self.RESULT_DIR+"ori_pos_x_inarea.csv"):
            self.ori_pos_x_inarea = pd.read_csv(self.RESULT_DIR+"ori_pos_x_inarea.csv", header=None)
            self.ori_pos_x_outarea = pd.read_csv(self.RESULT_DIR+"ori_pos_x_outarea.csv", header=None)
            self.ori_pos_y_inarea = pd.read_csv(self.RESULT_DIR+"ori_pos_y_inarea.csv", header=None)
            self.ori_pos_y_outarea = pd.read_csv(self.RESULT_DIR+"ori_pos_y_outarea.csv", header=None)
            self.Tracking_index_inarea = pd.read_csv(self.RESULT_DIR+"Tracking_index_inarea.csv", header=None)
            self.Tracking_index_outarea = pd.read_csv(self.RESULT_DIR+"Tracking_index_outarea.csv", header=None)
        else:
            self.ori_pos_x_inarea = self.ori_pos_x.copy()
            self.ori_pos_y_inarea = self.ori_pos_y.copy()
            self.Tracking_index_inarea = self.Tracking_index.copy()
            
            self.ori_pos_x_outarea = self.ori_pos_x.copy()
            self.ori_pos_y_outarea = self.ori_pos_y.copy()
            self.Tracking_index_outarea = self.Tracking_index.copy()

            for frame in tqdm(range(self.N_FRAMES)):
                for i in tqdm(range(self.N_TRACKS), leave=False):
                    if not np.isnan(self.ori_pos_x[frame][i]) and not np.isnan(self.ori_pos_y[frame][i]):
                        if not self._inwingroi(self.ori_pos_x[frame][i], self.ori_pos_y[frame][i],
                                         self.ori_pos_x[frame][self.wingroi_tracks].values,
                                         self.ori_pos_y[frame][self.wingroi_tracks].values): # (sx, sy, px, py)
                            self.ori_pos_x_inarea[frame][i] = np.nan
                            self.ori_pos_y_inarea[frame][i] = np.nan
                            self.Tracking_index_inarea[frame][i] = np.nan
                        else:
                            self.ori_pos_x_outarea[frame][i] = np.nan
                            self.ori_pos_y_outarea[frame][i] = np.nan
                            self.Tracking_index_outarea[frame][i] = np.nan
            
            self.ori_pos_x_inarea.to_csv(self.RESULT_DIR+"ori_pos_x_inarea.csv", header=False, index=False)
            self.ori_pos_x_outarea.to_csv(self.RESULT_DIR+"ori_pos_x_outarea.csv", header=False, index=False)
            self.ori_pos_y_inarea.to_csv(self.RESULT_DIR+"ori_pos_y_inarea.csv", header=False, index=False)
            self.ori_pos_y_outarea.to_csv(self.RESULT_DIR+"ori_pos_y_outarea.csv", header=False, index=False)
            self.Tracking_index_inarea.to_csv(self.RESULT_DIR+"Tracking_index_inarea.csv", header=False, index=False)
            self.Tracking_index_outarea.to_csv(self.RESULT_DIR+"Tracking_index_outarea.csv", header=False, index=False)
            
        self.calc_inarea = True
    
    
    def calc_division_pair_tracks_in_area(self):
        '''
        Quantify cell division in the wingROI by using the tracking ID, not cell ID.
        Called in calc_Forest.
        ポリゴン内の分裂をトラッキング番号のデータとして数えなおす。
        calc_Forest内で呼び出すので事前に呼び出す必要はない
        '''
        if not self.calc_inarea:
            raise Exception("calculate trackings in wingroi area")
        
        self.Tracking_division_pair_inarea_trackid = np.zeros((len(self.Tracking_division_pair), 4), dtype=float)
        # Subtract 1. Note that time in Tracking_division_pair.csv corresponds to the birth of two sister cells. 
        # Tracking_division_pair.csvの時間の項は細胞分裂が起きた直後のフレームに対応しているので、1減らしておく 
        self.Tracking_division_pair_inarea_trackid[:, 0] = np.array(self.Tracking_division_pair["time"]) - 1

        for i in tqdm(range(len(self.Tracking_division_pair_inarea_trackid))):
            j = self.Tracking_division_pair_inarea_trackid[i, 0] #time
            
            # A cell of interest in Tracking_index_inarea. -> Search the row number (tracking ID) of the cell in Tracking_division_pair. 
            # Replace 0 with the row number (tracking ID) if the condition is fulfilled. NaN if not.
            # np.whereで条件（inareaのトラッキングのうちTracking_division_pairで考えている細胞名の"行"）を検索
            # 条件に合致したものがある場合（len(before） > 0)は置き換え。なければNaNで。
            before = np.where(self.Tracking_index_inarea[j] == self.Tracking_division_pair.iloc[i, 1])[0]
            self.Tracking_division_pair_inarea_trackid[i, 1] = before[0] if len(before) > 0 else np.nan
            
            after1 = np.where(self.Tracking_index_inarea[j+1] == self.Tracking_division_pair.iloc[i, 2])[0]
            self.Tracking_division_pair_inarea_trackid[i, 2] = after1[0] if len(after1) > 0 else np.nan
            
            after2 = np.where(self.Tracking_index_inarea[j+1] == self.Tracking_division_pair.iloc[i, 3])[0]
            self.Tracking_division_pair_inarea_trackid[i, 3] = after2[0] if len(after2) > 0 else np.nan

        self.Tracking_division_pair_inarea_trackid = pd.DataFrame(self.Tracking_division_pair_inarea_trackid,
                                                                  columns=["time", "before track", "after track 1", "after track 2"]).dropna().reset_index(drop=True).astype(int)
    
    
    def calc_Forest(self):
        '''
        Reconstitute a tree based on the information stored in Tracking_division_pair_inarea.
        計算した分裂の情報から系統樹（細胞分裂）の森を構成し、諸々の特徴量を計算する
        '''    
        
        self.calc_division_pair_tracks_in_area()
        
        self.TrackingForest_inarea = Forest(self.N_TRACKS)
        for i in reversed(range(len(self.Tracking_division_pair_inarea_trackid))):
            j, before, after1, after2 = self.Tracking_division_pair_inarea_trackid.iloc[i]  # j = time
            self.TrackingForest_inarea.merge(str(after1), str(after2), str(before), j) 
        self.TrackingForest_inarea.reduce_roots()
    
        # The number of division of each track
        # Ancestor of each track
        # 各トラッキングがその時点で何回分裂した後か
        # 各トラッキングの根（一番の祖先）が誰か -> colorfulの動画作成時などに利用
        self.Tracking_divcounts_inarea = np.zeros(self.Tracking_index_inarea.shape)
        self.Tracking_ancestor_inarea = np.array(list(range(len(self.Tracking_index_inarea))))
        
        for root in self.TrackingForest_inarea.roots:
            for (trackid, time, depth, sum_of_div) in self.TrackingForest_inarea.enum_nodes_and_time(root)[1:]:
                self.Tracking_divcounts_inarea[int(trackid), time+1:] = depth
                self.Tracking_ancestor_inarea[int(trackid)] = int(root)

        # List tracks, which exist at time 0 and undergo cell division
        # 最初から存在し、かつ分裂を起こしたトラッキング（細胞）の番号リスト
        self.divide_and_initially_exist_inarea_tracks = np.array(list(set(self.TrackingForest_inarea.roots.keys())
                                                                      & set(map(str, self.Tracking_index_inarea[0].dropna().index))))
        
        # The number of division of tracks, which exist at time 0
        # 最初から存在するトラッキングの分裂回数のデータ
        self.div_counts_inarea = np.array([self.TrackingForest_inarea.roots[i].height for i in self.divide_and_initially_exist_inarea_tracks])
        
        print("<division counts>")
        print(Counter(self.div_counts_inarea))
        
        self.calc_forest = True
    
    
    def show_inarea_division_times_density(self):
        '''
        Histogram  of cell division time. kde.
        Split data based on the division numbers: 1, 2, and 3 -> 1 and 2
        分裂回数ごとにいつ分裂したかのヒストグラム（密度推定）
        1～3回の分裂まで表示 -> 1回と2回
        '''
  
        if not os.path.isdir(self.RESULT_DIR+"div_time/"):
            os.mkdir(self.RESULT_DIR+"div_time")
    
        roots1 = self.divide_and_initially_exist_inarea_tracks[self.div_counts_inarea == 1]
        divtime = []
        for root in roots1:
            root = self.TrackingForest_inarea.roots[root]
            divtime.append(root.left.time+1)
                        
        tend = 198
        G1_time =gaussian_kde(divtime)(np.linspace(0, tend, 1000))
        plt.plot(np.linspace(0,  tend, 1000), G1_time, c="green", label="div time of 1div cell")
        
        np.save(self.RESULT_DIR+"div_time/divtime_1-1", G1_time)
        
        roots2 = self.divide_and_initially_exist_inarea_tracks[self.div_counts_inarea == 2]
        divtime_1st = []
        divtime_2nd = []
        for root in roots2:
            for (trackid, time, depth, sum_of_div) in self.TrackingForest_inarea.enum_nodes_and_time(root)[1:]:
                if depth==1:
                    divtime_1st.append(time+1)
                elif depth==2:
                    divtime_2nd.append(time+1)
                
        Y1_time = gaussian_kde(divtime_1st)(np.linspace(0, tend, 1000))
        Y2_time =gaussian_kde(divtime_2nd)(np.linspace(0, tend, 1000))
    
        plt.plot(np.linspace(0,  tend, 1000), Y1_time,
                 c=tuple(list(mcolors.to_rgb("orange")) + [1.0]), label="1st div time of 2div cell")
        plt.plot(np.linspace(0, tend, 1000),  Y2_time,
                 c=tuple(list(mcolors.to_rgb("orange")) + [0.5]), label="2nd div time of 2div cell")
        
        np.save(self.RESULT_DIR+"div_time/divtime_2-1", Y1_time)
        np.save(self.RESULT_DIR+"div_time/divtime_2-2", Y2_time)                                 
                                                   
        plt.legend()
        plt.title("Division time")
        plt.xlabel("Time [h APF]")
        plt.ylabel("Frequency")
        plt.xlim(0, tend)
        plt.ylim(0, 0.04)
        plt.xticks([6, 30, 54, 78, 102, 126, 150, 174, 198], ['16', '18', '20', '22', '24', '26', '28', '30', '32'])
        plt.yticks([0, 0.01, 0.02, 0.03, 0.04])
        plt.savefig(self.RESULT_DIR+"divtimes.png")
    
    
    def show_stress_and_division_angle_plots(self, time_offset=1):
        '''
        Cell division angle is plotted against the orientation of the maximm stress.
        細胞分裂の向き（角度）と応力の向き（角度）のプロットを作成する
        分裂した時間および分裂の回数による色付け
        '''
        
        if not self.calc_forest:
            raise Exception("calculate Forest before plot")
        
        if not os.path.isdir(self.RESULT_DIR+"stress_tensor/"):
            raise Exception("calculate stress tensor before plot")
        
        if not os.path.isdir(self.RESULT_DIR+"div_orientation/"):
            os.mkdir(self.RESULT_DIR+"div_orientation")
               
        before_maxstress = []
        aftercells_vecs = []
        div_time = []
        division_type = []
        divtype_color = {"1div":"green", "2div-1":"darkorange", "2div-2":"gold", "others":"lightgray"}
        
        tend = 199        
#         for frame in tqdm(range(time_offset-1, self.N_FRAMES)):
        for frame in tqdm(range(time_offset-1, tend)):
            before_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["before track"].values
            befores = [self.Tracking_index_inarea[frame-time_offset+1][int(bt)] for bt in before_tracks] # cell id
            after1_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["after track 1"].values
            after2_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["after track 2"].values
            
            _, _, _, _, _, _, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame-time_offset+1))
            cells_in = np.where(np.logical_not(c_ext))[0]
            N = np.load(self.RESULT_DIR+"stress_tensor/"+str(frame-time_offset+1)+".npy")
            
            for i in range(len(befores)):
                if befores[i] == np.nan:
                    continue
                
                b = np.where(cells_in == befores[i])[0]
                if len(b) > 0:
                    e = np.linalg.eig(N[b[0]])
                    before_maxstress.append(e[1][:, 0] if abs(e[0][0]) > abs(e[0][1]) else e[1][:, 1])
                    # y座標は実際には YMAX - vy の形で考えていたので（相対座標的には）符号を反転する必要があることに注意
                    aftercells_vecs.append(np.array([TA.ori_pos_x_inarea[frame+1][after2_tracks[i]]-TA.ori_pos_x_inarea[frame+1][after1_tracks[i]],
                                                     -TA.ori_pos_y_inarea[frame+1][after2_tracks[i]]+TA.ori_pos_y_inarea[frame+1][after1_tracks[i]]]))
                    div_time.append(frame)
                    ancestor_pointer = self.TrackingForest_inarea.roots.get(str(self.Tracking_ancestor_inarea[before_tracks[i]]))
                    
                    # Split the data by the number of cell division
                    # 祖先の分裂回数ごとに場合分け（泥臭い）
                    if ancestor_pointer != None:
                        if ancestor_pointer.height == 1:
                            division_type.append("1div")
                        elif ancestor_pointer.height == 2:
                            if ancestor_pointer.name == str(befores[i]):
                                division_type.append("2div-1")
                            else:
                                division_type.append("2div-2")
                        else:
                            division_type.append("others")
                    else:
                        division_type.append("others")
        
        division_type = np.array(division_type)
        division_time = np.array(div_time)
        
        # np.angle： the range is (-np.pi, np.pi] . Convert the value into  (-90°, 90°] .
        # np.angle：角度を (-np.pi, np.pi] で求めてくれる
        # 実際には (-90°, 90°] の形で考えるといいので、そうなるように上手く変形している
        # Comment (KS): 応力テンソルは実対称行列なので、固有ベクトルは実数（固有値も実数）。atanなど、もっと簡単に書いて良いのでは？
        before_angles = np.array([((np.angle(complex(bs[0], bs[1])) + 5/2*np.pi) % np.pi - np.pi/2) / np.pi * 180 for bs in before_maxstress])
        after_angles  = np.array([((np.angle(complex(a[0], a[1])) + 5/2*np.pi) % np.pi - np.pi/2) / np.pi * 180 for a in aftercells_vecs])
        
        np.save(self.RESULT_DIR+"div_orientation/cellstress_offset_"+str(time_offset), before_angles)
        np.save(self.RESULT_DIR+"div_orientation/division_offset_"+str(time_offset), after_angles)   
        np.save(self.RESULT_DIR+"div_orientation/divtype_offset_"+str(time_offset), division_type)  
        np.save(self.RESULT_DIR+"div_orientation/divtime_offset_"+str(time_offset), division_time)  
        
        fig = plt.figure(figsize=(8,7), dpi=150)
        ax = fig.add_subplot(111)

        sc = ax.scatter(before_angles, after_angles, vmin=0, vmax=102, c=div_time, cmap=matplotlib.cm.jet, s=0.2)
        cbar = plt.colorbar(sc, ax=ax, label="Division time [h APF]", ticks= [6, 30, 54, 78, 102])
        cbar.ax.set_yticklabels(['16', '18', '20', '22', '24']) 
        ax.set_xlabel("The maximum stress direction of mother cell [degree]")
        ax.set_xlim(-90, 90)
        ax.set_xticks([-90, -60, -30, 0, 30, 60, 90]) 
        ax.set_ylabel("The angle of the vector connecting \n the centers of two dauther cells [degree]")
        ax.set_ylim(-90, 90)
        ax.set_yticks([-90, -60, -30, 0, 30, 60, 90]) 
        
        plt.savefig(self.RESULT_DIR+"angle_plot_cellstress_with_time_offset_"+str(time_offset)+".png")
        
        fig = plt.figure(figsize=(8,7), dpi=150)
        ax = fig.add_subplot(111)
  
        for divtype in divtype_color.keys():
            divtype_indices = np.where(division_type == divtype)[0]
            ax.scatter(before_angles[divtype_indices], after_angles[divtype_indices], c=divtype_color[divtype], label=divtype, s=0.2)
                    
        ax.legend(bbox_to_anchor=(1.05, 1))
        ax.set_xlabel("The maximum stress direction of mother cell [degree]")
        ax.set_xlim(-90, 90)
        ax.set_xticks([-90, -60, -30, 0, 30, 60, 90]) 
        ax.set_ylabel("The angle of the vector connecting \n the centers of two dauther cells [degree]")
        ax.set_ylim(-90, 90)
        ax.set_yticks([-90, -60, -30, 0, 30, 60, 90]) 
        plt.tight_layout()
        plt.savefig(self.RESULT_DIR+"angle_plot_cellstress_with_division_type_offset_"+str(time_offset)+".png")
 

    def show_cellshape_and_division_angle_plots(self, time_offset=1):
        '''
        Cell division angle is plotted against the orientation of cell shape tensor.
        細胞分裂の向き（角度）と細胞の向き（角度）のプロットを作成する
        分裂した時間および分裂の回数による色付け
        '''
        
        if not self.calc_forest:
            raise Exception("calculate Forest before plot")
        
        if not os.path.isdir(self.RESULT_DIR+"cellshape_tensor/"):
            raise Exception("calculate cell shape tensor before plot")
        
        if not os.path.isdir(self.RESULT_DIR+"div_orientation/"):
            os.mkdir(self.RESULT_DIR+"div_orientation")
        
        before_maxstress = []
        aftercells_vecs = []
        div_time = []
        division_type = []
        divtype_color = {"1div":"green", "2div-1":"darkorange", "2div-2":"gold", "others":"lightgray"}
        
        tend = 198 
        for frame in tqdm(range(time_offset-1, tend)):
            before_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["before track"].values
            befores = [self.Tracking_index_inarea[frame-time_offset+1][int(bt)] for bt in before_tracks] # cell id
            after1_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["after track 1"].values
            after2_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["after track 2"].values
            
            _, _, _, _, _, _, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame-time_offset+1))
            cells_in = np.where(np.logical_not(c_ext))[0]
            N = np.load(self.RESULT_DIR+"cellshape_tensor/"+str(frame-time_offset+1)+".npy")
            
            for i in range(len(befores)):
                if befores[i] == np.nan:
                    continue
                
                b = np.where(cells_in == befores[i])[0]
                if len(b) > 0:
                    e = np.linalg.eig(N[b[0]])
                    before_maxstress.append(e[1][:, 0] if abs(e[0][0]) > abs(e[0][1]) else e[1][:, 1])
                    # y座標は実際には YMAX - vy の形で考えていたので（相対座標的には）符号を反転する必要があることに注意
                    aftercells_vecs.append(np.array([TA.ori_pos_x_inarea[frame+1][after2_tracks[i]]-TA.ori_pos_x_inarea[frame+1][after1_tracks[i]],
                                                     -TA.ori_pos_y_inarea[frame+1][after2_tracks[i]]+TA.ori_pos_y_inarea[frame+1][after1_tracks[i]]]))
                    div_time.append(frame)
                    ancestor_pointer = self.TrackingForest_inarea.roots.get(str(self.Tracking_ancestor_inarea[before_tracks[i]]))
                    
                    if ancestor_pointer != None:
                        if ancestor_pointer.height == 1:
                            division_type.append("1div")
                        elif ancestor_pointer.height == 2:
                            if ancestor_pointer.name == str(befores[i]):
                                division_type.append("2div-1")
                            else:
                                division_type.append("2div-2")
                        else:
                            division_type.append("others")
                    else:
                        division_type.append("others")
        
        division_type = np.array(division_type)
        
        # Comment (KS): 実対称行列なので、固有ベクトルは実数（固有値も実数）。atanなど、もっと簡単に書いて良いのでは？
        before_angles = np.array([((np.angle(complex(bs[0], bs[1])) + 5/2*np.pi) % np.pi - np.pi/2) / np.pi * 180 for bs in before_maxstress])
        after_angles  = np.array([((np.angle(complex(a[0], a[1])) + 5/2*np.pi) % np.pi - np.pi/2) / np.pi * 180 for a in aftercells_vecs])
        
        np.save(self.RESULT_DIR+"div_orientation/cellshape_offset_"+str(time_offset), before_angles)
        
        fig = plt.figure(figsize=(8,7), dpi=150)
        ax = fig.add_subplot(111)

        sc = ax.scatter(before_angles, after_angles, vmin=0, vmax=102, c=div_time, cmap=matplotlib.cm.jet, s=0.2)
        cbar = plt.colorbar(sc, ax=ax, label="Division time [h APF]", ticks= [6, 30, 54, 78, 102])
        cbar.ax.set_yticklabels(['16', '18', '20', '22', '24']) 
        ax.set_xlabel("The direction of mother cell [degree]")
        ax.set_xlim(-90, 90)
        ax.set_xticks([-90, -60, -30, 0, 30, 60, 90]) 
        ax.set_ylabel("The angle of the vector connecting the centers of two dauther cells [degree]")
        ax.set_ylim(-90, 90)
        ax.set_yticks([-90, -60, -30, 0, 30, 60, 90]) 
        
        plt.savefig(self.RESULT_DIR+"angle_plot_cellshape_with_time_offset_"+str(time_offset)+".png")
        
        fig = plt.figure(figsize=(8,7), dpi=150)
        ax = fig.add_subplot(111)

        for divtype in divtype_color.keys():
            divtype_indices = np.where(division_type == divtype)[0]
            ax.scatter(before_angles[divtype_indices], after_angles[divtype_indices], c=divtype_color[divtype], label=divtype, s=0.2)
        ax.legend(bbox_to_anchor=(1.05, 1))
        ax.set_xlabel("The direction of mother cell [degree]")
        ax.set_xlim(-90, 90)
        ax.set_xticks([-90, -60, -30, 0, 30, 60, 90]) 
        ax.set_ylabel("The angle of the vector connecting \n the centers of two dauther cells [degree]")
        ax.set_ylim(-90, 90)
        ax.set_yticks([-90, -60, -30, 0, 30, 60, 90]) 
        plt.tight_layout()
        plt.savefig(self.RESULT_DIR+"angle_plot_cellshape_with_division_type_offset_"+str(time_offset)+".png")

    def determine_groups_over_time(self, y_offset=2000, width=20): ## (JS) This is a copy of the animate_group1and2 function, modified to return a dictionary of group1 and group2 cells at different timepoints. The animate_group1and2 function could be modified to do the same.
            '''
            Split cell divisions into Group 1 and 2.
            Hertwig's rule: A cell divides along its longest axis. In pupal wing, the longest axis and the maximum stress direction are mostly close.
            角度の情報からGroup1とGroup2の細胞に色分けした動画を作る
            色分けした角度プロットも出力する
            Hertwigの法則：細胞の長軸方向に細胞分裂が起きる、翅では細胞の長軸方向と主応力方向は近いことがほとんど
            '''
            
            if not self.calc_forest:
                raise Exception("calculate Forest before plot")
            
            if not os.path.isdir(self.RESULT_DIR+"stress_tensor/"):
                raise Exception("calculate stress tensor before plot")
            
            if not os.path.isdir(self.RESULT_DIR+"cellshape_tensor/"):
                raise Exception("calculate cell shape tensor before plot")
            
            if not os.path.isdir(self.RESULT_DIR+"div_orientation/"):
                os.mkdir(self.RESULT_DIR+"div_orientation")
            
            before_angles = []
            after_angles = []
            is_group1s = []
            Tracking_colors = np.full(self.Tracking_index.shape, "white", dtype=object)
            
            tend = 199
            
            group1s_over_time = {}

            for frame in tqdm(range(tend)):
                _, _, _, _, _, _, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
                before_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["before track"].values
                befores = [self.Tracking_index_inarea[frame][int(bt)] for bt in before_tracks]
                after1_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["after track 1"].values
                after2_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["after track 2"].values
                N = np.load(self.RESULT_DIR+"stress_tensor/"+str(frame)+".npy")
                cells_in = np.where(np.logical_not(c_ext))[0]
                for i in range(len(befores)):
                    b = np.where(cells_in == befores[i])[0]
                    if len(b) > 0:
                        e = np.linalg.eig(N[b[0]])
                        bs = e[1][:, 0] if abs(e[0][0]) > abs(e[0][1]) else e[1][:, 1]
                        av = np.array([self.ori_pos_x_inarea[frame+1][after2_tracks[i]]-self.ori_pos_x_inarea[frame+1][after1_tracks[i]],
                                    -self.ori_pos_y_inarea[frame+1][after2_tracks[i]]+self.ori_pos_y_inarea[frame+1][after1_tracks[i]]])
                        ba = ((np.angle(complex(bs[0], bs[1]))+5/2*np.pi) % np.pi - np.pi/2) / np.pi * 180
                        aa = ((np.angle(complex(av[0], av[1]))+5/2*np.pi) % np.pi - np.pi/2) / np.pi * 180
                        
                        # is_group1 = deviation from y=x. 
                        # is_group1 = 角度が y=xから幅widthだけしか離れていないか
                        is_group1 = (aa <= ba + width) and (aa >= ba - width)
                        Tracking_colors[after1_tracks[i]][frame+1:] = "blue" if is_group1 else "orange"
                        Tracking_colors[after2_tracks[i]][frame+1:] = "blue" if is_group1 else "orange"
                        before_angles.append(ba)
                        after_angles.append(aa)
                        is_group1s.append(is_group1)
                group1s_over_time[frame] = is_group1s

            before_angles = np.array(before_angles)
            after_angles = np.array(after_angles)
            is_group1s = np.array(is_group1s)

            return group1s_over_time

    def animate_group1and2(self, y_offset=2000, width=20): ## edited by Joseph
        '''
        Split cell divisions into Group 1 and 2.
        Hertwig's rule: A cell divides along its longest axis. In pupal wing, the longest axis and the maximum stress direction are mostly close.
        角度の情報からGroup1とGroup2の細胞に色分けした動画を作る
        色分けした角度プロットも出力する
        Hertwigの法則：細胞の長軸方向に細胞分裂が起きる、翅では細胞の長軸方向と主応力方向は近いことがほとんど
        '''
        
        if not self.calc_forest:
            raise Exception("calculate Forest before plot")
        
        if not os.path.isdir(self.RESULT_DIR+"stress_tensor/"):
            raise Exception("calculate stress tensor before plot")
        
        if not os.path.isdir(self.RESULT_DIR+"cellshape_tensor/"):
            raise Exception("calculate cell shape tensor before plot")
        
        if not os.path.isdir(self.RESULT_DIR+"div_orientation/"):
            os.mkdir(self.RESULT_DIR+"div_orientation")
        
        before_angles = []
        after_angles = []
        is_group1s = []
        Tracking_colors = np.full(self.Tracking_index.shape, "white", dtype=object)
        
        tend = 199
        
        for frame in tqdm(range(tend)):
            _, _, _, _, _, _, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
            before_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["before track"].values
            befores = [self.Tracking_index_inarea[frame][int(bt)] for bt in before_tracks]
            after1_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["after track 1"].values
            after2_tracks = self.Tracking_division_pair_inarea_trackid[self.Tracking_division_pair_inarea_trackid["time"] == frame]["after track 2"].values
            N = np.load(self.RESULT_DIR+"stress_tensor/"+str(frame)+".npy")
            cells_in = np.where(np.logical_not(c_ext))[0]
            for i in range(len(befores)):
                b = np.where(cells_in == befores[i])[0]
                if len(b) > 0:
                    e = np.linalg.eig(N[b[0]])
                    bs = e[1][:, 0] if abs(e[0][0]) > abs(e[0][1]) else e[1][:, 1]
                    av = np.array([self.ori_pos_x_inarea[frame+1][after2_tracks[i]]-self.ori_pos_x_inarea[frame+1][after1_tracks[i]],
                                   -self.ori_pos_y_inarea[frame+1][after2_tracks[i]]+self.ori_pos_y_inarea[frame+1][after1_tracks[i]]])
                    ba = ((np.angle(complex(bs[0], bs[1]))+5/2*np.pi) % np.pi - np.pi/2) / np.pi * 180
                    aa = ((np.angle(complex(av[0], av[1]))+5/2*np.pi) % np.pi - np.pi/2) / np.pi * 180
                    
                    # is_group1 = deviation from y=x. 
                    # is_group1 = 角度が y=xから幅widthだけしか離れていないか
                    is_group1 = (aa <= ba + width) and (aa >= ba - width)
                    Tracking_colors[after1_tracks[i]][frame+1:] = "blue" if is_group1 else "orange"
                    Tracking_colors[after2_tracks[i]][frame+1:] = "blue" if is_group1 else "orange"
                    before_angles.append(ba)
                    after_angles.append(aa)
                    is_group1s.append(is_group1)


        before_angles = np.array(before_angles)
        after_angles = np.array(after_angles)
        is_group1s = np.array(is_group1s)
        

        np.save(self.RESULT_DIR+"div_orientation/group1_offset_1",  is_group1s)
        
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        
        coll = PolyCollection([], lw=0.1, edgecolors='black')
        ax.add_collection(coll)
        ax.set_xlim(self.XMIN, self.XMAX)
        ax.set_ylim(self.YMIN, self.YMAX)
        
        def facecolor(frame, c_ext, n_cells):
            fc = np.full(n_cells, "white", dtype=object)
            indices = self.Tracking_index[frame].dropna().index
            fc[self.Tracking_index[frame][indices].astype(int)] = Tracking_colors[indices, frame]
            return fc[np.where(np.logical_not(c_ext))[0].tolist()]

        def update(frame, pbar):
            vx, vy, _, _, _, cells, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
            coll.set_verts([tuple(zip(vx[cells[i]], vy[cells[i]]+y_offset)) for i in np.where(np.logical_not(c_ext))[0]])
            coll.set_facecolor(facecolor(frame, c_ext, len(cells)))
            pbar.update()
        
        pbar=tqdm(total=tend)
        anim = animation.FuncAnimation(fig, update, frames=range(tend), fargs=(pbar,), interval=10)
        anim.save(self.RESULT_DIR+"Group1-2_w"+str(width)+".gif", writer="imagemagick", dpi=150)
        pbar.close()
        
        fig = plt.figure(figsize=(8,7), dpi=150)
        ax = fig.add_subplot(111)
        
        ax.scatter(before_angles[np.logical_not(is_group1s)], after_angles[np.logical_not(is_group1s)], c="orange", label="Group 2", s=0.2)
        ax.scatter(before_angles[is_group1s], after_angles[is_group1s], c="blue", label="Group 1", s=0.2)
        ax.legend(bbox_to_anchor=(1.05, 1))
        ax.set_xlabel("The maximum stress direction of mother cell [degree]")
        ax.set_xlim(-90, 90)
        ax.set_xticks([-90, -60, -30, 0, 30, 60, 90]) 
        ax.set_ylabel("The angle of the vector connecting \n the centers of two dauther cells [degree]")
        ax.set_ylim(-90, 90)
        ax.set_yticks([-90, -60, -30, 0, 30, 60, 90]) 
        plt.tight_layout()
        plt.savefig(self.RESULT_DIR+"angle_plot_Group1-2_w"+str(width)+".png")

        
    def animate_inarea_divisions_normal(self, title, y_offset=2000):
        '''
        Movie of cell division in the wingroi. 
        Show polygonal cells. Show 1~3 divisions. -> 1 and 2 divisions
        wingroi内の分裂を動画化する
        分裂回数ごとに色分けされ、分裂するごとに透明度が高まる
        0～3回の分裂のみ表示 -> 1 and 2 divisions
        '''  
        
        # List tracks that divide once, twice, or three times.  -> List tracks that divide once or twice.
        # n=1,2,3回分裂した細胞のトラッキングを列挙（分離） -> n=1,2回分裂
        all_ndivides_inarea = [set().union(*[set(map(int, self.TrackingForest_inarea.get_treenodes(root))) for root in
                                             self.divide_and_initially_exist_inarea_tracks[self.div_counts_inarea == n]]) for n in range(1, 3)]
        # no_divides :  0 division or >= 4 divisions -> 0 division or >= 3 divisions
        # no_divides : 上記に当てはまらない = 一回も分裂してない or 4回以上分裂 -> 一回も分裂してない or 3回以上分裂
        no_divides = set()
        for l in all_ndivides_inarea:
            no_divides |= l
        no_divides = list(set(range(self.N_TRACKS)) - no_divides)
        
        color_ndivides = ["green", "orange"]
        
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        
        for i in range(2):
            ax.scatter([], [], s=2, c=color_ndivides[i], label="div "+str(i+1)+" times")
        
        coll = PolyCollection([], lw=0.1, edgecolors='black')
        ax.add_collection(coll)
        ax.set_xlim(self.XMIN, self.XMAX)
        ax.set_ylim(self.YMIN, self.YMAX)
        ax.legend(fontsize=10, loc='upper right')
        
        def facecolor(frame, c_ext, n_cells):
            fc = np.zeros((n_cells, 4))
            for nc in range(len(all_ndivides_inarea)):
                ncdiv_cells = self.Tracking_index_inarea[frame][list(all_ndivides_inarea[nc])].dropna()
                fc[ncdiv_cells.values.astype(int), :3] = mcolors.to_rgb(color_ndivides[nc])
                # Increase the transparancy upon cell division. 
                fc[ncdiv_cells.values.astype(int), 3] = 1.5**(-self.Tracking_divcounts_inarea[ncdiv_cells.index, frame])
            return fc[np.where(np.logical_not(c_ext))[0].tolist()]
        
        def update(frame, pbar):
            vx, vy, _, _, _, cells, c_ext = self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame))
            coll.set_verts([tuple(zip(vx[cells[i]], vy[cells[i]]+y_offset)) for i in np.where(np.logical_not(c_ext))[0]])
            coll.set_facecolor(facecolor(frame, c_ext, len(cells)))
            pbar.update()
        
        tend = 199
        pbar=tqdm(total=tend)
        anim = animation.FuncAnimation(fig, update, frames=range(tend), fargs=(pbar,), interval=10)
        anim.save(self.RESULT_DIR+"divs_inarea_normal.gif", writer="imagemagick", dpi=150)
        pbar.close()

    def animate_inarea_divisions_normal_center(self, title):
        '''
        Movie of cell division in the wingroi. 
        Show cell centers. Show 1~3 divisions. -> 1 and 2 divisions
        wing roi内の分裂を動画化する（細胞の重心）
        分裂回数ごとに色分けされ、分裂するごとに透明度が高まる
        1～3回の分裂のみ表示 -> 1~2回の分裂のみ表示
        '''
        
        all_ndivides_inarea = [set().union(*[set(map(int, self.TrackingForest_inarea.get_treenodes(root))) for root in
                                             self.divide_and_initially_exist_inarea_tracks[self.div_counts_inarea == n]]) for n in range(1, 3)]
        no_divides = set()
        for l in all_ndivides_inarea:
            no_divides |= l
        no_divides = list(set(range(self.N_TRACKS)) - no_divides)
        
        color_ndivides = ["green", "orange"]
        
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)

        def update(frame, pbar):
            ax.cla()
            ax.fill(self.ori_pos_x[frame][self.wingroi_tracks], self.YMAX - self.ori_pos_y[frame][self.wingroi_tracks], c=[0.6, 0.6, 1.0, 0.05])
            ax.scatter(self.ori_pos_x_outarea[frame], self.YMAX - self.ori_pos_y_outarea[frame], c="#EEEEEE", s=2, label="Out of ROI")
            ax.scatter(self.ori_pos_x_inarea[frame][no_divides], self.YMAX - self.ori_pos_y_inarea[frame][no_divides], c="#EEEEEE", s=2, label="NotAnalyzed")
            
            for nc in range(len(all_ndivides_inarea)):
                cols = np.zeros((len(all_ndivides_inarea[nc]), 4))
                cols[:, :3] = mcolors.to_rgb(color_ndivides[nc])
                cols[:, 3] = 2**(-self.Tracking_divcounts_inarea[list(all_ndivides_inarea[nc]), frame])
                ax.scatter(self.ori_pos_x_inarea[frame][all_ndivides_inarea[nc]], self.YMAX - self.ori_pos_y_inarea[frame][all_ndivides_inarea[nc]], c=cols, s=2, label="Divide "+str(nc+1)+" times")
            ax.set_xlim(self.XMIN, self.XMAX)
            ax.set_ylim(self.YMIN, self.YMAX)
            ax.legend(fontsize=10, loc='upper right')
            pbar.update()

        tend = 199
        pbar=tqdm(total=tend)
        anim = animation.FuncAnimation(fig, update, frames=range(tend), fargs=(pbar,), interval=10)
        anim.save(self.RESULT_DIR+"divs_inarea_normal_.gif", writer="imagemagick", dpi=150)
        pbar.close()
    
    ### All of the following code was written by Joseph :)
        
        
    def set_proximal_distal_regions (self,hour): 
    # This function divides the wing ROI (range of interest) into 4 segments, separated by 3 boundaries. It returns an array of these boundaries (which is simply an array of x coordinate integers)
        
        hour = 0 
        x_vertices = [self.find_x_coordinate(int(track),hour) for track in self.wingroi_tracks] 
        regionsize = self.find_total_initial_cell_number()/4 
        min_x = min(x_vertices) 
        max_x = max(x_vertices)
        first_boundary = find_boundaries(regionsize, min_x, min_x+100, max_x) 
        second_boundary = find_boundaries(regionsize, first_boundary, first_boundary+100, max_x)
        third_boundary = find_boundaries(regionsize, second_boundary, second_boundary+100, max_x)
        boundary_array = [first_boundary, second_boundary, third_boundary]
        
        def find_boundaries(self, regionsize, x1, x2, max_x): 
        
        # helper function used by set_proximal_distal_regions, gradually increases boundary size until a certain cell number is reached, or the rightmost coordinate is reached. Note that this function sets boundaries based upon cell numbers at the start of the recording.
        
            while self.find_enclosed_cells(x1,x2) < regionsize or x2 == max_x:
                x2 += 20
            return x2
    
        return boundary_array
    
       
    def find_enclosed_cells(self,x1,x2): 
    
    # This function finds the number of enclosed cells within two x coordinates, at the beginning of the recording (it is used for the find_boundaries function within set_proximal_distal_regions)
    
        count = 0
        for track in self.divide_and_initially_exist_inarea_tracks:
            if self.find_x_coordinate(track,0)>x1 and self.find_x_coordinate(track,0)<=x2:
                count+=1
        return count
        

    def find_total_initial_cell_number(self): 
    
    # This function returns the total number of cells that exist within the wing ROI at the start of the recording
        
        return len(self.divide_and_initially_exist_inarea_tracks)
        

    def find_enclosed_cells_at_hour (self, boundary_array, hour): 
    
    # This function finds the number of enclosed cells within each boundary region (P1, P2, D1, D2) at the start of a specific hour
        
        time = hour * 12 
        P1total = 0
        P2total = 0
        D1total = 0
        D2total = 0 
        for track in self.Tracking_index.iloc[:,time].dropna():
            x_loc = self.find_x_coordinate(track,time)
            if x_loc < boundary_array[0]:
                P1total+=1
            elif x_loc >= boundary_array[0] and x_loc < boundary_array[1]:
                P2total+=1
            elif x_loc >= boundary_array[1] and x_loc < boundary_array[2]:
                D1total+=1
            elif x_loc >= boundary_array[2]:
                D2total+=1
        return [P1total,P2total,D1total,D2total]
        

    def find_regional_div_frequency(self, hour, count): 
    
    # This function finds the division frequency for the 4 regions of the wing. It takes in a count (e.g 1 division, 2 divisions, 3 divisions) and an hour, and returns an array of standardized counts representing the # of cells that dividided count # of times in given hour.

        P1 = 0 #first proximal region
        P2 = 0 #second proximal region
        D1 = 0 #first distal region
        D2 = 0 #second distal region

        boundary_array = self.show_wingroi_tracks(hour) # show_wingroi tracks returns the boundary array
        min_time = (hour-1)*12
        max_time = (hour)*12

        dividing_cells = self.Tracking_division_pair #CHECK STRUCTURE OF THIS DATA
        dividing_cells['Tracking ID'] = dividing_cells.apply(lambda row: self.find_trackID(row['before cell'], row['time']), axis=1)
        tracking_counts = dividing_cells['Tracking ID'].value_counts()

        if count == 1:
            single_occurrences = tracking_counts[tracking_counts== 1].index
            dividing_cells = dividing_cells[dividing_cells['before cell'].isin(single_occurrences)]
        elif count == 2:
            double_occurrences = tracking_counts[tracking_counts== 2].index
            dividing_cells = dividing_cells[dividing_cells['Tracking ID'].isin(double_occurrences)]
        elif count == 3:
            triple_occurrences = tracking_counts[tracking_counts== 3].index
            dividing_cells = dividing_cells[dividing_cells['Tracking ID'].isin(triple_occurrences)]


        dividing_cells = dividing_cells[(dividing_cells.iloc[:, 0] >= min_time) & (dividing_cells.iloc[:, 0] < max_time)] #data frame that filters original df to just rows within a given timerange
        dividing_cells = dividing_cells.iloc[:, [0, 1]] #further filters the dataframe to only include the first two columns

        total_array = self.find_enclosed_cells_at_hour(boundary_array,hour)
        
        P1total = total_array[0]
        P2total = total_array[1]
        D1total = total_array[2]
        D2total = total_array[3]
        
        for index, row in dividing_cells.iterrows():
            cell_time = row[0]
            cellID = row[1]

            trackID = self.find_trackID(cellID, cell_time)
            if trackID is not None:
                trackID = int(trackID)
                x_coordinate = self.find_x_coordinate(trackID, cell_time)
                if not math.isnan(x_coordinate):
                    if int(x_coordinate) < int(boundary_array[0]):
                        P1+=1
                    elif int(x_coordinate) >= int(boundary_array[0]) and int(x_coordinate) < int(boundary_array[1]):
                        P2+=1
                    elif int(x_coordinate) >= int(boundary_array[1]) and int(x_coordinate) < int(boundary_array[2]):
                        D1+=1
                    elif int(x_coordinate) > int(boundary_array[2]):
                        D2+=1
                        
        def standardize_counts(count,total):
            return count/total

        P1 = standardize_counts(P1,P1total)
        P2 = standardize_counts(P2,P2total)
        D1 = standardize_counts(D1,D1total)
        D2 = standardize_counts(D2,D2total)

        return [P1,P2,D1,D2]
            
    
    def plot_div_by_region (self, hour, count): 
        
    # This function creates a plot for the number of cells that divide count # times within the different regions during specified hour
        
        count_array = self.find_regional_div_frequency(hour, count)
    
        P1count = count_array[0]
        P2count = count_array[1]
        D1count = count_array[2]
        D2count = count_array[3]
    
        group_labels = ['Proximal 1', 'Proximal 2', 'Distal 1', 'Distal 2']
        count_values = [P1count, P2count, D1count, D2count]

        plt.bar(group_labels, count_values, color =['aquamarine', 'mediumseagreen', 'xkcd:sky blue', 'xkcd:eggshell'])
        plt.xlabel('Wing Region: Proximal --> Distal')
        plt.ylabel('Cell Division Count/Total Cells in  Region')
        plt.title("Cell Division Frequency During Hour " + str(hour) + ' [' + str(count) + ' division(s)]')
        plt.savefig(self.RESULT_DIR + "_division_count_by_region.png")


    def line_plot_all_regions(self, count): 
    
    # This function creates a line plot for division counts in each region over time, for a given number of cell divisions
        
        P1counts = []
        P2counts = []
        D1counts = []
        D2counts = []

        hours = math.floor(self.N_FRAMES/(60/5))
        
        time = range(0, hours)

        for hour in range (0, hours):
            count_array = self.find_regional_div_frequency(hour, count)
            P1counts.append(count_array[0])
            P2counts.append(count_array[1])
            D1counts.append(count_array[2])
            D2counts.append(count_array[3])
        plt.plot(time, P1counts, label='P1')
        plt.plot(time, P2counts, label='P2')
        plt.plot(time, D1counts, label='D1')
        plt.plot(time, D2counts, label='D2')

        plt.xlabel('Time (hours)')
        plt.ylabel('# Cell Divisions / # Cells in Region')
        plt.title('Regional Cell Division Frequency Over Time [' + str(count) + ' division(s)]')
        plt.xticks(time)
        plt.legend()
                        

    def find_neighbours (self, cellID, time): 
    
    # This function finds the neighbours (cells that share edges with given cell) of a cell at a given time 
        
        frame = (time-1)*(12)
        cells =  self.get_skeltonized_data(self.txtfilename+"_{:0>4}.txt".format(frame), False)[5]
        cell_edges = cells[int(cellID)]
        neighbour_cellIDs = []
        for edge in cell_edges:
            for i in range(0,len(cells)):  
                if edge in cells[i]:
                    neighbour_cellIDs.append(i)
        neighbour_cellIDs = [item for i, item in enumerate(neighbour_cellIDs) if item not in neighbour_cellIDs[:i]]
        return neighbour_cellIDs


    def find_group_1_and_2 (self, time, group_dictionary):

    # Given a dictionary of time : group1 and group2 ids, this function finds the group1 and group2 cellIDS at a particular timepoint

        boolean_array = group_dictionary[time]
        
        group1_ids = [self.find_cell_id(index, time) for index, value in enumerate(boolean_array) if value and not np.isnan(index)]
        group2_ids = [self.find_cell_id(index, time) for index, value in enumerate(boolean_array) if not value and not np.isnan(index)]

        return group1_ids, group2_ids

 

    def cells_that_divide_during_timerange(self, t1, t2): 
        
    # This function determines which cells divide between time t1 and time t2, not inclusive for t2. It returns an array of cellIDs.
       
        celldivs = self.Tracking_division_pair
        celldivs = celldivs.iloc[:, :2] #selects just the first two columns of the division dataframe
        celldivs = celldivs[(celldivs.iloc[:, 0] >= t1) & (celldivs.iloc[:, 0] < t2)] #selects just rows for time between t1 to t2 
        
        cellIDs = celldivs.iloc[:, 1].values
        return cellIDs
    

    def find_cell_id(self, track_id, time): 
        
    # This function finds the cellID of a cell, given its trackID and a timepoint

    # Check if the track ID exists in the DataFrame
        if track_id in self.Tracking_index.index:
        # Access the row corresponding to the track ID
            row = self.Tracking_index.loc[track_id]

        if time in row.index:
            # Retrieve the cell ID at the specified time
            cell_id = row[time]
            return cell_id

    # If no match is found, return None or an appropriate value
        return None
    
    
    def find_cluster_centres(self,t1,t2, group1, group2): 
    
    # This function returns arrays of group1 and group2 cells that divide between t1 and t2 (the cells that begin cluster formations)
        
        g1_cluster_centres = group1[::10]
        g1_cluster_centres = [cell for cell in group1 if cell in self.cells_that_divide_during_timerange(t1,t2)]
        g2_cluster_centres = group2[::10]
        g2_cluster_centres = [cell for cell in group2 if cell in self.cells_that_divide_during_timerange(t1,t2)]
        return g1_cluster_centres, g2_cluster_centres
    
    def find_clusters_during_timerange(self, t1, t2, g1_cluster_centres, g2_cluster_centres): 
    
    # This function finds the clusters surrounding given cluster centres at a particular timepoint
        
        group1_clusters = []
        group2_clusters = []

        def find_cluster_and_append(cell, t1, grouping):
            cluster = self.find_neighbours(cell, t1)
            if grouping:
                group1_clusters.append(cluster)
            elif not grouping: 
                group2_clusters.append(cluster)

        for cell in g1_cluster_centres:
            find_cluster_and_append(cell, t1, True)

        for cell in g2_cluster_centres:
            find_cluster_and_append(cell, t1, False)

        return group1_clusters, group2_clusters


    def find_cluster_proportions(self, cluster, groupby_group1, group1, group2): 
        
    # This function finds the proportions of a group1/group2 cluster that are group1/group2 cells. If groupby_group1 is true -> returns group1 proportions, if groupby_group1 is false -> returns group2 proportions

        count = 0
        total = len(cluster)

        def count_common_values(array1, array2):
            common_values = set(array1) & set(array2)
            return len(common_values)
        
        if groupby_group1:
            count = count_common_values(cluster,group1)
        
        if not groupby_group1:
            count = count_common_values(cluster,group2)

        proportion = count/total

        return proportion


    def plot_and_inspect_clusters(self): 
    
    # This function plots group1 and group2 cluster proportions over time
        
        g1_proportions = {} 
        g2_proportions = {}
        group_dictionary = self.determine_groups_over_time()  

        for i in range(1,17):

            group1, group2 = self.find_group_1_and_2(i, group_dictionary)
            group1forcluster, group2forcluster = self.find_cluster_centres(i, i+1, group1, group2)
            

            g1_clusters, g2_clusters = self.find_clusters_during_timerange(i, i+1, group1forcluster, group2forcluster)
           
            
            for cluster in g1_clusters:
                proportion = self.find_cluster_proportions(cluster, True, group1, group2)
                if i not in g1_proportions:
                    g1_proportions[i] = []
                g1_proportions[i].append(proportion)
                
            for cluster in g2_clusters:
                proportion = self.find_cluster_proportions(cluster, False, group1, group2)
                if i not in g2_proportions:
                    g2_proportions[i] = []
                g2_proportions[i].append(proportion)
        
        def find_averages(my_dict):
            for key in my_dict:
                my_dict[key] = sum(my_dict[key]) / len(my_dict[key])
        
        find_averages(g1_proportions)  
        find_averages(g2_proportions)

        g1_proportions = list(g1_proportions.values()) 
        g2_proportions = list(g2_proportions.values())

        # g1_proportions and g2_proportions are now arrays of numbers, each number at index i representing the avg cluster proportion at hour i

        time = np.arange(1,17)
        plt.plot(time, g1_proportions, label='Group 1')
        plt.plot(time, g2_proportions, label='Group 2')

        plt.title('Cluster Proportions over Time for Group 1 and Group 2 Cells')
        plt.xlabel('Time')
        plt.ylabel('Cluster Proportion')

        plt.legend()

        plt.show()


    
    def find_trackID(self,cellID, time): 
        
    # This function finds the trackID of a cell given its cellID and a timepoint

    # Get the row indices where the cellID and time match
        matching_rows = self.Tracking_index[(self.Tracking_index[time] == cellID)].index

    # If there are no matching rows, return None
        if matching_rows.empty:
            return None

    # Return the first matching row index
        return matching_rows[0]
    
    def find_x_coordinate(self, trackID, time): 
    
    # This function finds the x coordinate of a specified trackID at a specified time
        
        df = self.ori_pos_x
        try:
            x_value = df.loc[trackID, time]
            return x_value
        except KeyError:
            return None
        
    def find_y_coordinate(self, trackID, time):
    
    # This function finds the y coordinate of a specified trackID at a specified time
        
        df = self.ori_pos_y
        try:
            y_value = df.loc[trackID, time]
            return y_value
        except KeyError:
            return None

    
    def get_cell_characteristic(self, dictionary, item_id, time_point): 
    
    # This function returns the value of a certain cell characteristic for a particular cell at a particular timepoint, given a dictionary of time: value pairs. Note that index of the dictionary is the cellID.
        
        if time_point in dictionary:
            item_list = dictionary[time_point]
            if item_id < len(item_list):
                return item_list[int(item_id)]
        return None
    
    def get_avg_edge_tension(self, dictionary, edges, timepoint): 
    
    # This function is similar to get_cell_characteristic, but returns the AVERAGE edge tension for particular cell edges
        
        if timepoint in dictionary:
            edge_tensions_list = dictionary[timepoint]
            tension_sum = 0
            if edges is None:
                return None
            for edge in edges:
                if edge < len(edge_tensions_list):
                    tension_sum += edge_tensions_list[edge]
            avg_tension = tension_sum/len(edges)

        return avg_tension

        

class Node:
    '''
    Node of tree
    系統樹（木構造）の頂点
    '''
    def __init__(self, ID):
        self.ID = ID #このノードのID
        self.name = ID # 名前(被りあり)
        self.height = 0 # 高さ
        self.time = -1 # 分岐した時間
        self.left = None #左側の要素
        self.right = None #右側の要素
        self.sum_of_div = 0
        
    
    
    def copy(self): 
        new = Node(self.ID+"_")
        new.name = self.name
        new.height = self.height
        new.time = self.time
        new.left = self.left
        new.right = self.right
        new.sum_of_div = self.sum_of_div
     
        return new
    
    
    def __str__(self): 
        '''
       For debug  デバッグ用
        '''
        return '(%s, [%s, dep:%d, time:%d], %s)' % (self.left, self.ID, self.height, self.time, self.right, self.sum_of_div)

  
class Forest:
    '''
    Forest: A group of tree
    Retrograde tracing (merge)
    Tracking = One node
    系統樹の集合（森）
    時間を遡って木を構築する（mergeメソッド）
    トラッキングを一つの頂点(Node)と考える
    中身はNodeのdict
    '''
    
    def __init__(self, n_tracks):
        '''
        roots: dict data. key: tracking id
        roots : 辞書データ。keyはトラッキング名（被る可能性アリ）だが、NodeのIDは被りナシとして扱う
        '''
        self.roots = {str(i):Node(str(i)) for i in range(n_tracks)}
    
    
    def merge(self, after1_name, after2_name, before_name, time):
        '''
        Record cell division. 
        細胞の分裂を入力する。時間を遡るため、融合として考える
        '''
        before = self.roots.get(before_name)
        after1 = self.roots.get(after1_name)
        after2 = self.roots.get(after2_name)
        
        # すでにrootsに存在しない場合（親が曖昧な場合など）は無視
        # すなわち、最後に列挙された分裂の候補を採用することになる
        if after1 != None and after2 != None and before != None:
        
            # 同じ名前のトラッキングな場合
            if after1 == before:
                # 前回の分裂と同時刻 = 分裂が曖昧で複数パターン候補が出ているとき
                if after1.left != None and after1.left.time == time:
                    # すでに他の分裂の候補を採用しているので無視
                    return
                # 名前は同じでも扱いを変える必要がある
                before = after1.copy()
            elif after2 == before:
                if after2.left != None and after2.left.time == time:
                    return
                before = after2.copy()
            
            after1.time = time
            after2.time = time
            before.left = after1
            before.right = after2
            
            before.height = max(after1.height, after2.height) + 1
            before.sum_of_div = after1.sum_of_div + after2.sum_of_div + 1 #added by Joseph
            
            self.roots.pop(after1_name)
            self.roots.pop(after2_name)
            
            # 同じ名前のトラッキングが分裂したときなどはbeforeが消されているので復活させる
            if not before_name in self.roots:
                self.roots[before_name] = before
  

    def reduce_roots(self):
        '''
        Trim trees with no height
        高さのない木を削除してrootsをスリムにする
        '''
        rootlist = list(self.roots.keys())
        for r in rootlist:
            if self.roots[r].height == 0:
                self.roots.pop(r)
 

    def render_tree(self, rootname, outdir):
        '''
        Visuazlize a specific tree by using graphviz
        graphvizを用いて特定の系統樹（木）を表示する
        '''   
        rootp = self.roots[rootname]
        T = Digraph("Tree_of_track"+rootname, format="png")
        
        # breadth first search (BFS) 幅優先探索
        q = deque()
        q.append(rootp)
        while q:
            v = q.popleft()
            if v.left != None:
                T.edge(v.ID, v.left.ID)
                q.append(v.left)
            if v.right != None:
                T.edge(v.ID, v.right.ID)
                q.append(v.right)
        path = T.render(directory=outdir).replace('\\', '/')
        print(path)
        return Image(path)
  

    def get_treenodes(self, rootname):
        '''
        List vertices that constitute of a tree by using breadth first search (BFS)
        木を構成する頂点を（幅優先探索で）列挙する
        '''
        rootp = self.roots[rootname]
        treenodes = {rootname}
        
        # breadth first search 幅優先探索
        q = deque()
        q.append(rootp)
        while q:
            v = q.popleft()
            if v.left != None:
                treenodes.add(v.left.name)
                q.append(v.left)
            if v.right != None:
                treenodes.add(v.right.name)
                q.append(v.right)
        return treenodes
    

    def enum_nodes_and_time(self, rootname): 
        '''
        List time and depth of vertices of a tree by using depth-first search (DFS)
        木を構成する頂点の時間と深さを（深さ優先探索で）列挙する
        '''
        rootp = self.roots[rootname]
        treenodes = [(rootname, -1, 0,0)]

        #  depth-first search (DFS)  深さ優先探索
        s = []
        s.append((rootp, 0))
        while s:
            v, d = s.pop()
            if v.left != None:
                treenodes.append((v.left.name, v.left.time, d+1, v.left.sum_of_div))
                s.append((v.left, d+1))
            if v.right != None:
                treenodes.append((v.right.name, v.right.time, d+1, v.right.sum_of_div))
                s.append((v.right, d+1))
        return treenodes


In [None]:
TA = TrackingAnalyzer("/Users/josephschull/Desktop/R_Joseph/InputData/140725_tp7/", "/Users/josephschull/Desktop/R_Joseph/Results/140725_tp7/", 0, 4000, 0, 2000)
# TA.show_permanent_tracks()
# TA.animate_permanent_tracks()
TA.set_wingroi_tracks([5, 1, 23, 52, 50, 37, 86, 139, 220, 231, 248, 251, 239, 213, 117, 64])

# TA.calc_stress_tensor()
# TA.animate_local_cell_stress("")
# TA.animate_cellshape_tensor("")
TA.calc_trackings_in_and_out_wingroi_area()
TA.calc_Forest()
# TA.plot_and_inspect_clusters()
print(len(TA.Tracking_index))
# TA.animate_group1and2()


# TA.plot_and_inspect_clusters()
# TA.show_inarea_division_times_density()
# TA.show_stress_and_division_angle_plots(time_offset=1)
# TA.show_stress_and_division_angle_plots(time_offset=2)
# TA.show_stress_and_division_angle_plots(time_offset=6)
# TA.show_stress_and_division_angle_plots(time_offset=12)
# TA.animate_inarea_divisions_normal("")
# TA.animate_inarea_divisions_normal_center("")

# Data Reorganisation
### The following code organises cell and tracking data into a large dataframe 'ModelFrame'

In [None]:
import numpy as np
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt


TA_practice = TrackingAnalyzer("/Users/josephschull/Desktop/R_Joseph/InputData/141007_tp4/", "/Users/josephschull/Desktop/R_Joseph/Results/141007_tp4/", 0, 4000, 0, 2000)
##### replace the above line with the desired data and filepath to be used
TA_practice.set_wingroi_tracks([5, 1, 23, 52, 50, 37, 86, 139, 220, 231, 248, 251, 239, 213, 117, 64])
### replace the above line with the desired wingroi tracks
TA_practice.calc_trackings_in_and_out_wingroi_area()
forest = TA_practice.calc_Forest()
TA_practice.calc_stress_tensor()

# Create the dataframe from Tracking_division_pair
ModelFrame = TA_practice.Tracking_division_pair
ModelFrame = ModelFrame.iloc[:, [0, 1]]

# Model Frame now contains just two columns, time and before cell

ModelFrame = ModelFrame.rename(columns={ModelFrame.columns[0]: 'Time'})
ModelFrame = ModelFrame.rename(columns={ModelFrame.columns[1]: 'Cell ID'})

# The following code uses the find_trackID function to add Tracking ID as a column to ModelFrame

ModelFrame['Tracking ID'] = ModelFrame.apply(lambda row: TA_practice.find_trackID(row['Cell ID'], row['Time']), axis=1)

# ModelFrame now contains an additional column 'Tracking ID' which contains the tracking ID for each cell ID at the corresponding time

# Sort the DataFrame by Time column
ModelFrame.sort_values('Time', inplace=True)

# Create a new column to store the number of prior track divisions
ModelFrame['# Prior Track Divisions'] = 0

# Iterate over each row
for i, row in ModelFrame.iterrows():
    tracking_id = row['Tracking ID']
    time = row['Time']
    
    # Get all rows with the same tracking ID, from earlier time points
    prior_rows = ModelFrame[(ModelFrame['Tracking ID'] == tracking_id) & (ModelFrame['Time'] < time)]
    
    # Count the number of prior divisions (the number of appearances of a tracking id within a specified time range represents the number of cell divisions of that divisiont track during that timerange)
    prior_divisions = len(prior_rows)
    
    # Assign the count to the 'number of prior divisions' column
    ModelFrame.loc[i, '# Prior Track Divisions'] = prior_divisions

# Sort the DataFrame back to its original order
ModelFrame.sort_index(inplace=True)


ModelFrame['# of Total Divisions'] = ModelFrame.groupby('Tracking ID')['Time'].apply(lambda x: (x > x.iloc[0]).sum())
ModelFrame['X coordinate'] = ModelFrame.apply(lambda row: TA_practice.find_x_value(row[4], row[0]), axis=1)
ModelFrame['Y coordinate'] = ModelFrame.apply(lambda row: TA_practice.find_y_value(row[4], row[0]), axis=1)


# The following are dictionaries, with key == timepoint and values == desired quantity (e.g array of cell edges, or value of cell pressure)

Pressure = {}
Cell_Edges = {} 
Tensions = {}
CellShapeTensor = {}
CellStressTensor = {}

# this for loop iterates through all frames and fills in the dictionary with the desired values

for frame in range(TA_practice.N_FRAMES):
    TPdata = TA_practice.get_TP_data(TA_practice.TPfilename+"_{:0>4}.txt".format(frame))
    
    cell_pressure = TPdata[2]
    Pressure[frame] = cell_pressure

    edge_tensions = TPdata[1]
    Tensions[frame] = edge_tensions

    cell_edges = TA_practice.get_skeltonized_data(TA_practice.txtfilename+"_{:0>4}.txt".format(frame), False)[5]
    Cell_Edges[frame] = cell_edges

    cellshape_tensor = np.load(TA_practice.RESULT_DIR + '/cellshape_tensor/' + str(frame) + '.npy')
    CellShapeTensor[frame] = cellshape_tensor

    cellstress_tensor = np.load(TA_practice.RESULT_DIR + '/stress_tensor/' + str(frame) + '.npy')
    CellStressTensor[frame] = cellstress_tensor


ModelFrame = ModelFrame[ModelFrame['Tracking ID'].notna()]

ModelFrame['Cell Pressure'] = ModelFrame.apply(lambda row: TA_practice.get_cell_characteristic(Pressure, row['Cell ID'], row['Time']), axis=1)
ModelFrame['Cell Edges'] = ModelFrame.apply(lambda row: TA_practice.get_cell_characteristic(Cell_Edges, row['Cell ID'], row['Time']), axis=1)
ModelFrame['Cell Edges Avg Tension'] = ModelFrame.apply(lambda row: TA_practice.get_avg_edge_tension(Tensions, row['Cell Edges'], row['Time']), axis=1)
ModelFrame['Cell Shape Tensor'] = ModelFrame.apply(lambda row: TA_practice.get_cell_characteristic(CellShapeTensor, row['Cell ID'], row['Time']), axis=1)
ModelFrame['Cell Stress Tensor'] = ModelFrame.apply(lambda row: TA_practice.get_cell_characteristic(CellStressTensor, row['Cell ID'], row['Time']), axis=1)

ModelFrame['# of Edges'] = ModelFrame['Cell Edges'].apply(lambda x: len(x))

# Now ModelFrame contains an additional column with the # Edges of a Cell

get_first_item = lambda arr: arr[0][0]
get_second_item = lambda arr: arr[0][1]
get_third_item = lambda arr: arr[1][0]
get_fourth_item = lambda arr: arr[1][1]

ModelFrame = ModelFrame.dropna(subset=['Cell Shape Tensor'])
ModelFrame = ModelFrame.dropna(subset=['Cell Stress Tensor'])

folder_path = '/Users/josephschull/Desktop/Sugimura Lab/ML model' # change to desired filepath
ModelFrame.to_csv(folder_path + '/ModelFrame', index = False) # change to desired name
print(ModelFrame)

# the following code extracts values from the 2D arrays of Cell Shape Tensor and Cell Stress Tensor, and replaces the existing Cell Stress Tensor/Cell Shape Tensor columns with these values

ModelFrame['Cell Shape Tensor [0,0]'] = ModelFrame['Cell Shape Tensor'].apply(get_first_item)
ModelFrame['Cell Shape Tensor [0,1]'] = ModelFrame['Cell Shape Tensor'].apply(get_second_item)
ModelFrame['Cell Shape Tensor [1,0]'] = ModelFrame['Cell Shape Tensor'].apply(get_third_item)
ModelFrame['Cell Shape Tensor [1,1]'] = ModelFrame['Cell Shape Tensor'].apply(get_fourth_item)

ModelFrame['Cell Stress Tensor [0,0]'] = ModelFrame['Cell Stress Tensor'].apply(get_first_item)
ModelFrame['Cell Stress Tensor [0,1]'] = ModelFrame['Cell Stress Tensor'].apply(get_second_item)
ModelFrame['Cell Stress Tensor [1,0]'] = ModelFrame['Cell Stress Tensor'].apply(get_third_item)
ModelFrame['Cell Stress Tensor [1,1]'] = ModelFrame['Cell Stress Tensor'].apply(get_fourth_item)

folder_path = '/Users/josephschull/Desktop/Sugimura Lab/ML model' # change to desired filepath
ModelFrame.to_csv(folder_path + '/ModelFrame_2D_extracted', index = False) # change to desired name
ModelFrame


# Multiple Linear Regression Model
### The following code utilizes the 'ModelFrame' data in order to predict # Total Divisions, using a MLR Model

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt

ModelFrame = pd.read_csv('/Users/josephschull/Desktop/Sugimura Lab/ML model/model_input_dataframe')  # Replace 'your_data.csv' with the actual filename or path

# WE ARE PREDICTING TOTAL NUMBER OF CELL DIVISIONS

# Define the independent variables (X) and dependent variable (y)
X = ModelFrame[['Time', 'x coordinate', 'y coordinate',
                'Cell Edges Avg Tension', 'Cell Pressure']]
y = ModelFrame['# Total Divisions']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)
regression_model = LinearRegression()
regression_model.fit(X_train, y_train)


y_pred = regression_model.predict(X_test)


mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)


plt.scatter(y_test, y_pred)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Multiple Linear Regression Model')
plt.show()


print('Model Performance:')
print('Mean Squared Error (MSE):', mse)
print('Coefficient of Determination (R^2):', r2)

# Random Forest Model
### The following code utilizes the 'ModelFrame' data in order to predict # Total Divisions

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

ModelFrame = pd.read_csv('/Users/josephschull/Desktop/Sugimura Lab/ML model/model_input_dataframe')  

# Define the independent variables (X) and dependent variable (y)
X = ModelFrame[['Time', 'x coordinate', 'y coordinate', '# of Prior Divisions',
                'Cell Edges Avg Tension', 'Cell Pressure']]
y = ModelFrame['# Total Divisions']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)

rf = RandomForestRegressor()

param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [None, 5, 10],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}

grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, scoring='r2')
grid_search.fit(X_train, y_train)
best_params = grid_search.best_params_

rf_best = RandomForestRegressor(**best_params)

rf_best.fit(X_train, y_train)

y_pred = rf_best.predict(X_test)

r2 = r2_score(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)

plt.scatter(y_test, y_pred)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('Actual vs. Predicted')
plt.show()

print("R^2 Score: {:.2f}".format(r2))
print("Mean Squared Error (MSE): {:.2f}".format(mse))
print("Mean Absolute Error (MAE): {:.2f}".format(mae))
