# tree HMM Simulations

Simulate a Markov process on binary trees to generate simulated data to use for testing a tree HMM model

In [1]:
#import libraries

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats


In [2]:
class TreeNode:
    def __init__(self, state, hidden_state=None, parent=None):
        self.x = state
        self.hidden_state = hidden_state
        self.parent = parent
        self.left = None
        self.right = None
        self.gamma = None
        self.alpha = None
        self.beta = None

In [3]:
#modify generate node and generate tree so that the hidden state of the two daughter cells are no longer independent

def generate_node_pair(parent_node, hidden_states, conditional_probs, means, covars, prob_prune):
    # Get the joint probabilities of the daughter cells given the parent's hidden state
    joint_probs = conditional_probs[parent_node.hidden_state]
    # Sample the joint hidden states of the daughter cells
    joint_state = np.random.choice(np.arange(len(hidden_states) ** 2), p=joint_probs.flatten())

    left_right_states = np.unravel_index(joint_state, (len(hidden_states), len(hidden_states)))
    
    left_hidden_state = hidden_states[left_right_states[0]]
    right_hidden_state = hidden_states[left_right_states[1]]
    
    # Sample the daughter cells' continuous states
    left_state = np.random.normal(means[left_hidden_state], vars[left_hidden_state])
    right_state = np.random.normal(means[right_hidden_state], vars[right_hidden_state])
    
    left_node = TreeNode(left_state, hidden_state=left_hidden_state, parent=parent_node)
    right_node = TreeNode(right_state, hidden_state=right_hidden_state, parent=parent_node)

    #with smmall probability remove one of the daughter nodes
    if np.random.rand() < prob_prune:
        if np.random.rand() < 0.5:
            left_node = None
        else:
            right_node = None
    
    return left_node, right_node


def generate_tree(n_generations, hidden_states, initial_probs, conditional_probs, means, vars, prob_prune=0.2):
    # Sample the hidden state of the root
    root_hidden_state = np.random.choice(hidden_states, p=initial_probs)
    root_params = means[root_hidden_state], vars[root_hidden_state]
    root_state = np.random.normal(*root_params)
    root = TreeNode(root_state, hidden_state=root_hidden_state)

    if n_generations == 1:
        return root

    # Generate the rest of the tree
    queue = [(root, 1)]
    while queue:
        parent, generation = queue.pop(0)

        if parent is not None:

            left_node, right_node = generate_node_pair(parent, hidden_states, conditional_probs, means, vars, prob_prune)
            parent.left, parent.right = left_node, right_node

            if generation < n_generations - 1:
                queue.append((left_node, generation + 1))
                queue.append((right_node, generation + 1))

    return root



In [64]:
# Example parameters for Figs 2 and 3

n_trees = 150
#Create a list of trees
trees = []

n_generations = 5
hidden_states = [0, 1]
initial_probs = [0.3, 0.7]

prob_prune = 0.


conditional_probs = [
    np.array([
        [0.8, 0.],  # Joint probabilities for child states given parent hidden state 0
        [0., 0.2]
    ]),
    np.array([
        [0.4, 0.],  # Joint probabilities for child states given parent hidden state 1
        [0., 0.6]
    ])
]

#normalize conditional probs
for i in range(len(conditional_probs)):
    conditional_probs[i] = conditional_probs[i] / np.sum(conditional_probs[i])
  
means = [
    1.0,  # Mean value of x for hidden state 0
    2.0   # Mean value of x for hidden state 1
]

# Std of the shedding probs
vars = [
    0.5,  # Std for hidden state 0
    0.7 # Std for hidden state 1
]

#change the random seed
np.random.seed(421)

# Generate n_trees
trees = [generate_tree(n_generations, hidden_states, initial_probs, conditional_probs, means, vars, prob_prune) for _ in range(n_trees)]
    

In [None]:
# Example parameters for Figs 4 and 5

n_trees = 150
#Create a list of trees
trees = []

n_generations = 5
hidden_states = [0, 1, 2]
initial_probs = [1.0, 0.0, 0.0]

prob_prune = 0.


conditional_probs = [
    np.array([
        [0., 0., 0.],  # Joint probabilities for child states given parent hidden state 0
        [0., 1.0, 0.],
        [0., 0., 0.]
    ]),
    np.array([
        [0., 0., 0.],  # Joint probabilities for child states given parent hidden state 1
        [0., 0., 0.],
        [0., 0., 1.0]
    ]),
    np.array([
        [1., 0., 0.],  # Joint probabilities for child states given parent hidden state 2
        [0., 0., 0.],
        [0., 0., 0.]
    ])
]

