In [None]:
'''
This algorithm was constructed to estimate migration rates of cancer cells within a tumor, based on a phylogeographic 
reconstruction using single-cell DNA-seq or single-cell CNV data. It is a work in progress. Although it yields reliable
results when given simulated data and data from one actual cancer case, modifications to increase its accuracy and
flexibility are required. It is not yet clear whether the behavior of actual cancer cells matches the underlying model
sufficiently for this algorithm to generate useful results, and the limitations of the method have not yet been fully
explored.

(March 25, 2020)
'''

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from scipy.stats import binned_statistic
import pandas as pd
import numpy as np
import random
import math
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
import copy
from sys import exit
from statistics import mean, median
from scipy import misc
from scipy.special import comb as sp_comb


class Phylogeny(object):
    def __init__(self, phylogeny_name):
        self.name = phylogeny_name
        self.n2n = {}     # dictionary mapping names onto nodes
        self.n2e ={}     # dictionary mapping names onto edges
        self.n_ctr = 0     # counter for naming nodes
        self.e_ctr = 0     # counter for naming edges
        self.L = 0.0     # log-likelihood for this phylogeny
        self.psi = [1.0]     # migration rates
        self.psi_num = 1     # number of different migration rates (states)
        self.em_trans = pd.DataFrame(np.array([[0.0, -50.0],[-50.0,0.0]]))     
            # transition probabilities among migration states
        self.migr_priors = None
        self.mutn_num = 1     # number of types of mutational changes possible
        self.mutn_rate = [0.0]     # mutation rates, one per type of mutation
        self.root = None
        self.angles = None
        self.omni = 0.0
        
        #self.m = [0.0 for x in range(migr_states)]     # current estimates of migration rates
                                                       #    for each migration phenotype
        #self.priors = []     # DataFrame: priors for the states for each CNV
        #self.m_priors = [1/float(migr_states)for x in range(migr_states)]
        #self.root = 0
        
    def __str__(self):
        return self.name
    
class Node(object):
    def __init__(self, node_name):
        """ 
        node_name: string
        Initializes a node. 
        Name is a string representing the name of this node,
        parent is None, and no children exist.
        i: number of characters
        states: list containing the labels for the character states, as strings
        """        
        self.name = node_name
        self.anc_edge = None  # edge leading to the ancestor of this node
        self.desc_edges = []  # edges leading to the descendants of this node
        self.location = (0.0,0.0)  # spatial x,y coordinates of this node
        self.flag = False   # for tracking which nodes have been updated during a traversal
        self.time = 0.0     # time before present
        self.duct = 'X'
        
        # self.loc_variance = 0
        # self.included = True     # in case estimates of some nodes are negligible
        # self.time = (0,0)     # best estimate and variance
        # self.f = None
        # self.migr_P = None
        #self.states = states
        #self.chars = pd.DataFrame([[0.0 for x in range(i)]for y in states], index = states) 
        #self.location = (0,0,0)     # best estimate of x,y,z-coords
        #self.migr_Ps = [0.0 for x in range(migr_states)]
        #self.b = None
        
    def __str__(self):
        return self.name  
    
class Edge(object):
    def __init__(self, p, edge_name, anc_node, desc_node):
        self.name = edge_name
        self.anc = anc_node
        self.desc = desc_node
        self.time = anc_node.time - desc_node.time  # the duration of time spent on this branch
        self.dist = (((anc_node.location[0]-desc_node.location[0])**2 +
                      (anc_node.location[1]-desc_node.location[1])**2)**0.5)
    
        self.bk_anc = [0.0 for i in range(p.psi_num)]    # b-sub-k, backwards algorithm values, headed toward root
        self.bk_desc = [0.0 for i in range(p.psi_num)]    # b-sub-k, backwards algorithm values, headed toward tips
        self.marginals = [0.0 for i in range(p.psi_num)]    # marginal probabilities for each migration state for this edge
        self.normalized_marginals = [1.0/p.psi_num for i in range(p.psi_num)]    # normalized to sum to 1
        self.mutations = 0
        self.angle = get_angle(anc_node.location, desc_node.location)   # note that this is from the perspective of 
                                                                        # the ancestral node!
        
         #self.migr_state = [0.0 for i in range(p.psi_num)]             # probability that this branch 
                                                                        # was traversed in each of the two migration phenotypes,
                                                                        # 0 = epithelial, 1 = mesenchymal
        
    def __str__(self):
        return self.name
    

class Control_Panel(object):
    """
    For passing settings and initial states to the reconstruction algorithms.
    """
    def __init__(self):
        self.psi = [0.0, 0.0]  # migration rates
        self.em_trans = pd.DataFrame(np.array([[0.0, -50.0],[-50.0,0.0]]))  # transition probabilities
        self.migr_priors = [math.log(0.5), math.log(0.5)]
        self.mutation_rate = [0.0]
        self.null_location = [0.0, 0.0]
        self.null_time = 0.0
        self.initialize_locations = 'random'
        self.mutation_pseudocount = 1
        self.min_edge_time = 0.1
        self.location_epsilon = 0.0001
        self.location_step = 0.5
        self.L_epsilon = 0.001
        self.initial_percentage_epithelial = 95
        self.psi_epsilon = 0.001
        self.time_epsilon = 0.001
        self.iter_limit = 500
        self.time_step = 0.005
        self.space_sigma = 200.0
        self.time_sigma = 0.1
        self.sigma_T = 1.0
        self.sigma_T_step = 0.95
        self.exp_T_init = 1.0
        self.exp_T = self.exp_T_init
        self.exp_T_stop = -3.0
        self.exp_T_step = (self.exp_T_init - self.exp_T_stop)/self.iter_limit
        self.omni = False
        self.acceptance = 'top'
        self.correction = 0.0
        self.corr2 = 1.0
        self.now = False
        self.set_psi_num = 1
    
    def __str__(self):
        return None
    
    def get_step(self, axis):
        
        if axis == -1:
            return (abs(np.random.normal(0, self.time_sigma*self.sigma_T)))
        
        else:
            return (abs(np.random.normal(0, self.space_sigma*self.sigma_T)))

    