#normalize conditional probs
for i in range(len(conditional_probs)):
    conditional_probs[i] = conditional_probs[i] / np.sum(conditional_probs[i])
  
means = [
    1.0,  # Mean value of x for hidden state 0
    1.0,   # Mean value of x for hidden state 1
    2.0   # Mean value of x for hidden state 2
]

# Std of the shedding probs
vars = [
    0.21,  # Std for hidden state 0
    0.21, # Std for hidden state 1
    0.21 # Std for hidden state 2
]

#change the random seed
np.random.seed(421)

# Generate n_trees
trees_3 = [generate_tree(n_generations, hidden_states, initial_probs, conditional_probs, means, vars, prob_prune) for _ in range(n_trees)]
    

In [66]:
from scipy.stats import norm

#Compute the probability of the subtree rooted at tree (all daughters of tree and their progenies) conditioned on the hidden state of tree
#p(T_{subtree} | H_{tree} = h)
def forward(tree, conditional_probs, means, covars):
    if tree.left is None and tree.right is None:
        return np.ones(len(hidden_states))

    if tree.right is None:
        left_alpha = forward(tree.left, conditional_probs, means, covars)
        right_alpha = np.ones(len(hidden_states))
    elif tree.left is None:
        left_alpha = np.ones(len(hidden_states))
        right_alpha = forward(tree.right, conditional_probs, means, covars)
    else:
        left_alpha = forward(tree.left, conditional_probs, means, covars)
        right_alpha = forward(tree.right, conditional_probs, means, covars)
    
    if tree.right is None:
        emission_probs_left = np.array([norm.pdf(tree.left.x, means[i], covars[i]) for i in range(len(hidden_states))])
        emission_probs_right = np.ones(len(hidden_states))
    elif tree.left is None:
        emission_probs_left = np.ones(len(hidden_states))
        emission_probs_right = np.array([norm.pdf(tree.right.x, means[j], covars[j]) for j in range(len(hidden_states))])
    else:
        emission_probs_left = np.array([norm.pdf(tree.left.x, means[i], covars[i]) for i in range(len(hidden_states))])
        emission_probs_right = np.array([norm.pdf(tree.right.x, means[j], covars[j]) for j in range(len(hidden_states))])

    alpha = np.zeros(len(hidden_states))

    for i in range(len(hidden_states)):
        for j in range(len(hidden_states)):
            for k in range(len(hidden_states)):
                alpha[i] += emission_probs_left[j] * emission_probs_right[k] * conditional_probs[i][j][k] * left_alpha[j] * right_alpha[k]

    
    return alpha

#backward algorithm
#backward algorithm
#Start from the root of the tree, and recursively compute the probability of the observed state of all nodes except for the subtree rooted at tree but including node tree itself
#p(T / T_{subtree} , H_{tree} = h)
def backward(root, node, initial_probs, conditional_probs, means, covars):

    beta = np.zeros(len(hidden_states))

    #Run the recursion from the root for each hidden state of tree
    for i in range(len(hidden_states)):
        alpha = backward_sub(root, node, i, conditional_probs, means, covars)
        #aveage the alpha values across the hidden states of root
        #Compute emission probabilities for the root node
        emission_probs_root = np.array([norm.pdf(root.x, means[j], covars[j]) for j in range(len(hidden_states))])
        beta[i] = np.sum(alpha * emission_probs_root * initial_probs)

    return beta


def backward_sub(tree, node, node_hidden_state, conditional_probs, means, covars):
    #Run the recursion from root but stop at the leaves or at node tree
    if (tree.left is None and tree.right is None) or tree == node:
        return np.ones(len(hidden_states))

    if tree.right is None:
        left_alpha = backward_sub(tree.left, node, node_hidden_state, conditional_probs, means, covars)
        right_alpha = np.ones(len(hidden_states))
    elif tree.left is None:
        left_alpha = np.ones(len(hidden_states))
        right_alpha = backward_sub(tree.right, node, node_hidden_state, conditional_probs, means, covars)
    else:
        left_alpha = backward_sub(tree.left, node, node_hidden_state, conditional_probs, means, covars)
        right_alpha = backward_sub(tree.right, node, node_hidden_state, conditional_probs, means, covars)

    alpha = np.zeros(len(hidden_states))

    if tree.right == node:
        emission_probs_right = norm.pdf(node.x, means[node_hidden_state], covars[node_hidden_state])

        if tree.left is None:
            emission_probs_left = np.ones(len(hidden_states))
        else:
            emission_probs_left = np.array([norm.pdf(tree.left.x, means[i], covars[i]) for i in range(len(hidden_states))])

        for i in range(len(hidden_states)):
            for j in range(len(hidden_states)):
                    alpha[i] += emission_probs_left[j] * emission_probs_right * conditional_probs[i][j][node_hidden_state] * left_alpha[j] * right_alpha[node_hidden_state]
    
    elif tree.left == node:
        emission_probs_left = norm.pdf(node.x, means[node_hidden_state], covars[node_hidden_state])

        if tree.right is None:
            emission_probs_right = np.ones(len(hidden_states))
        else:
            emission_probs_right = np.array([norm.pdf(tree.right.x, means[j], covars[j]) for j in range(len(hidden_states))])

        for i in range(len(hidden_states)):
            for j in range(len(hidden_states)):
                    alpha[i] += emission_probs_left * emission_probs_right[j] * conditional_probs[i][node_hidden_state][j] * left_alpha[node_hidden_state] * right_alpha[j]

    else:

        if tree.left is None:
            emission_probs_left = np.ones(len(hidden_states))
            emission_probs_right = np.array([norm.pdf(tree.right.x, means[j], covars[j]) for j in range(len(hidden_states))])
        elif tree.right is None:
            emission_probs_right = np.ones(len(hidden_states))
            emission_probs_left = np.array([norm.pdf(tree.left.x, means[i], covars[i]) for i in range(len(hidden_states))])
        else:
            emission_probs_right = np.array([norm.pdf(tree.right.x, means[j], covars[j]) for j in range(len(hidden_states))])
            emission_probs_left = np.array([norm.pdf(tree.left.x, means[i], covars[i]) for i in range(len(hidden_states))])

        for i in range(len(hidden_states)):
            for j in range(len(hidden_states)):
                for k in range(len(hidden_states)):
                    alpha[i] += emission_probs_left[j] * emission_probs_right[k] * conditional_probs[i][j][k] * left_alpha[j] * right_alpha[k]

    return alpha





In [67]:
def mutual_information(conditional_probs):
    n_hidden_states = len(conditional_probs)

    # Calculate the transition matrix for one daughter
    transition_matrix = np.sum(conditional_probs, axis=2)

    # Calculate the joint probability distribution P(j, k | i)
    joint_prob = np.array(conditional_probs)

    # Calculate the marginal probability distributions P(j | i) and P(k | i)
    marginal_prob_j = transition_matrix
    marginal_prob_k = transition_matrix

    # Calculate the mutual information I(j; k | i)
    mutual_information = 0
    for i in range(n_hidden_states):
        for j in range(n_hidden_states):
            for k in range(n_hidden_states):
                joint_prob_ijk = joint_prob[i, j, k]
                marginal_prob_ij = marginal_prob_j[i, j]
                marginal_prob_ik = marginal_prob_k[i, k]

                if joint_prob_ijk > 0 and marginal_prob_ij > 0 and marginal_prob_ik > 0:
                    mutual_information += joint_prob_ijk * np.log2(joint_prob_ijk / (marginal_prob_ij * marginal_prob_ik))

    return mutual_information

# Example usage:
# mi = mutual_information(conditional_probs)
# print("Mutual information:", mi)


In [68]:
def get_all_nodes(tree):
    if tree is None:
        return []
    return [tree] + get_all_nodes(tree.left) + get_all_nodes(tree.right)