class Proposal(object):
    """
    For proposing node movements in space or time during optimization.
    """
    def __init__(self, ctr, p2, node):
        self.name = 'pr'+str(ctr)
        self.node = node
        self.location = node.location
        self.time = node.time
        self.de0_time = node.desc_edges[0].time
        self.de0_dist = node.desc_edges[0].dist
        self.de1_time = node.desc_edges[1].time
        self.de1_dist = node.desc_edges[1].dist
        
        if node.anc_edge == None:
            self.ae_time = None
            self.ae_dist = None
            
        else:
            self.ae_time = node.anc_edge.time
            self.ae_dist = node.anc_edge.dist
            
        self.angles = []
        self.omni = 0.0
        
        self.L = 0.0
        

def text_to_array(FILENAME):
    inFile = open(FILENAME, 'r')
    line = inFile.read()
    array = []
    while (1):
        nextStart = line.find('\n')
        if nextStart == -1:

            temp1 = line
            if temp1 != '':
                array.append(temp1)
                
            return array
            
        else:
            temp1 = line[:nextStart]
            array.append(temp1)
            if (nextStart+1) > len(line):
                return array
            
            else:
                temp2 = line[nextStart+1:]
                line = temp2
                
def accept_proposal(p2, cp, pr):
    n = pr.node
    n.location = pr.location
    n.time = pr.time
    n.desc_edges[0].time = pr.de0_time
    n.desc_edges[0].dist = pr.de0_dist
    n.desc_edges[1].time = pr.de1_time
    n.desc_edges[1].dist = pr.de1_dist
    if n.anc_edge != None:
        n.anc_edge.time = pr.ae_time
        n.anc_edge.dist = pr.ae_dist
    
    if cp.omni == True:
        p2.angles = pr.angles[:]
        p2.omni = pr.omni
    
    return (p2)


def branch_dist(a_locn, b_locn):
    return (((a_locn[0]-b_locn[0])**2 +
         (a_locn[1]-b_locn[1])**2)**0.5)
         

def clear_phylogeny (p, cp):
    p2 = copy.deepcopy(p)
    p2.name = 'Ph2'
    p2.L = 0.0
    p2.psi = cp.psi
    p2.em_trans = cp.em_trans
    p2.migr_priors = cp.migr_priors[:]
    p2.mutn_rate = cp.mutation_rate
    
    for edge in p2.n2e.values():
        edge.bk_anc = [0.0 for i in range(p2.psi_num)]
        edge.bk_desc = [0.0 for i in range(p2.psi_num)]
        edge.marginals = [0.0 for i in range(p2.psi_num)]
        
    return (p2)


def emission(cp, e, psi):
    """
    e.time: edge length, in terms of time
    e.dist: distance between ancestral and descendant nodes for this edge
    psi: variance in descendant position, scaled by generation time
    
    Takes a single value for psi; calculates the probability of seeing e.dist given e.time and psi
    """
    if e.dist == 0:
        return (0.0)

    f1 = math.log(e.dist)
    f2 = math.log(e.time)+2*math.log(psi)
    f3 = (e.dist**2)/(2*e.time*(psi**2))
    
    f = (f1 - f2 - f3)
    return (f)


def emission_hypothetical(cp, distance, time, psi):
    f1 = math.log(distance)
    f2 = math.log(time) + 2*math.log(psi)
    f3 = (distance**2)/(2*time*(psi**2))
    
    f = (f1 - f2 - f3)
    return (f)


def estimate_parameters(cp, p2):
    sums = [0.0 for i in range(p2.psi_num)]
    wts = [0.0 for i in range(p2.psi_num)]     # these are the sums across all edges for the normalized marginals, i.e., 
                         #    when normalized, they are the proportion of the time that the lineages
                         #    are probably in each state
    

    for e in p2.n2e.values():
        
        if p2.psi_num > 1:
            try:
                mdiff = math.exp(e.marginals[0]-e.marginals[1])
                
            except:
                print (e.marginals[0], e.marginals[1])
                exit(0)
                
            mpr = [mdiff/(mdiff+1), 1/(mdiff+1)]
            if (mpr[0]==0):
                mpr[0]=1/10**10
                mpr[1]=1 - mpr[0]

            if (mpr[1]==0):
                mpr[1]=1/10**10
                mpr[0]=1 - mpr[1]
                
            e.normalized_marginals = mpr     # note that these are likelihoods, not log-likelihoods!
                
        else:
            e.normalized_marginals = [1.0]
            
        
        
        for i in range(p2.psi_num):
            p2 = estimate_psi_v2(cp, p2, i)
   
    temp_np = [sum([e.normalized_marginals[i] for e in p2.n2e.values()]) for i in range(p2.psi_num)]
    new_priors = [math.log(temp_np[i]/sum(temp_np)) for i in range(p2.psi_num)]
    
    em_trans_temp = pd.DataFrame(np.array([[0.0, 0.0],[0.0,0.0]]))
    em_trans_temp_wts = [0.0, 0.0]
    internal_nodes = [a for a in p2.n2n.values() if len(a.desc_edges) > 0]
    for n in internal_nodes:
        states_k = [de.normalized_marginals for de in n.desc_edges]
        if n.anc_edge == None:
            states_j = [math.exp(new_priors[i]) for i in range(len(new_priors))]
            
        else:
            states_j = n.anc_edge.normalized_marginals
            
        for anc_state in range(p2.psi_num):
            for desc_state in range(p2.psi_num):
                for de in range(2):
                    em_trans_temp.iloc[anc_state, desc_state]+=(states_j[anc_state]*states_k[de][desc_state])
                    em_trans_temp_wts[anc_state]+=states_j[anc_state]
                    
    em_trans_temp2 = pd.DataFrame(np.array([[(em_trans_temp.loc[j, k]/em_trans_temp_wts[j]) 
                                           for k in range(p2.psi_num)] for j in range(p2.psi_num)]))
    
    em_trans_new = pd.DataFrame(np.array([[math.log(em_trans_temp2.loc[j,k]/sum(em_trans_temp2.loc[j,:]))
                                           for k in range(p2.psi_num)] for j in range(p2.psi_num)]))
    
    return (p2, new_priors, em_trans_new)