def expectation_maximization(trees, hidden_states, initial_probs, conditional_probs, means, covars, num_iterations=100):

    all_initial_probs = []
    all_means = []
    all_covars = []
    all_conditional_probs = []

    for ctr in range(num_iterations):
        print("Iteration %d" % ctr)
        # E-step: Compute expected values of hidden states
        for tree in trees:
            #print("Tree %d" % trees.index(tree))
            nodes = get_all_nodes(tree)
            for node in nodes:
                forward_probs = forward(node, conditional_probs, means, covars)
                if node == tree: # root node
                    #Compute emission probabilities for the root node
                    emission_probs_root = np.array([norm.pdf(tree.x, means[j], covars[j]) for j in range(len(hidden_states))])
                    backward_probs = initial_probs * emission_probs_root
                else:
                    backward_probs = backward(tree, node, initial_probs, conditional_probs, means, covars)
                gamma = forward_probs * backward_probs
                gamma /= np.sum(gamma)
                #Assign the probabilities to the node
                node.gamma = gamma
                node.alpha = forward_probs
                node.beta = backward_probs
        
        # M-step: Update model parameters
        initial_probs = np.mean([tree.gamma for tree in trees], axis=0)

        # Update conditional_probs
        new_conditional_probs = np.zeros_like(conditional_probs)
        for tree in trees:
            #update_conditional_probs(tree, new_conditional_probs)
            update_conditional_probs_Rev2(tree, new_conditional_probs, conditional_probs, means, covars)
        

        #normalize
        new_conditional_probs /= np.sum(new_conditional_probs, axis=(1, 2), keepdims=True)

        #symmetrize
        new_conditional_probs = (new_conditional_probs + new_conditional_probs.transpose(0, 2, 1)) / 2

        conditional_probs = new_conditional_probs

        # Update means and covars of emission distributions
        numerator = np.zeros((len(hidden_states)))
        denominator = np.zeros(len(hidden_states))

        for tree in trees:
            nodes = get_all_nodes(tree)
            for node in nodes:
                for i in range(len(hidden_states)):
                    #compute the numerator of the mean
                    numerator[i] += node.gamma[i] * node.x
                    #denominator
                    denominator[i] += np.sum(node.gamma[i])

        # divide the numerator by the denominator to get the mean
        means = numerator / denominator

        covars_numerators = np.zeros((len(hidden_states)))
        denominator = np.zeros(len(hidden_states))

        for tree in trees:
            nodes = get_all_nodes(tree)
            for node in nodes:
                for i in range(len(hidden_states)):
                    diff = node.x - means[i]
                    covars_numerators[i] += node.gamma[i] * diff * diff
                    denominator[i] += np.sum(node.gamma[i])

        covars = [np.sqrt(covars_numerators[i] / denominator[i]) for i in range(len(hidden_states))]

        #print without using e notation
        np.set_printoptions(suppress=True)
        #print delimiters
        print("--------------------------------------------------")
        print("Iteration %d" % ctr)
        print("Initial probs: %s" % initial_probs)
        print("Conditional probs: %s" % conditional_probs)
        print("Mutual information: %s" % mutual_information(conditional_probs))
        print("Means: %s" % means)
        print("Covars: %s" % covars)
        print("--------------------------------------------------")

        #save the parameters
        all_initial_probs.append(initial_probs)
        all_means.append(means)
        all_covars.append(covars)
        all_conditional_probs.append(conditional_probs)

        
    #Combine all the saved parameters
    all_saved_parameters = [all_initial_probs, all_means, all_covars, all_conditional_probs]
        

    return initial_probs, conditional_probs, means, covars, all_saved_parameters


def update_conditional_probs_Rev2(tree, conditional_probs, old_conditional_probs, means, covars):
    if tree is None or tree.left is None or tree.right is None:
        return

    parent_beta = tree.beta
    left_alpha = tree.left.alpha
    right_alpha= tree.right.alpha


    #get emission probabilities for the daughter cells
    emission_probs_left = np.array([norm.pdf(tree.left.x, means[j], covars[j]) for j in range(len(hidden_states))])
    emission_probs_right = np.array([norm.pdf(tree.right.x, means[j], covars[j]) for j in range(len(hidden_states))])

    joint_probs = np.zeros((len(parent_beta), len(left_alpha), len(right_alpha)))

    for i in range(len(parent_beta)):
        for j in range(len(left_alpha)):
            for k in range(len(right_alpha)):
                joint_probs[i][j][k] = parent_beta[i] * left_alpha[j] * right_alpha[k] * emission_probs_left[j] * emission_probs_right[k] * old_conditional_probs[i][j][k]

    #normalize
    joint_probs /= np.sum(joint_probs)
    conditional_probs += joint_probs

    update_conditional_probs_Rev2(tree.left, conditional_probs, old_conditional_probs, means, covars)
    update_conditional_probs_Rev2(tree.right, conditional_probs, old_conditional_probs, means, covars)




In [None]:
# run the expectation maximization algorithm for Fig. 3

hidden_states = [0, 1]
from sklearn.cluster import KMeans

# Function to assign data points to clusters based on likelihood maximization
def assign_clusters(nodes, means, covars):
    cluster_assignments = []

    for node in nodes:
        point = node.x
        likelihoods = [norm.pdf(point, mean, covar) for mean, covar in zip(means, covars)]
        max_likelihood_cluster = np.argmax(likelihoods)
        cluster_assignments.append(max_likelihood_cluster)

    return cluster_assignments

#initial guess for the means and covariances of the hidden states
#using k-means clustering

#get all the nodes of the trees
data = []
for tree in trees:
    nodes = get_all_nodes(tree)
    #get the data from each node and append to data
    for node in nodes:
        data.append(node.x)

#convert data to a numpy array
data = np.array(data).reshape(-1, 1)

#perform k-means clustering
kmeans = KMeans(n_clusters=len(hidden_states), random_state=420).fit(data)

#the means of the clusters
guess_means = [center[0] for center in kmeans.cluster_centers_]


#the covariances of the clusters
guess_covars = [np.std(data[kmeans.labels_ == i]) for i in range(len(hidden_states))]

#print without using e notation
np.set_printoptions(suppress=True)
#print the results
print("The initial mean guesss was: %s" % guess_means)
print("The initial vars guesss was: %s" % guess_covars)