def estimate_psi_v2(cp, p2, i):
    #print (i)
    direction = -1
    step = 1
    epsilon = 0.001
    psi_temp = p2.psi[i]
    L_sum = 0.0
    L_wts = 0.0
    for e in p2.n2e.values():
        L_sum += e.normalized_marginals[i]*emission(cp, e, psi_temp)
        L_wts += e.normalized_marginals[i]

    try:
        L_old = L_sum/L_wts

    except ZeroDivisionError:
        p2.psi[i] = 1.0
        return (p2)
    
    L_orig = L_old
    psi_temp += (direction*step)
    while True:
        L_sum = 0.0
        L_wts = 0.0
        for e in p2.n2e.values():
            L_sum += e.normalized_marginals[i]*emission(cp, e, psi_temp)
            L_wts += e.normalized_marginals[i]
            
        try:
            L_new = L_sum/L_wts
            
        except ZeroDivisionError:
            p2.psi[i] = 1.0
            return (p2)
        
        if (abs(L_new - L_old) < epsilon):
            p2.psi[i] = psi_temp #*cp.corr2
            
            return (p2)
        
        else:
            if L_new > L_old:
                psi_temp += (direction*step)
                
            else:
                direction *= -1
                step /= 2
                psi_temp += (direction*step)
                
            L_old = L_new

def estimate_psi_v3(cp, p2, i):
    #print (i)
    direction = 1
    step = 10
    epsilon = 0.1
    L_old = -10000000
    p_temp = copy.deepcopy(p2)
    while True:
        p_temp2 = get_L(cp, p_temp)
        L_new = p_temp2.L
        print (L_old, L_new)
        if (abs(L_new - L_old) < epsilon):
            p2.psi[i] = p_temp2.psi[i]
            return (p2)
        
        else:
            p_temp = copy.deepcopy(p_temp2)
            if L_new > L_old:
                p_temp.psi[i] += (direction*step)
                
            else:
                direction *= -1
                step /= 2
                p_temp.psi[i] += (direction*step)
                
            L_old = L_new
        
            
def estimate_times (p2, cp):
    # initialize times for internal nodes based on number of mutations along branches
    # estimate times for edges
    def et_recursion (node):
        if node.desc_edges == []:
            node.time = 0.0
            return (node.time)
        
        else:
            mutn_ct = 0.0
            for e in node.desc_edges:
                mutn_ct += float(e.mutations + cp.mutation_pseudocount)
                mutn_ct += float(et_recursion(e.desc))
                
            temp = [(e.desc.time+cp.min_edge_time) for e in node.desc_edges]
            temp.append(mutn_ct/2.0)
            node.time = max (temp)
            return (node.time)
    
    for node in p2.n2n.values():
        node.time = cp.null_time
        
    for edge in p2.n2e.values():
        edge.time = 0.0
        
    temp1 = et_recursion(p2.root)
    for n in p2.n2n.values():
        n.time /= temp1
      
    p2.mutn_rate = temp1   
    for e in p2.n2e.values():
        e.time = e.anc.time - e.desc.time

    return (p2)


def fbb2(cp, p):
        
    def to_tips(cp, p, edge, fd_node):
        """
        Returns the total probability of all paths that lead to all tips not descended from this edge, to 
        starting this edge in migration state k.
        """
        sibling = get_sibling(edge, edge.anc.desc_edges)
        for j in range (p.psi_num):
            fd_node[j] += mult_logaddexp([(sibling.bk_anc[k]+p.em_trans.loc[j][k]) for k in range(p.psi_num)])

        edge.bk_desc = [0.0 for i in range(p.psi_num)]
        
        for l in range(p.psi_num):
            edge.bk_desc[l] += mult_logaddexp([(fd_node[j]+p.em_trans.loc[j][l]) for j in range(p.psi_num)])
            
        edge.marginals = [(edge.bk_anc[i] + edge.bk_desc[i] - p.L) for i in range(p.psi_num)]
                     
        if edge.desc.desc_edges == []:
            return (p)
        
        else:
            for k in range(p.psi_num):
                edge.bk_desc[k] += emission(cp, edge, p.psi[k])
                
            for new_edge in edge.desc.desc_edges:
                p = to_tips(cp, p, new_edge, edge.bk_desc)
                
            return (p)
            
            sibling = get_sibling(edge, edge.anc.desc_edges)
            for j in range (p.psi_num):
                fd_node[j] += mult_logaddexp([(sibling.bk_anc[k]+p.em_trans.loc[j][k]) for k in range(p.psi_num)])
                
            edge.bk_desc = [0.0 for i in range(p.psi_num)]
            for l in range(p.psi_num):
                edge.bk_desc[l] += mult_logaddexp([(fd_node[j]+p.em_trans.loc[j][l]) for j in range(p.psi_num)])
                
        
    
    p = get_L (cp, p)
    
    for new_edge in p.root.desc_edges:
        p = to_tips(cp, p, new_edge, p.migr_priors[:])

    return (p)


def get_angle (a,b):
    """
    Returns the angle of a line between points a and b (from a's perspective!).
    Returns values between 0 and 2*pi.
    """
    return (np.arctan2(b[1]-a[1], b[0]-a[0]))%(2*math.pi)


def angle_between (desc_angle,anc_angle):
    return (desc_angle-anc_angle)%(2*math.pi)
    

def get_distances (location, neighbor_locations):
    return ([branch_dist(location, n) for n in neighbor_locations])


def get_L(cp, p):
    #bk_root = [0.0 for i in range(p.psi_num)]
    bk_root = p.migr_priors[:]
    for desc_edge in p.root.desc_edges:     # for each descendant edge:
        p, bk_temp = to_root(cp, p, desc_edge)     # get bk_anc for that edge
        for j in range (p.psi_num):     # for each possible state at the root of the phylogeny:
            tempj = mult_logaddexp([(bk_temp[k]+p.em_trans.loc[j][k]) for k in range(p.psi_num)])
            bk_root[j] += tempj
                 # sum over the possible paths that lead to starting the phylogeny in state j
                
    p.L = mult_logaddexp(bk_root)     # this is the total probability across all paths and states
    
    return(p)


def get_L_local (cp, p2, edges, distances, times, psi, to_optimize = 'location'):
    L_temp = 0.0
    for i in range(len(edges)):
        L_temp += mult_logaddexp([((emission_hypothetical(cp, distances[i],
                times[i],psi[k]))+math.log(edges[i].normalized_marginals[k])) for k in range(p2.psi_num)])
            
    return (L_temp)