#get the cluster assignments of the roots of the trees
guess_initial_probs = np.zeros(len(hidden_states))
for tree in trees:
    cluster_assignments = assign_clusters([tree], guess_means, guess_covars)
    guess_initial_probs[cluster_assignments[0]] += 1

#initialize the initial probabilities using the cluster assignments
guess_initial_probs = guess_initial_probs / np.sum(guess_initial_probs)


#initialize the conditional probabilities using the cluster assignments
guess_conditional_probs = np.zeros((len(hidden_states), len(hidden_states), len(hidden_states)))
for tree in trees:
    nodes = get_all_nodes(tree)
    for node in nodes:
        if node.left is not None and node.right is not None: #if both children are not None
            #get assigments of the node and its children
            node_assignments = assign_clusters([node, node.left, node.right], guess_means, guess_covars)
            #update the conditional probabilities
            guess_conditional_probs[node_assignments[0]][node_assignments[1], node_assignments[2]] += 0.5
            guess_conditional_probs[node_assignments[0]][node_assignments[2], node_assignments[1]] += 0.5

#normalize the conditional probabilities
guess_conditional_probs = [guess_conditional_probs[i] / np.sum(guess_conditional_probs[i]) for i in range(len(guess_conditional_probs))]

#guess_conditional_probs = conditional_probs

#guess_means = means
#guess_covars = covars

#print the results
print("The initial guesss was: %s" % guess_initial_probs)

#print the results
print("The initial guesss was: %s" % guess_conditional_probs)

# Run the expectation maximization algorithm
final_initial_probs, final_conditional_probs, final_means, final_covars, all_saved_parameters = expectation_maximization(trees, hidden_states, guess_initial_probs, guess_conditional_probs, guess_means, guess_covars, num_iterations=20)

#print a delimiter
print("--------------------------------------------------")

print("The initial guesss was: %s" % guess_initial_probs)
print("The final guess is: %s" % final_initial_probs)
print("The actual value is %s" % initial_probs)

#print a delimiter
print("--------------------------------------------------")

print("The initial guesss was: %s" % guess_conditional_probs)
print("The final guess is: %s" % final_conditional_probs)
print("The actual value is %s" % conditional_probs)

#print a delimiter
print("--------------------------------------------------")


print("The initial guesss was: %s" % guess_means)
print("The final guess is: %s" % final_means)
print("The actual value is %s" % means)

#print a delimiter
print("--------------------------------------------------")


print("The initial guesss was: %s" % guess_covars)
print("The final guess is: %s" % final_covars)
print("The actual value is %s" % vars)

In [None]:
# run the expectation maximization algorithm for Figs 4 and 5

hidden_states = [0, 1, 2]
from sklearn.cluster import KMeans

# Function to assign data points to clusters based on likelihood maximization
def assign_clusters(nodes, means, covars):
    cluster_assignments = []

    for node in nodes:
        point = node.x
        likelihoods = [norm.pdf(point, mean, covar) for mean, covar in zip(means, covars)]
        max_likelihood_cluster = np.argmax(likelihoods)
        cluster_assignments.append(max_likelihood_cluster)

    return cluster_assignments

#initial guess for the means and covariances of the hidden states
#using k-means clustering

trees = trees_3

#get all the nodes of the trees
data = []
for tree in trees:
    nodes = get_all_nodes(tree)
    #get the data from each node and append to data
    for node in nodes:
        data.append(node.x)

#convert data to a numpy array
data = np.array(data).reshape(-1, 1)

#perform k-means clustering
kmeans = KMeans(n_clusters=len(hidden_states), random_state=420).fit(data)

#the means of the clusters
guess_means = [center[0] for center in kmeans.cluster_centers_]


#the covariances of the clusters
guess_covars = [np.std(data[kmeans.labels_ == i]) for i in range(len(hidden_states))]

#print without using e notation
np.set_printoptions(suppress=True)
#print the results
print("The initial mean guesss was: %s" % guess_means)
print("The initial vars guesss was: %s" % guess_covars)

#get the cluster assignments of the roots of the trees
guess_initial_probs = np.zeros(len(hidden_states))
for tree in trees:
    cluster_assignments = assign_clusters([tree], guess_means, guess_covars)
    guess_initial_probs[cluster_assignments[0]] += 1

#initialize the initial probabilities using the cluster assignments
guess_initial_probs = guess_initial_probs / np.sum(guess_initial_probs)


#initialize the conditional probabilities using the cluster assignments
guess_conditional_probs = np.zeros((len(hidden_states), len(hidden_states), len(hidden_states)))
for tree in trees:
    nodes = get_all_nodes(tree)
    for node in nodes:
        if node.left is not None and node.right is not None: #if both children are not None
            #get assigments of the node and its children
            node_assignments = assign_clusters([node, node.left, node.right], guess_means, guess_covars)
            #update the conditional probabilities
            guess_conditional_probs[node_assignments[0]][node_assignments[1], node_assignments[2]] += 0.5
            guess_conditional_probs[node_assignments[0]][node_assignments[2], node_assignments[1]] += 0.5

#normalize the conditional probabilities
guess_conditional_probs = [guess_conditional_probs[i] / np.sum(guess_conditional_probs[i]) for i in range(len(guess_conditional_probs))]

#guess_conditional_probs = conditional_probs

#guess_means = means
#guess_covars = covars

#print the results
print("The initial guesss was: %s" % guess_initial_probs)

#print the results
print("The initial guesss was: %s" % guess_conditional_probs)

# Run the expectation maximization algorithm
final_initial_probs, final_conditional_probs, final_means, final_covars, all_saved_parameters = expectation_maximization(trees, hidden_states, guess_initial_probs, guess_conditional_probs, guess_means, guess_covars, num_iterations=20)

all_initial_probs_3, all_means_3, all_covars_3, all_conditional_probs_3 = all_saved_parameters

#print a delimiter
print("--------------------------------------------------")

print("The initial guesss was: %s" % guess_initial_probs)
print("The final guess is: %s" % final_initial_probs)
print("The actual value is %s" % initial_probs)

#print a delimiter
print("--------------------------------------------------")

print("The initial guesss was: %s" % guess_conditional_probs)
print("The final guess is: %s" % final_conditional_probs)
print("The actual value is %s" % conditional_probs)

#print a delimiter
print("--------------------------------------------------")


print("The initial guesss was: %s" % guess_means)
print("The final guess is: %s" % final_means)
print("The actual value is %s" % means)

#print a delimiter
print("--------------------------------------------------")


print("The initial guesss was: %s" % guess_covars)
print("The final guess is: %s" % final_covars)
print("The actual value is %s" % vars)



In [None]:
# run the expectation maximization algorithm for Fig 5

hidden_states = [0, 1]
from sklearn.cluster import KMeans

# Function to assign data points to clusters based on likelihood maximization
def assign_clusters(nodes, means, covars):
    cluster_assignments = []

    for node in nodes:
        point = node.x
        likelihoods = [norm.pdf(point, mean, covar) for mean, covar in zip(means, covars)]
        max_likelihood_cluster = np.argmax(likelihoods)
        cluster_assignments.append(max_likelihood_cluster)

    return cluster_assignments

#initial guess for the means and covariances of the hidden states
#using k-means clustering

trees = trees_3

#get all the nodes of the trees
data = []
for tree in trees:
    nodes = get_all_nodes(tree)
    #get the data from each node and append to data
    for node in nodes:
        data.append(node.x)

#convert data to a numpy array
data = np.array(data).reshape(-1, 1)

#perform k-means clustering
kmeans = KMeans(n_clusters=len(hidden_states), random_state=420).fit(data)

#the means of the clusters
guess_means = [center[0] for center in kmeans.cluster_centers_]


#the covariances of the clusters
guess_covars = [np.std(data[kmeans.labels_ == i]) for i in range(len(hidden_states))]

#print without using e notation
np.set_printoptions(suppress=True)
#print the results
print("The initial mean guesss was: %s" % guess_means)
print("The initial vars guesss was: %s" % guess_covars)

#get the cluster assignments of the roots of the trees
guess_initial_probs = np.zeros(len(hidden_states))
for tree in trees:
    cluster_assignments = assign_clusters([tree], guess_means, guess_covars)
    guess_initial_probs[cluster_assignments[0]] += 1

#initialize the initial probabilities using the cluster assignments
guess_initial_probs = guess_initial_probs / np.sum(guess_initial_probs)


#initialize the conditional probabilities using the cluster assignments
guess_conditional_probs = np.zeros((len(hidden_states), len(hidden_states), len(hidden_states)))
for tree in trees:
    nodes = get_all_nodes(tree)
    for node in nodes:
        if node.left is not None and node.right is not None: #if both children are not None
            #get assigments of the node and its children
            node_assignments = assign_clusters([node, node.left, node.right], guess_means, guess_covars)
            #update the conditional probabilities
            guess_conditional_probs[node_assignments[0]][node_assignments[1], node_assignments[2]] += 0.5
            guess_conditional_probs[node_assignments[0]][node_assignments[2], node_assignments[1]] += 0.5

#normalize the conditional probabilities
guess_conditional_probs = [guess_conditional_probs[i] / np.sum(guess_conditional_probs[i]) for i in range(len(guess_conditional_probs))]