def get_proposal(p2, cp, pr_ctr, node, axis, direction, step):
    pr = Proposal(pr_ctr, p2, node)
    if (axis == -1):
        pr.de0_time += (direction*step)
        pr.de1_time += (direction*step)
        temp = [pr.de0_time, pr.de1_time]
        if node.anc_edge != None:
            pr.ae_time -= (direction*step)
            temp.append(pr.ae_time)
            
        if (min(temp)<0):
            return None
        
        else:
            edges = [de for de in node.desc_edges if de.dist != 0]
            distances = [de.dist for de in node.desc_edges if de.dist != 0]
            if node.anc_edge != None:
                if node.anc_edge.dist != 0:
                    edges.append(node.anc_edge)
                    distances.append(node.anc_edge.dist)
                
            times = temp 
            L_temp = get_L_local (cp, p2, edges, distances, times, p2.psi, to_optimize = 'time')
            if cp.omni == True:
                L_temp += p2.omni
                pr.omni = p2.omni
                pr.angles = p2.angles
                
            pr.time = node.time + (direction*step)
            
    else:
        edges = [e for e in node.desc_edges if e.dist != 0]
        times = [de.time for de in node.desc_edges if de.dist != 0]
        neighbor_locations = [de.desc.location for de in node.desc_edges if de.dist != 0]
        if node.anc_edge != None:
            if node.anc_edge.dist != 0:
                edges.append(node.anc_edge)
                times.append(node.anc_edge.time)
                neighbor_locations.append(node.anc_edge.anc.location)
            
        x_new = node.location[0] + ((direction*step)*abs(axis-1))
        y_new = node.location[1] + ((direction*step)*axis)
        new_distances = get_distances((x_new, y_new), neighbor_locations)
        L_temp = get_L_local (cp, p2, edges, new_distances, times, p2.psi)
        pr.location = (x_new, y_new)
        pr.de0_dist, pr.de1_dist = new_distances[0], new_distances[1]
        if node.anc_edge != None:
            pr.ae_dist = new_distances[2]
         
        if cp.omni == True:
            involved = [de.desc.name for de in node.desc_edges]
            involved.append(node.name)
            if node.anc_edge != None:
                involved.append(node.anc_edge.anc.name)

            temp_angles = p2.angles.loc[~p2.angles['node'].isin(involved)][:]

            de0 = node.desc_edges[0].desc
            if node.anc_edge == None:
                anc_temp = None

            else:
                anc_temp = node.anc_edge.anc

            angles_add = get_proposed_angles(node, pr.location, anc_temp, 
                                                 node.desc_edges[0].desc, node.desc_edges[1].desc)

            pr.angles = pd.concat([temp_angles, angles_add])
            pval, m = omnibus(np.array(pr.angles.loc[:,'angle']))
            pr.omni = math.log(pval)
            L_temp += pr.omni
        
                         
    pr.L = L_temp
    return (pr)


def get_proposed_angles(node, proposed_loc, anc, de0, de1):
    nde0 = get_angle(proposed_loc, de0.location)
    nde1 = get_angle(proposed_loc, de1.location)
    nodes = [node.name]
    angles = [angle_between(nde0, nde1)]
    if node.anc_edge != None:
        nanc = get_angle(anc.location, proposed_loc)
        angles.append(angle_between(nde0, nanc))
        angles.append(angle_between(nde1, nanc))
        nodes *= 3
        temp = [(nanc+math.pi)%(2*math.pi)]
        temp.append(get_sibling(node.anc_edge, node.anc_edge.anc.desc_edges).angle)
        if anc.anc_edge != None:
            temp.append((anc.anc_edge.angle+math.pi)%(2*math.pi))
            for i in range(len(temp)-1):
                for j in range(i+1, len(temp)):
                    angles.append(angle_between(temp[j],temp[i]))
                    nodes.append(anc.name)
        
    for de in [de0, de1]:
        temp = [dde.angle for dde in de.desc_edges]
        temp.append (get_angle(de.location, proposed_loc))
        for i in range(len(temp)-1):
            for j in range(i+1, len(temp)):
                angles.append(angle_between(temp[j],temp[i]))
                nodes.append(de.name)
                
    new_angles = pd.DataFrame({'node':nodes,'angle':angles})
    return (new_angles)       
        

def get_sibling(edge, desc_edges):
    temp = desc_edges[:]
    temp.remove(edge)
    return (temp[0])


def initialize_angles_df(p2):
    # initialize the list of angles stored (by associated node) in the Phylogeny object for
    # the omnibus function
    nodes = []
    angles = []
    internal_nodes = [a for a in p2.n2n.values() if len(a.desc_edges) > 0]
    for a in internal_nodes:
        temp = []
        if a.anc_edge != None:
            temp.append((a.anc_edge.angle+math.pi)%(2*math.pi))
            
        for de in a.desc_edges:
            temp.append(de.angle)
  
        for i in range(len(temp)-1):
            for j in range(i+1, len(temp)):
                angles.append(angle_between(temp[i],temp[j]))
                nodes.append(a.name)
                
    p2.angles = pd.DataFrame({'node':nodes,'angle':angles})
    pval, m = omnibus(np.array(p2.angles.loc[:,'angle']))
    p2.omni = math.log(pval)
    return (p2)


def initialize_locations_descendants (p2, cp):
    # initialize locations of internal nodes at a position on the line between their descendants,
    #     with distances from each descendant proportional to the amount of time 
    # initialize distances for edges based on these locations
    
    def ild_recursion(node):
        if node.desc_edges == []:
            return (node.anc_edge, node.location)
        
        else:
            de = [None, None]
            desc_locns = [[0.0, 0.0], [0.0, 0.0]]
            for i in range (2):
                de[i], desc_locns[i] = ild_recursion(node.desc_edges[i].desc)
                
            b_root = [(e.time**0.5) for e in de]
            b_ratio = [b_root[i]/sum(b_root) for i in range(2)]
            x_new = desc_locns[0][0]+(b_ratio[0]*(desc_locns[1][0]-desc_locns[0][0]))
            y_new = desc_locns[0][1]+(b_ratio[0]*(desc_locns[1][1]-desc_locns[0][1]))
            
            node.location = [x_new, y_new]
            return (node.anc_edge, node.location)
        
    internal_nodes = [a for a in p2.n2n.values() if len(a.desc_edges) > 0]
    for a in internal_nodes:
        a.location = cp.null_location
        
    temp1, temp2 = ild_recursion(p2.root)
    
    for e in p2.n2e.values():
        e.dist = (((e.anc.location[0]-e.desc.location[0])**2 +
                      (e.anc.location[1]-e.desc.location[1])**2)**0.5)
        
        e.angle = get_angle(e.anc.location, e.desc.location)
        
    if cp.omni == True:
        p2 = initialize_angles_df(p2)
        
    return (p2)