#guess_conditional_probs = conditional_probs

#guess_means = means
#guess_covars = covars

#print the results
print("The initial guesss was: %s" % guess_initial_probs)

#print the results
print("The initial guesss was: %s" % guess_conditional_probs)

# Run the expectation maximization algorithm
final_initial_probs, final_conditional_probs, final_means, final_covars, all_saved_parameters = expectation_maximization(trees, hidden_states, guess_initial_probs, guess_conditional_probs, guess_means, guess_covars, num_iterations=20)

all_initial_probs_2, all_means_2, all_covars_2, all_conditional_probs_2 = all_saved_parameters

#print a delimiter
print("--------------------------------------------------")

print("The initial guesss was: %s" % guess_initial_probs)
print("The final guess is: %s" % final_initial_probs)
print("The actual value is %s" % initial_probs)

#print a delimiter
print("--------------------------------------------------")

print("The initial guesss was: %s" % guess_conditional_probs)
print("The final guess is: %s" % final_conditional_probs)
print("The actual value is %s" % conditional_probs)

#print a delimiter
print("--------------------------------------------------")


print("The initial guesss was: %s" % guess_means)
print("The final guess is: %s" % final_means)
print("The actual value is %s" % means)

#print a delimiter
print("--------------------------------------------------")


print("The initial guesss was: %s" % guess_covars)
print("The final guess is: %s" % final_covars)
print("The actual value is %s" % vars)



In [None]:
from collections import defaultdict

def store_ancestors(root, max_depth):
    """
    Traverse the tree and store ancestors for each node up to max_depth.
    """
    ancestors = defaultdict(list) # node: [list of ancestors up to max_depth]
    stack = [(root, 0)] # node, depth

    while stack:
        node, depth = stack.pop()
        if node is None or depth > max_depth:
            continue
        if node.parent:
            ancestors[node] = ancestors[node.parent] + [node.parent]
        stack.append((node.left, depth + 1))
        stack.append((node.right, depth + 1))

    return ancestors

def find_sibling(parent, node):
    """
    Find the sibling of a node.
    """
    return parent.left if parent.right == node else parent.right

def explore_sibling(node, n, pairs, reference_node):
    """
    Explore a sibling subtree to find nodes at a distance n.
    """
    if node is None or n < 0:
        return
    if n == 0:
        pairs.append((reference_node, node))
        return
    explore_sibling(node.left, n-1, pairs, reference_node)
    explore_sibling(node.right, n-1, pairs, reference_node)
    
def find_pairs(root, m, n, max_depth):
    """
    Find all pairs of nodes that are m and n steps away from their common ancestor,
    correctly handling cases where m or n is 0.
    """
    ancestors = store_ancestors(root, max_depth)
    pairs = []

    for node in ancestors:
        # When m is 0, the node itself can be the common ancestor
        if m == 0 and n >= 0:
            #print(np.shape(pairs))
            #print([pair[0].state for pair in pairs if pair[0] and pair[1]])
            explore_tree(node, n, pairs, node)
        # When n is 0, the node itself can be the common ancestor
        elif n == 0 and m >= 0:
            explore_tree(node, m, pairs, node)
        else:
            for ancestor in ancestors[node][:m]:
                sibling = find_sibling(ancestor, node)
                explore_sibling(sibling, n-1, pairs, node)

    return pairs

def explore_tree(node, n, pairs, reference_node):
    """
    Explore the tree to find nodes at a distance n from the given node.
    """
    if node is None or n < 0:
        return
    if n == 0:
        pairs.append((reference_node, node))
        return
    explore_tree(node.left, n-1, pairs, reference_node)
    explore_tree(node.right, n-1, pairs, reference_node)
    

In [None]:
#function to compute the pearson correlation coefficient with bootstrapping

def pearsonr_bs(x, y, per_cutoffs):

    #bootstrap the data

    #get the number of samples
    n = len(x)

    #number of bootstrap samples
    n_bs = 1000

    #create an empty list to store the correlation coefficients
    r_bs = []

    x = np.array(x)
    y = np.array(y)


    #go through all bootstrap samples
    for i in range(n_bs):
            
            #get the bootstrap sample
            sample = np.random.choice(n, n, replace = True)
            
            x_bs = x[sample]
            y_bs = y[sample]

            #compute the correlation coefficient
            r_bs.append(stats.pearsonr(x_bs, y_bs)[0])

            # compute coefficient of multiple correlation
                #r_bs.append(stats.linregress(x_bs, y_bs)[2])
                

    #compute the confidence intervals

    #get for example the 2.5 and 97.5 percentile
    r_lo = np.percentile(r_bs, per_cutoffs[0]) #2.5
    r_hi = np.percentile(r_bs, per_cutoffs[1]) #97.5

    #compute the mean
    r_mean = np.mean(r_bs)

    return r_mean, (r_lo, r_hi)
    