def initialize_locations_random (p2, cp):
    # initialize locations of internal nodes with a random uniform distribution between
    #     the min and max of the tip nodes in each dimension
    # initialize distances for edges based on these locations
    
    x_range = [a.location[0] for a in p2.n2n.values()]
    y_range = [a.location[1] for a in p2.n2n.values()]

    internal_nodes = [a for a in p2.n2n.values() if len(a.desc_edges) > 0]

    for a in internal_nodes:
        a.location = (random.uniform(min(x_range), max(x_range)),random.uniform(min(y_range), max(y_range)))
        
    for e in p2.n2e.values():
        e.dist = (((e.anc.location[0]-e.desc.location[0])**2 +
                      (e.anc.location[1]-e.desc.location[1])**2)**0.5)
        
        e.angle = get_angle(e.anc.location, e.desc.location)
        
    if cp.omni == True:
        p2 = initialize_angles_df(p2)
        
    return (p2)


def initialize_psi (p2, cp):
    psi_temp = [math.log(e.dist/(e.time**0.5)) for e in p2.n2e.values()]
    psi_temp = sorted(psi_temp)
    if p2.psi_num == 2:
        # based on edge distances and times, initialize psi
        cutoff = np.percentile(psi_temp, cp.initial_percentage_epithelial)
        psi_e = (mean([math.exp(i) for i in psi_temp if i < cutoff]))
        psi_m = (mean([math.exp(i) for i in psi_temp if i >= cutoff]))
        p2.psi = [psi_e, psi_m]
        return (p2)
    
    elif p2.psi_num == 1:
        p2.psi = [mean([math.exp(i) for i in psi_temp])]
        return (p2)

def initialize_sigmas (p2, cp):
    x_range = [a.location[0] for a in p2.n2n.values()]
    y_range = [a.location[1] for a in p2.n2n.values()]
    cp.space_sigma = ((max(x_range)-min(x_range)) + (max(y_range)-min(y_range)))/20
    cp.time_sigma = 0.01
    
    return (p2, cp)


def insert_node(p, node_to_add, tip_node, edge_to_split, e_counter):
    total_time = edge_to_split.time
    descendant = edge_to_split.desc
    
    p.n2e[e_counter+1]=Edge(p, 'e'+str(e_counter+1), edge_to_split.anc, node_to_add)   # new branch: bottom half of edge_to_split
    p.n2e[e_counter+1].marginals = edge_to_split.marginals
    edge_to_split.anc.desc_edges.remove(edge_to_split)
    edge_to_split.anc.desc_edges.append(p.n2e[e_counter+1])
    node_to_add.anc_edge = p.n2e[e_counter+1]
    
    edge_to_split.anc = node_to_add
    a = math.log(np.random.uniform(0.0,1.0))
    if a <(p.em_trans.iloc[np.argmax(p.n2e[e_counter+1].marginals)]
                                   [abs(np.argmax(p.n2e[e_counter+1].marginals) - 1)]):
        edge_to_split.marginals = [p.n2e[e_counter+1].marginals[1],p.n2e[e_counter+1].marginals[0]]     # toggle migr_state
    
    p.n2e[e_counter]=Edge(p, 'e'+str(e_counter), node_to_add, tip_node)
    tip_node.anc_edge=p.n2e[e_counter]
    p.n2e[e_counter].marginals = p.n2e[e_counter+1].marginals
    a = math.log(np.random.uniform(0.0,1.0))
    if a <(p.em_trans.iloc[np.argmax(p.n2e[e_counter+1].marginals)]
                                   [abs(np.argmax(p.n2e[e_counter+1].marginals) - 1)]):
        edge_to_split.marginals = [p.n2e[e_counter+1].marginals[1],p.n2e[e_counter+1].marginals[0]]     # toggle migr_state
    
    node_to_add.desc_edges = [p.n2e[e_counter],edge_to_split]
    
    e_counter += 2
    
    return (e_counter)


def mult_logaddexp(P_array):
    lae = P_array[0]
    for x in range(1,len(P_array)):
        lae = np.logaddexp(lae,P_array[x])
        
    return lae


def omnibus(alpha, w=None, sz=np.radians(1), axis=None):
    """
    Computes omnibus test for non-uniformity of circular data. The test is also
    known as Hodges-Ajne test.
    H0: the population is uniformly distributed around the circle
    HA: the populatoin is not distributed uniformly around the circle
    Alternative to the Rayleigh and Rao's test. Works well for unimodal,
    bimodal or multimodal data. If requirements of the Rayleigh test are
    met, the latter is more powerful.
    :param alpha: sample of angles in radian
    :param w:      number of incidences in case of binned angle data
    :param sz:    step size for evaluating distribution, default 1 deg
    :param axis:  compute along this dimension, default is None
                  if axis=None, array is raveled
    :return pval: two-tailed p-value
    :return m:    minimum number of samples falling in one half of the circle
    References: [Fisher1995]_, [Jammalamadaka2001]_, [Zar2009]_
    """

    if w is None:
        w = np.ones_like(alpha)

    assert w.shape == alpha.shape, "Dimensions of alpha and w must match"

    alpha = alpha % (2 * np.pi)
    n = np.sum(w, axis=axis)

    dg = np.arange(0, np.pi, np.radians(1))

    m1 = np.zeros((len(dg),) + alpha.shape[1:])
    m2 = np.zeros((len(dg),) + alpha.shape[1:])

    for i, dg_val in enumerate(dg):
        m1[i, ...] = np.sum(
            w * ((alpha > dg_val) & (alpha < np.pi + dg_val)), axis=axis)
        m2[i, ...] = n - m1[i, ...]

    m = np.concatenate((m1, m2), axis=0).min(axis=axis)

    n = np.atleast_1d(n)
    m = np.atleast_1d(m)
    A = np.empty_like(n)
    pval = np.empty_like(n)
    idx50 = (n > 50)

    if np.any(idx50):
        A[idx50] = np.pi * np.sqrt(n[idx50]) / 2 / (n[idx50] - 2 * m[idx50])
        pval[idx50] = np.sqrt(2 * np.pi) / A[idx50] * \
                      np.exp(-np.pi ** 2 / 8 / A[idx50] ** 2)

    if np.any(~idx50):
        pval[~idx50] = 2 ** (1 - n[~idx50]) * (n[~idx50] - \
                                               2 * m[~idx50]) * sp_comb(n[~idx50], m[~idx50])

    return pval.squeeze(), m


def optimize_node (cp, p2, node):
    # get 7 proposals for this node: stay the same, + and - for x, y, and time
    # choose one with probability proportional to its L_local
    # implement that one
    pr_ctr = 0
    axes = [0,1]
    dirs = [1,-1]
    proposals = [get_proposal(p2, pr_ctr, node, axis = 1, direction = 1, step = 0)]   # no change
    for a in axes:
        for d in dirs:
            pr_ctr += 1
            proposals.append(get_proposal(p2, pr_ctr, node, axis = a, 
                                          direction = d, step = cp.get_step(a)))
    
    L_temp = [((math.exp(pr.L))**(1/cp.exp_T)) for pr in proposals]
    Ls = [a/sum(L_temp) for a in L_temp]
    
    accept = np.argmax(Ls)
    
    
    p2 = accept_proposal(p2, cp, pr = proposals[accept])
    return (p2)

def optimize_node_v2 (cp, p2, node):
    pr_ctr = 0
    axes = [0, 1]
    dirs = [1,-1]
    proposals = [get_proposal(p2, cp, pr_ctr, node, axis = 1, direction = 1, step = 0)]   # no change
    for a in axes:
        for d in dirs:
            pr_ctr += 1
            proposals.append(get_proposal(p2, cp, pr_ctr, node, axis = a, 
                                          direction = d, step = cp.get_step(a)))
    
    for d in dirs:
        pr_ctr += 1
        st = cp.get_step(-1)
        while (True):
            temp = get_proposal(p2, cp, pr_ctr, node, axis = -1, direction = d, step = st)
            if temp != None:
                proposals.append(temp)
                break
                
            else:
                st/=2
    
    L_temp = [((math.exp(pr.L))**(1/math.exp(cp.exp_T))) for pr in proposals]
    try:
        Ls = [a/sum(L_temp) for a in L_temp]
        
    except ZeroDivisionError:
        if (len(list(set(L_temp)))==1):
            Ls = [1.0/len(L_temp) for a in L_temp]
        else:
            print ([pr.L for pr in proposals])
            print (1/math.exp(cp.exp_T))
            print (L_temp)
            exit (0)
    
    if cp.acceptance == 'top':
        accept = np.argmax(Ls)
        
    else:
        if cp.acceptance == 'proportional':
            accept = np.random.choice(a = range(len(proposals)), p = Ls)
        
    #if accept == 0:
        #cp.sigma_T /= 2
        #if cp.sigma_T < cp.opt_converge:
        #    return (True, cp, p2)
        
        #else:
        #    return (True, cp, p2) #change back to False
    
    #else:
    p2 = accept_proposal(p2, cp, pr = proposals[accept])
    return (True, cp, p2) #change back to False
def plot_history (p):
    #print ('*')
    x = [a.location[0] for a in p.n2n.values()]
    y = [a.location[1] for a in p.n2n.values()]

    buffer = 5

    plt.xlim(min(x)-buffer, max(x)+buffer)
    plt.ylim(min(y)-buffer, max(y)+buffer)

    for a in p.n2e.values():
        try:
            x1, y1 = a.anc.location
            x2, y2 = a.desc.location
            plt.plot([x1, x2],[y1,y2],'ro-')
            
        except AttributeError:
            print (a.name, a.desc, a.anc, a.length)
            exit(0)

    plt.show()
    
    return (None)


def Ramanujan(n):
    # Approximation for log (n!) given by Srinivasa Ramanujan
    try:
        return (n*math.log(n) - n + (math.log(n*(1+4*n*(1+2*n))))/6 
            + math.log(math.pi)/2)
    
    except ValueError:
        print (n)
        exit(0)