In [None]:
# generate data for Fig. 5

hidden_states = [0, 1, 2]
trees_3 = [generate_tree(n_generations, hidden_states, initial_probs, conditional_probs, means, vars, prob_prune) for _ in range(n_trees)]
trees_counter_3 = [generate_tree(n_generations, hidden_states, all_initial_probs_3[-1], all_conditional_probs_3[-1], all_means_3[-1], all_covars_3[-1], prob_prune) for _ in range(n_trees)]

hidden_states = [0, 1]
trees_counter_2 = [generate_tree(n_generations, hidden_states, all_initial_probs_2[-1], all_conditional_probs_2[-1], all_means_2[-1], all_covars_2[-1], prob_prune) for _ in range(n_trees)]


In [None]:
# generate cors for Fig. 5

#compute various correlation coefficients

n_generations = 5

# initialize matrix of pearson correlation coefficients
all_r = np.zeros((n_generations, n_generations))
all_r_lo = np.zeros((n_generations, n_generations))
all_r_hi = np.zeros((n_generations, n_generations))

all_r_3 = np.zeros((n_generations, n_generations))
all_r_lo_3 = np.zeros((n_generations, n_generations))
all_r_hi_3 = np.zeros((n_generations, n_generations))

all_r_2 = np.zeros((n_generations, n_generations))
all_r_lo_2 = np.zeros((n_generations, n_generations))
all_r_hi_2 = np.zeros((n_generations, n_generations))


# define the percentiles for the confidence intervals
per_cutoffs = [2.5, 97.5]


#initialize the first value
all_r[0, 0] = 1
all_r_lo[0, 0] = 1
all_r_hi[0, 0] = 1

all_r_3[0, 0] = 1
all_r_lo_3[0, 0] = 1
all_r_hi_3[0, 0] = 1

all_r_2[0, 0] = 1
all_r_lo_2[0, 0] = 1
all_r_hi_2[0, 0] = 1

#go through all distances in the pairs


for m in range(n_generations):
    for n in range(n_generations):
        if m == 0 and n == 0:
            continue
        pairs = [find_pairs(tree, m, n, max(m, n)) for tree in trees_3]
        array1, array2 = np.transpose(np.array(pairs))
        array1 = np.concatenate(array1)
        array2 = np.concatenate(array2)
        array1 = [array1[i].x for i in range(len(array1))]
        array2 = [array2[i].x for i in range(len(array2))]
        r_mean, (r_lo, r_hi) = pearsonr_bs(array1, array2, per_cutoffs)

        all_r[m, n] = r_mean
        all_r_lo[m, n] = r_lo
        all_r_hi[m, n] = r_hi

        pairs = [find_pairs(tree, m, n, max(m, n)) for tree in trees_counter_3]
        array1, array2 = np.transpose(np.array(pairs))
        array1 = np.concatenate(array1)
        array2 = np.concatenate(array2)
        array1 = [array1[i].x for i in range(len(array1))]
        array2 = [array2[i].x for i in range(len(array2))]
        r_mean, (r_lo, r_hi) = pearsonr_bs(array1, array2, per_cutoffs)

        all_r_3[m, n] = r_mean
        all_r_lo_3[m, n] = r_lo
        all_r_hi_3[m, n] = r_hi

        
        pairs = [find_pairs(tree, m, n, max(m, n)) for tree in trees_counter_2]
        array1, array2 = np.transpose(np.array(pairs))
        array1 = np.concatenate(array1)
        array2 = np.concatenate(array2)
        array1 = [array1[i].x for i in range(len(array1))]
        array2 = [array2[i].x for i in range(len(array2))]
        r_mean, (r_lo, r_hi) = pearsonr_bs(array1, array2, per_cutoffs)

        all_r_2[m, n] = r_mean
        all_r_lo_2[m, n] = r_lo
        all_r_hi_2[m, n] = r_hi


all_r_std = (all_r_hi-all_r_lo)/2
all_r_std = np.sqrt(all_r_std**2 + all_r_std.transpose()**2)/2
all_r_std

all_r_std_3 = (all_r_hi_3-all_r_lo_3)/2
all_r_std_3 = np.sqrt(all_r_std_3**2 + all_r_std_3.transpose()**2)/2
all_r_std_3

all_r_std_2 = (all_r_hi_2-all_r_lo_2)/2
all_r_std_2 = np.sqrt(all_r_std_2**2 + all_r_std_2.transpose()**2)/2
all_r_std_2