def simulate_evolution(i, psi, em_trans, fraction_mesenchymal, mutn_rate):
    p = Phylogeny('Ph')
    root_num = -1
    i_temp = i
    t_curr = 0
    t = []   # list of coalescence times, from youngest to oldest
    while (i>1):
        ETi = (i*(i-1))/2.0   #mean (and variance) time to first coalescence for i lineages
        t.append(t_curr+(np.random.exponential(1/ETi)))
        i-=1
        t_curr = t[-1]

    t_max = max(t)
    t = [t[i]/t_max for i in range(len(t))]
    i = i_temp
    p.psi = psi
    p.em_trans = em_trans
    p.migr_priors = [math.log(0.5),math.log(0.5)]
    anc = [(0.0, 0.0), (0.0,0.0)]   # location of the ancestor to each tip so far
    for x in range(i):   #create tip nodes, with dangling edges to be attached later
        p.n2n[x]=Node('n'+str(x))
        p.n2n[x].time = 0.0
    p.n2n[root_num]=Node('n'+str(root_num))   # create root node
    p.root = p.n2n[root_num]
    p.root.time = t[-1]

    for x in range(2):
        new_name = 'e'+str(x)
        p.n2e[x]=Edge(p, new_name, p.root, p.n2n[x])
        p.n2n[x].anc_edge = p.n2e[x]
        if np.random.uniform(0.0,1.0)<fraction_mesenchymal:
            p.n2e[x].marginals = [-50.0, 0.0]    # using these temporarily to keep track of the migration
                                               # state of each edge
            
        else:
            p.n2e[x].marginals = [0.0, -50.0]

    p.root.desc_edges = [p.n2e[0],p.n2e[1]]

    n_counter = int(i)
    e_counter = 2

    Newick = '(0,1)'
    hist = [anc]
    i = 2   # the name of the tip to be added next
    t_ct = -2   # the current time before present
    g = []
    while (i<i_temp):
        b = t[t_ct+1]-t[t_ct] # time elapsed since last coalescent event
        direction = [np.random.uniform(0,2*math.pi) for x in range(len(anc))]
        try:
            m_state = [np.argmax(p.n2n[n_temp].anc_edge.marginals) for n_temp in range(i)]
            dist = [abs(np.random.normal(0.0,p.psi[m_state[x]]*(b**0.5))) for x in range(len(anc))]
        except ValueError:
            exit(0)
            
        anc_new = [(anc[x][0]+(math.cos(direction[x])*dist[x]),
                    anc[x][1]+(math.sin(direction[x])*dist[x])) for x in range(len(anc))]
        # anc_new is the positions of all lineages at this next coalescent event

        g_new = random.randint(0,len(anc)-1)   # randomly select existing tip as sibling of new tip
        g.append(g_new)
        new_clade = '('+str(g_new)+','+str(len(anc))+')'
        Newick = Newick.replace(str(g_new),new_clade)   # put this new clade into the Newick string

        anc_new.append(anc_new[g_new])   # location of coalescent = where new edge starts
        p.n2n[n_counter]=Node('n'+str(n_counter))   # new internal node
        p.n2n[n_counter].location = anc_new[g_new]
        p.n2n[n_counter].time = t[t_ct]
        e_counter = insert_node(p, p.n2n[n_counter], p.n2n[i], p.n2n[g_new].anc_edge, e_counter)
        n_counter+=1
        anc = anc_new
        hist.append(anc) # keep track of locations of lineages from previous coal. times

        i+=1
        t_ct-=1

    b = t[0] 
    direction = [np.random.uniform(0,2*math.pi) for x in range(len(anc))]
    try:
        m_state = [np.argmax(p.n2n[n_temp].anc_edge.marginals) for n_temp in range(i)]
        dist = [abs(np.random.normal(0.0,p.psi[m_state[x]]*(b**0.5))) for x in range(len(anc))]
    except ValueError:
        exit(0)

    anc_new = [(anc[x][0]+(math.cos(direction[x])*dist[x]),
                anc[x][1]+(math.sin(direction[x])*dist[x])) for x in range(len(anc))]
    
    hist.append(anc_new)
    for a in range(len(anc_new)):
        p.n2n[a].location = anc_new[a]

    for i in p.n2e.values():
        i.time = i.anc.time - i.desc.time
        i.mutations = np.random.poisson(mutn_rate*i.time)
        i.dist = (((i.anc.location[0]-i.desc.location[0])**2 +
                          (i.anc.location[1]-i.desc.location[1])**2)**0.5)
        
    _ = plot_history (p)
    
    return (p)


def to_root(cp, p, edge):
        """
        Returns the total probability of all paths that lead to all descendant tips from starting this edge
        in migration state k.
        """
        if edge.desc.desc_edges == []:
            edge.bk_anc = [(p.migr_priors[k] + emission(cp, edge, p.psi[k])) for k in range(p.psi_num)]
            return (p, edge.bk_anc[:])
        
        else:
            edge.bk_anc = [0.0 for i in range(p.psi_num)]
            for desc_edge in edge.desc.desc_edges:     # for each descendant edge:
                p, bk_temp = to_root(cp, p, desc_edge)     # get bk_anc for that edge
                for j in range (p.psi_num):     # for each possible state at this edge:
                    edge.bk_anc[j] += mult_logaddexp([(bk_temp[k]+p.em_trans.loc[j,k]) for k in range(p.psi_num)])
                         # sum over the possible paths that lead to arriving at the end of this edge in state j
                    edge.bk_anc[j] += emission(cp, edge, p.psi[j])
                         # multiply by the probability that this edge was crossed in state j
                    
            if cp.now == True:
                print (edge.name, edge.bk_anc, edge.bk_desc, edge.marginals)
                        
            return (p, edge.bk_anc[:])
        
def verify_psi(cp, p2):
    direction = 1
    step = 1
    epsilon = 0.001
    L_old = -10000
    psi_temp = p2.psi[0]
    while True:
        L_tot = sum([emission(cp, e, psi_temp) for e in p2.n2e.values()])
        if (abs(L_tot - L_old) < epsilon):
            p2.psi[0] = psi_temp
            return (p2)
        
        else:
            if L_tot > L_old:
                psi_temp += (direction*step)
                
            else:
                direction *= -1
                step /= 2
                psi_temp += (direction*step)
                
            L_old = L_tot
                            
def build_phylogeny():
    FILENAME = "C:/Users/Brian/Desktop/tempdata/seg_locations.txt"
    temp = text_to_array(FILENAME)
    temp = [i.split('\t') for i in temp]
    temp2 = np.array([[int(i[1]), int(i[2]), i[3]] for i in temp[1:]])
    temp_colnames = [i[0] for i in temp[1:]]
    temp3 = pd.DataFrame(temp2)
    labels = {i-1:temp[i][0] for i in range(1, len(temp))}
    locations = temp3.rename(index = labels, columns = {0:'x', 1:'y', 2:'duct'})

    FILENAME = "C:/Users/Brian/Desktop/tempdata/seg_height_2.txt"

    temp = text_to_array(FILENAME)
    temp = [i.split('\t') for i in temp]
    heights = [float(temp[i][1]) for i in range(1, len(temp))]
    #print (heights[:5])

    FILENAME = "C:/Users/Brian/Desktop/tempdata/seg_merge.txt"

    temp = text_to_array(FILENAME)
    temp = [i.split('\t') for i in temp[1:]]
    merge = [[int(i[1]), int(i[2])] for i in temp]

    pn = Phylogeny('pn')
    n_ctr = 1
    for i in range(len(labels)):
        j = i+1
        try:
            temp = (int(locations.loc[labels[i],'x']), int(locations.loc[labels[i],'y']))
            pn.n2n[-j] = Node('n'+str(-j))
            pn.n2n[-j].location = temp
            pn.n2n[-j].duct = locations.loc[labels[i],'duct']

        except KeyError:
            pass

    merge_nodes = {}

    for i in range(len(labels)):
        j = i+1
        if -j in pn.n2n.keys():
            merge_nodes[-j] = pn.n2n[-j] 

    for i in range(len(merge)):
        if merge[i][0] in merge_nodes.keys():
            if merge[i][1] in merge_nodes.keys():
                pn.n2n[i+1]=Node('n'+str(i+1))
                n = pn.n2n[i+1]
                d0 = merge_nodes[merge[i][0]]
                d1 = merge_nodes[merge[i][1]]
                n.time = heights[i]
                e0 = Edge(pn, 'e'+str(pn.e_ctr), n, d0)
                pn.n2e[pn.e_ctr]=e0
                e1 = Edge(pn, 'e'+str(pn.e_ctr+1), n, d1)
                pn.n2e[pn.e_ctr+1]=e1
                d0.anc_edge = e0
                d1.anc_edge = e1
                n.desc_edges = [e0, e1]
                pn.e_ctr += 2


                merge_nodes[i+1]=n

            else:
                merge_nodes[i+1]=merge_nodes[merge[i][0]]

        else:
            if merge[i][1] in merge_nodes.keys():
                merge_nodes[i+1]=merge_nodes[merge[i][1]]

            else:
                pass
    root = []
    for i in pn.n2n.values():
        if i.anc_edge == None:
            root.append(i)

    if len(root) == 1:
        pn.root = root[0]

    else:
        print ('multiple roots')
        exit(0)
        
    return(pn)


def reconstruct (p, cp):
    """
    a: Control_Panel object containing parameters for reconstruction
    """
    p2 = clear_phylogeny (p, cp)
    if cp.set_psi_num == 1:
        p2.psi_num = 1
        p2.psi = [1.0]
        p2.em_trans = cp.em_trans_1
        for e in p2.n2e.values():
            e.normalized_marginals = [1.0]
        
    elif cp.set_psi_num == 2:
        p2.psi_num = 2
        p2.psi = [1.0, 1.0]
        p2.migr_priors = [(cp.initial_percentage_epithelial/100), 1-(cp.initial_percentage_epithelial/100)]
        p2.em_trans = cp.em_trans_2
        for e in p2.n2e.values():
            e.normalized_marginals = [0.5, 0.5]
        
    else:
        print ('please set psi_num to 1 or 2')
        exit(0)
        
    print (p2.psi_num)
    #p2 = estimate_times(p2, cp)
    if cp.initialize_locations == 'random':
        p2 = initialize_locations_random (p2, cp)
        
    elif cp.initialize_locations == 'descendants':
        p2 = initialize_locations_descendants (p2, cp)
        
    else:
        print ('mode for initializing locations not recognized')
        exit(0)
        
    if cp.omni == True:
        p2 = initialize_angles_df(p2)
        
    p2 = initialize_psi (p2, cp)
    p2 = fbb2(cp, p2)
    p2, b, c = estimate_parameters(cp, p2)
    p2, cp = initialize_sigmas (p2, cp)
    num_edges = len(list(p2.n2e.values()))
    cp.corr2 = 0.9
    internal_nodes = [a for a in p2.n2n.values() if len(a.desc_edges) > 0]
    iterations = 0
    limit = cp.iter_limit
    L_record = float('-inf')
    L_old = float('-inf')
    psi_old = p2.psi[:]
    p2_old = copy.deepcopy(p2)
    while (True):
        iterations += 1
        if (iterations > limit):
            return p_record
        
        random.shuffle(internal_nodes)
        for a in internal_nodes:
            optimized, cp, p2 = optimize_node_v2 (cp, p2, a)
        cp.sigma_T *= cp.sigma_T_step
        p2 = fbb2(cp, p2)
        p2, b, c = estimate_parameters(cp, p2)
        if p2.L > L_record and iterations > 2:
            p_record = copy.deepcopy(p2)
            L_record = p2.L
            
        if (iterations%20 == 0):
            print ('psi:',"{:.2f}".format(p2.psi[0]),
                   'L: ',"{:.2f}".format(p2.L))
        p2.migr_priors = b
        p2.em_trans = c
        p2_old = copy.deepcopy(p2)
        cp.exp_T -= cp.exp_T_step
        L_old = p2_old.L

def initialize_control_panel (psi_n, iters):        
    cp = Control_Panel()
    cp.psi = [1000.0]
    cp.em_trans_2 = pd.DataFrame(np.array([[math.log(0.975), math.log(0.025)],[math.log(0.025),math.log(0.975)]]))
    cp.em_trans_1 = pd.DataFrame(np.array([math.log(1.000)]))
    cp.migr_priors = [0.0]
    cp.mutation_rate = [100.0]
    cp.null_location = [0.0, 0.0]
    cp.null_time = 0.0
    cp.initialize_locations = 'random'
    cp.L_epsilon = 0.1
    cp.location_epsilon = 0.01
    cp.initial_percentage_epithelial = 95
    cp.psi_epsilon = 0.0001
    cp.mutation_pseudocount = 0.001
    cp.sigma_T = 1.0
    cp.opt_converge = 0.01
    cp.omni = False
    cp.acceptance = 'proportional'
    cp.exp_T_init = 1.0
    cp.exp_T = cp.exp_T_init
    cp.exp_T_stop = -3.0
    cp.exp_T_step = (cp.exp_T_init - cp.exp_T_stop)/cp.iter_limit
    cp.sigma_T_step = math.exp((math.log(0.01))/cp.iter_limit)
    cp.correction = 0.0
    cp.corr2 = 1.0
    cp.iter_limit = iters
    cp.set_psi_num = psi_n
    
    return (cp)

psi_true = [50.0, 50.0]
psi_e_to_mutn_ratio = 0.01
print_flag = 0



em = pd.DataFrame(np.array([[math.log(0.98), math.log(0.02)],[math.log(0.02),math.log(0.98)]]))

p = simulate_evolution(i = 250, psi = psi_true, em_trans = em, fraction_mesenchymal = 0.05, 
                     mutn_rate = psi_true[0]/psi_e_to_mutn_ratio)

for rep in range(1):
    cp = initialize_control_panel (1, 750)
    p_record= reconstruct (pn, cp)
    print ('final estimate: ', p_record.psi)