# Test using Newick

In [80]:
x = "(A,B,(C,D)E)F"

In [2]:
from newick import loads

trees = loads('(A,B,(C,D)E)F;')

print(trees[0].ascii_art())

    ┌─A
──F─┼─B
    │   ┌─C
    └─E─┤
        └─D


In [9]:
[n.name for n in trees[0].descendants]

['A', 'B', 'E']

In [14]:
len(trees)

1

In [35]:
from newick import loads

# Load the tree from the Newick string
trees = loads('((G,B(Y,Z))A,(C,D)E)F;')
tree = trees[0]  # Since we have only one tree in the list

print(trees[0].ascii_art())


# Function to get the leaf names under a node
def get_leaves(node):
    if node.is_leaf:
        return {node.name}
    leaves = set()
    if node.name:
        leaves = {node.name}
    for child in node.descendants:
        leaves.update(get_leaves(child))
    return leaves

# Function to compute all splits from the tree
def get_splits(node, all_leaves):
    splits = []
    if not node.is_leaf:
        # Get the leaves under this node
        subtree_leaves = get_leaves(node)

        # Remaining leaves are those not in the subtree
        remaining_leaves = all_leaves - subtree_leaves

        # Add the split (subtree_leaves, remaining_leaves)
        if remaining_leaves and subtree_leaves:
            splits.append((subtree_leaves, remaining_leaves))


        # Recursively compute splits for each child
        for child in node.descendants:
            splits.extend(get_splits(child, all_leaves))
            #splits.extend(get_splits(child, subtree_leaves))
            
        if node.is_leaf:
            single_leaf_split = ({node.name}, subtree_leaves - {node.name})
            splits.append(single_leaf_split)

    return splits

def return_splits(tree):
    # Get all leaves of the tree except the first one
    all_leaves = get_leaves(tree) - {tree.name}
    print(all_leaves)
    # Compute all the splits
    return get_splits(tree, all_leaves)

# Compute all the splits
splits = return_splits(tree)

# Output the splits
for i, split in enumerate(splits, 1):
    print(f"Split {i}: {split}")


        ┌─G
    ┌─A─┤
    │   │   ┌─Y
    │   └───┤
──F─┤       └─Z
    │   ┌─C
    └─E─┤
        └─D
{'G', 'C', 'Y', 'A', 'D', 'E', 'Z'}
Split 1: ({'A', 'G', 'Z', 'Y'}, {'D', 'E', 'C'})
Split 2: ({'Z', 'Y'}, {'G', 'C', 'A', 'D', 'E'})
Split 3: ({'D', 'E', 'C'}, {'A', 'G', 'Z', 'Y'})


In [37]:
from newick import loads

# Load the tree from the Newick string
trees = loads('((G,B(Y,Z(U,V)))A,(C,D)E)F;')
tree = trees[0]  # Since we have only one tree in the list


# Function to get the node names under a node
def get_nodes(node):
    if node.is_leaf:
        return {node.name}

    leaves = set()

    for child in node.descendants:
        leaves.update(get_nodes(child))
    if node.name:
        leaves.add(node.name)
    return leaves


# Function to compute splits for RF distance (excluding trivial splits)
def get_splits(node):
    splits = []

    if node.is_leaf:
        return splits

    current_leaves = []
    for descendant in node.descendants:
        current_leaves.append(get_nodes(descendant))
        splits += get_splits(descendant)
    splits.append(tuple(current_leaves))

    return splits


# Compute all the splits for RF
splits = get_splits(tree)

# Output the splits
for i, split in enumerate(splits, 1):
    print(f"Split {i}: {split}")

Split 1: ({'U'}, {'V'})
Split 2: ({'Y'}, {'V', 'U'})
Split 3: ({'G'}, {'V', 'U', 'Y'})
Split 4: ({'C'}, {'D'})
Split 5: ({'G', 'V', 'Y', 'A', 'U'}, {'D', 'E', 'C'})


In [38]:
def rf(tree_1, tree_2):
    """Count the number of different splits in a set of lists of splits.
    """
    splits_1 = get_splits(tree_1)
    splits_2 = get_splits(tree_2)
    
    diff_splits = 0
    for split in splits_1:
        if split not in splits_2:
            diff_splits += 1
    for split in splits_2:
        if split not in splits_1:
            diff_splits += 1
    return diff_splits/2


In [39]:
trees = loads('((G,(Y,Z))A,(C,D)Q)F;((G,(Y,Z))A,(C,D)E)F;')
tree_1 = trees[0]
tree_2 = trees[1]

splits_1 = get_splits(tree_1)
splits_2 = get_splits(tree_2)

print(tree_1.ascii_art())
print(tree_2.ascii_art())

splits_1

        ┌─G
    ┌─A─┤
    │   │   ┌─Y
    │   └───┤
──F─┤       └─Z
    │   ┌─C
    └─Q─┤
        └─D
        ┌─G
    ┌─A─┤
    │   │   ┌─Y
    │   └───┤
──F─┤       └─Z
    │   ┌─C
    └─E─┤
        └─D


[({'Y'}, {'Z'}),
 ({'G'}, {'Y', 'Z'}),
 ({'C'}, {'D'}),
 ({'A', 'G', 'Y', 'Z'}, {'C', 'D', 'Q'})]

In [40]:
import numpy as np


def splits_to_binary_matrix(splits, all_tips):
    """
    Convert splits into a binary matrix.

    Parameters:
    splits: list of sets representing splits (each set is a group of tips)
    all_tips: list of all tips (leaves)

    Returns:
    binary_matrix: Binary matrix representing splits
    """
    # Number of splits and number of tips
    n_splits = len(splits)
    n_tips = len(all_tips)
    
    # Create a tip index mapping (tip to column index)
    tip_index = {tip: i for i, tip in enumerate(all_tips)}
    
    # Initialize binary matrix with zeros
    binary_matrix = np.zeros((n_splits, n_tips), dtype=int)
    
    # Fill in the matrix
    for i, split in enumerate(splits):
        for tip in split:
            binary_matrix[i][tip_index[tip]] = 1  # Mark '1' for tips in the split
    
    return binary_matrix

# Example usage
all_tips = ['A', 'B', 'C', 'D', 'E']  # List of all tips
splits = [{'A', 'B'}, {'A', 'C', 'D'}]  # Example splits

binary_matrix = splits_to_binary_matrix(splits, all_tips)
print(binary_matrix)


[[1 1 0 0 0]
 [1 0 1 1 0]]


In [41]:
import numpy as np
from scipy.optimize import linear_sum_assignment



def jaccard_similarity(split_1, split_2, n_tips, exponent, allow_conflict):
    """
    Perform the matching of two sets of splits and compute their Jaccard similarity.

    Args:
        split_1 (numpy.ndarray): The first set of splits.
        split_2 (numpy.ndarray): The second set of splits.
        n_tips (list): The number of tips in the tree.
        exponent (int): The exponent to use in the Jaccard similarity calculation.
        allow_conflict (bool): Whether to allow conflicting splits.
    """
    if split_1.shape[1] != split_2.shape[1]:
        raise ValueError("Input splits must address the same number of tips.")
    
    n_splits_1 = len(split_1)
    n_splits_2 = len(split_2)
    most_splits = max(n_splits_1, n_splits_2)
    
    max_score = 1e6

    # Initialize the score matrix with the maximum score
    score = np.full((most_splits, most_splits), max_score)

    # Iterate over all splits in tree 1
    for clade_a_ix in range(n_splits_1):
        # Count number of tips in the split
        n_tips_a = np.sum(split_1[clade_a_ix])
        # Count number of tips not in this split
        n_tips_not_a = n_tips - n_tips_a
        
        for clade_b_ix in range(n_splits_2):
            clade_a_b = np.sum(np.bitwise_and(split_1[clade_b_ix],
                                              split_2[clade_b_ix]))
            
            n_tips_b = np.sum(split_2[clade_b_ix])
            n_tips_not_b = n_tips - n_tips_b
            a_and_B = n_tips_a - clade_a_b
            A_and_b = n_tips_b - clade_a_b
            A_and_B = n_tips_not_b - a_and_B

            if not allow_conflict and not (
                clade_a_b == n_tips_a or a_and_B == n_tips_b or A_and_b == n_tips_not_a or A_and_B == n_tips_not_a):
                score[clade_a_ix][clade_b_ix] = max_score
            else:
                A_or_b = n_tips - a_and_B
                a_or_B = n_tips - A_and_b
                a_or_b = n_tips - A_and_B
                A_or_B = n_tips - clade_a_b

                ars_ab = clade_a_b / a_or_b
                ars_Ab = A_and_b / A_or_b
                ars_aB = a_and_B / a_or_B
                ars_AB = A_and_B / A_or_B

                min_ars_both = min(ars_ab, ars_AB)
                min_ars_either = min(ars_aB, ars_Ab)

                if exponent == 1:
                    score[clade_a_ix][clade_b_ix] = max_score - (max_score * max(min_ars_both, min_ars_either))
                elif exponent == float('inf'):
                    score[clade_a_ix][clade_b_ix] = 0 if min_ars_both == 1 or min_ars_either == 1 else max_score
                else:
                    score[clade_a_ix][clade_b_ix] = max_score - (max_score * (max(min_ars_both, min_ars_either) ** exponent))
    
    # Filling in extra rows/columns with max_score
    score[n_splits_1:, :] = max_score
    score[:, n_splits_2:] = max_score
    
    # Perform linear assignment to minimize the score
    row_ind, col_ind = linear_sum_assignment(score)
    
    # Adjust by 2 - 2 * to retrieve zero in case of perfect match for RF
    final_score = 2 - (max_score * most_splits - score[row_ind, col_ind].sum()) / max_score

    # Prepare final matching
    final_matching = np.full(n_splits_1, np.nan)
    for i, match in enumerate(row_ind):
        if match < n_splits_2:
            final_matching[i] = match

    return {
        'score': final_score,
        'matching': final_matching
    }



all_tips = ['A', 'B', 'C', 'D', 'E']  # List of all tips
splits_1 = [{'A', 'B', 'C'},{'D','E','C'}]  # Example splits
splits_2 = [{'A', 'B', 'C'},{'D','E','C'}]  # Example splits

binary_matrix_1 = splits_to_binary_matrix(splits_1, all_tips)
binary_matrix_2 = splits_to_binary_matrix(splits_2, all_tips)

x = np.array([[1, 0], [0, 1]])  # Example binary matrices for splits
y = np.array([[1, 0], [0, 1]])
n_tips = 5  # Example number of tips
k = 1  # Example exponent
allow_conflict = True

result = jaccard_similarity(binary_matrix_1, binary_matrix_2, n_tips, k, allow_conflict)


In [42]:
splits_1 = get_splits(tree_1)
splits_2 = get_splits(tree_2)

In [6]:
trees = loads('((G,(Y,Z))A,(C,D)Q)F;((G,(Y,Z))A,(C,D)Q)F;')
tree_1 = trees[0]
tree_2 = trees[1]

print(tree_1.ascii_art())
print(tree_2.ascii_art())


splits_1 = get_splits(tree_1)
splits_2 = get_splits(tree_2)

splits_1

        ┌─G
    ┌─A─┤
    │   │   ┌─Y
    │   └───┤
──F─┤       └─Z
    │   ┌─C
    └─Q─┤
        └─D
        ┌─G
    ┌─A─┤
    │   │   ┌─Y
    │   └───┤
──F─┤       └─Z
    │   ┌─C
    └─Q─┤
        └─D


[({'Y'}, {'Z'}),
 ({'G'}, {'Y', 'Z'}),
 ({'C'}, {'D'}),
 ({'A', 'G', 'Y', 'Z'}, {'C', 'D', 'Q'})]

In [43]:
def compute_intersection_proportion(set_1, set_2):
    """
    Compute the proportion of the intersection of two sets.

    Parameters:
    set_1: First set
    set_2: Second set

    Returns:
    proportion: Proportion of the intersection
    """
    intersection = set_1.intersection(set_2)
    union = set_1.union(set_2)
    return len(intersection) / len(union)


scores = []
for split_1 in splits_1:
    for split_2 in splits_2:
        score_split_1 = []
        # Iterate over every split and take the maximum score
        score_split_1.append(max(compute_intersection_proportion(split_1[0], split_2[0]),
                                 compute_intersection_proportion(split_1[1], split_2[1])))


        # tips = list(set([value for inner_set in split_1 for value in inner_set] + [value for inner_set in split_2 for value in inner_set]))
        # binary_matrix_1 = splits_to_binary_matrix(list(split_1), tips)
        # binary_matrix_2 = splits_to_binary_matrix(list(split_2), tips)
        # score = jaccard_similarity(binary_matrix_1, binary_matrix_2, len(tips), 1, False)["score"]
        # if score != 0:
        #     print(f"Split {split_1} and {split_2} have a Jaccard similarity of {score}")

In [45]:
split_1_right, split_1_left = splits_1[0]

In [44]:
scores = {}
for ix_1, (split_1_left, split_1_right) in enumerate(splits_1):
    print("---")
    print(split_1_left)
    print(split_1_right)
    print("----")
    scores[ix_1] = {}
    for ix, (split_2_left, split_1_right) in enumerate(splits_2):
        scores[ix_1][ix] = max(
            min(compute_intersection_proportion(split_1_left, split_2_left),
                compute_intersection_proportion(split_1_right, split_1_right)),
            min(compute_intersection_proportion(split_1_left, split_1_right),
                                compute_intersection_proportion(split_1_right, split_2_left)))



---
{'Y'}
{'Z'}
----
---
{'G'}
{'Z', 'Y'}
----
---
{'C'}
{'D'}
----
---
{'A', 'G', 'Z', 'Y'}
{'D', 'C', 'Q'}
----


In [46]:
scores = np.array([[scores[i][j] for j in range(len(scores))] for i in range(len(scores))])


In [47]:
from scipy.optimize import linear_sum_assignment

row_ind, col_ind = linear_sum_assignment(scores)

In [683]:
scores[row_ind, col_ind].sum()

np.float64(0.0)

In [48]:
from scipy.optimize import linear_sum_assignment
from newick import loads
import numpy as np 


def load_tree(newick_str):
    """Load the first tree of the newick string.
    Careful, if there are several trees, only the first one is loaded.
    """
    return loads(newick_str)[0]


def get_nodes(node):
    """Get the nodes under a leaf.
    """
    if node.is_leaf:
        return {node.name}

    leaves = set()

    for child in node.descendants:
        leaves.update(get_nodes(child))
    if node.name:
        leaves.add(node.name)
    return leaves


def get_splits(node):
    """
    Get the splits recursively at each node.
    """
    splits = []

    if node.is_leaf:
        return splits

    current_leaves = []
    for descendant in node.descendants:
        current_leaves.append(get_nodes(descendant))
        splits += get_splits(descendant)
    splits.append(tuple(current_leaves))

    return splits


def compute_intersection_proportion(set_1, set_2, order=1):
    """
    Compute the proportion of the intersection of two sets.

    Parameters:
    set_1: First set
    set_2: Second set

    Returns:
    proportion: Proportion of the intersection
    """
    intersection = set_1.intersection(set_2)
    union = set_1.union(set_2)
    if order == 1:
        # Computation for the Nye similarity
        return len(intersection) / len(union)
    else:
        # Computation for the Böcker similarity
        return 2 - 2 * (len(intersection) / len(union))**order


def jaccard_similarity(split_1, split_2, order=1):
    """Compute the score between two splits.
    """
    split_1_left, split_1_right = split_1
    split_2_left, split_2_right = split_2
    return max(
        min(compute_intersection_proportion(split_1_left, split_2_left, order=order),
            compute_intersection_proportion(split_1_right, split_2_right, order=order)),
        min(compute_intersection_proportion(split_1_left, split_2_right, order=order),
            compute_intersection_proportion(split_1_right, split_2_left, order=order)))


def split_similarity(splits_1, splits_2, score_computer, *args, **kwargs):
    """Given two slides, define the generalized Robinson-Foulds distance
    using the similarity given as input in the score_computer function.
    """
    scores = {}
    for ix_1, split_1 in enumerate(splits_1):
        scores[ix_1] = {}
        for ix, split_2 in enumerate(splits_2):
            scores[ix_1][ix] = score_computer(split_1,
                                             split_2,
                                             *args, **kwargs)
    # This can be optimized
    scores_array = np.array([[scores[i][j] for j in range(len(scores))] for i in range(len(scores))])
    row_ind, col_ind = linear_sum_assignment(scores_array)
    return scores_array[row_ind, col_ind].sum()
    

def rf_generalized(tree_1_str,
                   tree_2_str, 
                   score_computer,
                   *args, **kwargs):
    """Return the generalized Robinson's Foulds jaccard distance."""
    tree_1 = load_tree(tree_1_str)
    tree_2 = load_tree(tree_2_str)
    
    splits_1 = get_splits(tree_1)
    splits_2 = get_splits(tree_2)
    
    return split_similarity(splits_1, splits_2, score_computer, *args, **kwargs)


from functools import reduce


def get_unique_values_splits(splits_1, splits_2):
    """
    Get the number of distinct leaves in two splits.
    """
    # Combine the two lists into one
    combined_data = splits_1 + splits_2
    
    # Flatten the list of tuples into one list of sets, then apply union across all sets
    all_sets = [set1.union(set2) for set1, set2 in combined_data]
    
    # Use reduce to apply union across all sets
    unique_values = reduce(set.union, all_sets)
    
    return len(unique_values)


def matching_split(split_1, split_2):
    """Compute the score between two splits.
    """
    split_1_left, split_1_right = split_1
    split_2_left, split_2_right = split_2
    
    len_tree = get_unique_values_splits(
        [split_1],
        [split_2]
    )
    return len_tree - max(
        len(split_1_left.intersection(split_2_left)) + len(split_2_right.intersection(split_2_right)),
        len(split_1_left.intersection(split_2_right)) + len(split_1_right.intersection(split_2_left))
    )            

In [49]:
rf_generalized('((G,(Y,Z))A,(C,D)Q)F;', 
               "((G,(Y,Z))A,(C,D)Q)F;",
               score_computer=matching_split)

np.int64(0)

In [50]:
import math

def double_factorial(n):
    """Define the double factorial."""
    if n <= 0:
        return 1
    return math.prod(range(n, 0, -2))


def P_Phy(split):
    """Compute the probability of apparition of a split."""
    split_1_left, split_1_right = split

    X_len = get_unique_values_splits(splits_1, [])

    numerator = (double_factorial(2 * len(split_1_left) - 3) * 
                 double_factorial(2 * len(split_1_right) - 3))
    
    denominator = double_factorial(2 * X_len - 5)

    return numerator / denominator


def phylogenetic_information_content(split):
    """Compute the phylogenetic information content."""
    return -math.log(P_Phy(split))


# Function to compute P_Phy for given sets S1 and S2
def probability_splits(split_1, split_2):
    """Compute the probability of apparition of two splits.
    """
    split_1_left, split_1_right = split_1
    split_2_left, _ = split_2
    X_len = get_unique_values_splits([split_1], [split_2])

    numerator = (double_factorial(2 * (len(split_1_right) + 1) - 5) * 
                 double_factorial(2 * (len(split_2_left) + 1) - 5) * 
                 double_factorial(2 * (len(split_1_left) - len(split_2_left) + 2) - 5))

    denominator = double_factorial(2 * X_len - 5)

    return -math.log(numerator / denominator)


def shared_phylogenetic_information_score(split_1, split_2):
    """Compute the shared phylogenetic information of two splits.
    """
    if are_splits_incompatible(split_1, split_2):
        return 0
    return phylogenetic_information_content(split_1) + \
        phylogenetic_information_content(split_2) - \
        probability_splits(split_1, split_2)

def shared_phylogenetic_information(tree_1, tree_2):
    return rf_generalized(
        tree_1,
        tree_2,
        score_computer=shared_phylogenetic_information_score)


def mutual_clustering_information(tree_1, tree_2):
    """Given two trees, compute the shared phylogenetic information between the trees.
    """
    splits_1 = get_splits(tree_1)
    splits_2 = get_splits(tree_2)
    
    return 



In [15]:
split_1_left, split_1_right = splits_1[0]
split_2_left, split_1_right = splits_2[0]

In [32]:
def matching_split_information(split_1, split_2):
    """Define the matching split information score (MSI) between two splits."""
    split_1_left, split_1_right = split_1
    split_2_left, split_2_right = split_2
    return max(
    phylogenetic_information_content((split_1_left.intersection(split_2_left), split_1_right.intersection(split_2_right))),
    phylogenetic_information_content((split_1_left.intersection(split_2_right), split_1_right.intersection(split_2_left)))
)

def shared_phylogenetic_information(split_1, split_2):
    """Compute the shared phylogenetic information of two splits.
    """
    if are_splits_incompatible(split_1, split_2):
        return 0
    return phylogenetic_information_content(split_1) + \
        phylogenetic_information_content(split_2) - \
        probability_splits(split_1, split_2)

In [33]:
rf_generalized('(((E,M)Y,Z)A,(C,D)Q)F;',
                "((G,(Y,Z))A,(C,E)Q)F;",
                score_computer=matching_split)

np.float64(27.404739709974972)

In [34]:
rf_generalized('(((E,M)Y,Z)A,(C,D)Q)F;',
                "((G,(Y,Z))A,(C,E)Q)F;",
                score_computer=shared_phylogenetic_information)

np.float64(24.696689508872762)

In [1]:
trees = loads('((G,(Y,Z))A,(C,D)Q)F;((G,(Y,Z))A,(C,D)Q)F;')
tree_1 = trees[0]
tree_2 = trees[1]

NameError: name 'loads' is not defined

In [72]:
trees = loads('(((E,M)Y,Z)A,(C,D)Q)F;((G,(Y,Z))A,(C,D)Q)F;')
tree_1 = trees[0]
tree_2 = trees[1]

print(tree_1.ascii_art())
print(tree_2.ascii_art())

            ┌─E
        ┌─Y─┤
    ┌─A─┤   └─M
──F─┤   └─Z
    │   ┌─C
    └─Q─┤
        └─D
        ┌─G
    ┌─A─┤
    │   │   ┌─Y
    │   └───┤
──F─┤       └─Z
    │   ┌─C
    └─Q─┤
        └─D


In [19]:
def are_splits_incompatible(split_1, split_2):
    """
    Determines if two splits are incompatible.
    A split is represented as a tuple of two sets, where each set is a group of taxa.
    For example, split1 = ({'A', 'B'}, {'C', 'D'}) represents the split {A, B} | {C, D}.
    """
    
    # Unpack the splits
    split_1_left, split_1_right = split_1
    split_2_left, split_2_right = split_2
    
    # Check for incompatibility: There should not be overlap between a group in one split
    # and both groups in the other split.
    
    # Group 1 of split 1 should not overlap with both groups of split 2
    if (split_1_left & split_2_left and split_1_left & split_2_right):
        return True
    # Group 2 of split 1 should not overlap with both groups of split 2
    if (split_1_right & split_2_left and split_1_right & split_2_right):
        return True
    
    # If no conflicts were found, the splits are compatible
    return False

# Example splits
split1 = ({'A', 'B'}, {'C', 'D'})
split2 = ({'A', 'C'}, {'B', 'D'})

# Check if the splits are incompatible
if are_splits_incompatible(split1, split2):
    print("The splits are incompatible.")
else:
    print("The splits are compatible.")




The splits are incompatible.


## Computing the graph edit distance

In [42]:
import newick

def tree_edit_distance(tree1, tree2):
    """
    Recursively calculates the tree edit distance between two trees,
    accounting for node names and structure.
    """
    if tree1.name != tree2.name:
        # Node substitution cost (if names differ)
        cost = 1
    else:
        cost = 0

    # Get children of both nodes
    children1 = tree1.descendants
    children2 = tree2.descendants

    # Calculate insertion or deletion costs for unbalanced children
    ins_del_cost = abs(len(children1) - len(children2))

    # Recursive structure comparison for all pairs of children
    sub_cost = 0
    for c1, c2 in zip(children1, children2):
        sub_cost += tree_edit_distance(c1, c2)

    # Total cost includes substitution, insertion/deletion, and recursive structure costs
    total_cost = cost + ins_del_cost + sub_cost
    return total_cost

# Example usage
tree_newick_1 = "((A,B),C)F;"
tree_newick_2 = "((A,C),F)B;"

# Load trees from Newick strings
tree1 = newick.loads(tree_newick_1)[0]
tree2 = newick.loads(tree_newick_2)[0]

# Calculate tree edit distance
distance = tree_edit_distance(tree1, tree2)
print("Tree Edit Distance:", distance)


Tree Edit Distance: 3


In [43]:
print(tree1.ascii_art())

print(tree2.ascii_art())

        ┌─A
    ┌───┤
──F─┤   └─B
    └─C
        ┌─A
    ┌───┤
──B─┤   └─C
    └─F


## Computing affinity matrixes and the corresponding euclidean distance

In [72]:
import numpy as np
import newick
from newick import loads


def load_tree(newick_str):
    """Load the first tree of the newick string.
    Careful, if there are several trees, only the first one is loaded.
    """
    return loads(newick_str)[0]


def build_edges_from_tree(node):
    """
    Recursively build edges from the tree structure.
    """
    edges = []
    if node.name is not None:  # Ensure the node has a name
        for child in node.descendants:
            if child.name is not None:  # Ensure the child has a name
                edges.append((node.name, child.name))  # Add edge
                edges.extend(build_edges_from_tree(child))  # Recur for child
    return edges


def create_adjacency_matrix(tree):
    """Create an adjacency matrix from a tree.
    """
    # Build edges from the tree
    edges = build_edges_from_tree(tree)

    # Extract unique nodes
    nodes = list(set([node for edge in edges for node in edge]))

    # Initialize the adjacency matrix
    n = len(nodes)
    matrix = np.zeros((n, n), dtype=int)

    # Fill the adjacency matrix
    for i in range(n):
        for j in range(n):
            if (nodes[i], nodes[j]) in edges or (nodes[j], nodes[i]) in edges:
                matrix[i, j] = 1  # Mark as connected

    return nodes, matrix


def compute_distance_adjacency_matrix(matrix_1, matrix_2, degree=2):
    """Compute the distance between two adjacency matrices.
    Possible distances include the l1 and the l2 norm.
    """
    return ((matrix_1 - matrix_2) ** degree).mean() ** (1/degree)


def adjacency_matrix_distance(tree_1_str, tree_2_str):
    """Compute the distance between two trees using adjacency matrices.
    Careful, all nodes of the tree must be named for this distance to be computed,
    including the root node.
    """
    # Load the trees from the Newick strings
    tree_1 = load_tree(tree_1_str)
    tree_2 = load_tree(tree_2_str)

    # Create adjacency matrices from the trees
    nodes_1, matrix_1 = create_adjacency_matrix(tree_1)
    nodes_2, matrix_2 = create_adjacency_matrix(tree_2)

    # Compute the distance between the adjacency matrices
    distance = compute_distance_adjacency_matrix(matrix_1, matrix_2)
    return distance



# # Example Newick string
# newick_str = "((A,B)G, (C, (D, E)X))F;"

# # Create adjacency matrix from the Newick string
# nodes, adjacency_matrix = create_adjacency_matrix(newick_str)

# # Display the results
# if nodes and adjacency_matrix is not None:
#     print("Nodes:", nodes)
#     print("\nAdjacency Matrix:")
#     print("  " + " ".join(nodes))
#     for i, row in enumerate(adjacency_matrix):
#         print(nodes[i] + " " + " ".join(map(str, row.astype(int))))


In [82]:
tree_newick_1 = "((A,E)D,B)F;"
tree_newick_2 = "((A,E)D,B)F;"

adjacency_matrix_distance(tree_newick_1, tree_newick_2)

np.float64(0.0)

In [83]:
print(loads(tree_newick_1)[0].ascii_art())
print(loads(tree_newick_2)[0].ascii_art())

        ┌─A
    ┌─D─┤
──F─┤   └─E
    └─B
        ┌─A
    ┌─D─┤
──F─┤   └─E
    └─B


In [None]:
import numpy as np
import newick
from newick import loads


def load_tree(newick_str):
    """Load the first tree of the newick string.
    Careful, if there are several trees, only the first one is loaded.
    """
    return loads(newick_str)[0]


def build_edges_from_tree(node):
    """
    Recursively build edges from the tree structure.
    """
    edges = []
    if node.name is not None:  # Ensure the node has a name
        for child in node.descendants:
            if child.name is not None:  # Ensure the child has a name
                edges.append((node.name, child.name))  # Add edge
                edges.extend(build_edges_from_tree(child))  # Recur for child
    return edges


def create_adjacency_matrix(tree):
    """Create an adjacency matrix from a tree.
    """
    # Build edges from the tree
    edges = build_edges_from_tree(tree)

    # Extract unique nodes
    nodes = list(set([node for edge in edges for node in edge]))

    # Initialize the adjacency matrix
    n = len(nodes)
    matrix = np.zeros((n, n), dtype=int)

    # Fill the adjacency matrix
    for i in range(n):
        for j in range(n):
            if (nodes[i], nodes[j]) in edges or (nodes[j], nodes[i]) in edges:
                matrix[i, j] = 1  # Mark as connected

    return nodes, matrix


def compute_distance_adjacency_matrix(matrix_1, matrix_2, degree=2):
    """Compute the distance between two adjacency matrices.
    Possible distances include the l1 and the l2 norm.
    """
    return ((matrix_1 - matrix_2) ** degree).mean() ** (1/degree)


def adjacency_matrix_distance(tree_1_str, tree_2_str):
    """Compute the distance between two trees using adjacency matrices.
    Careful, all nodes of the tree must be named for this distance to be computed,
    including the root node.
    """
    # Load the trees from the Newick strings
    tree_1 = load_tree(tree_1_str)
    tree_2 = load_tree(tree_2_str)

    # Create adjacency matrices from the trees
    nodes_1, matrix_1 = create_adjacency_matrix(tree_1)
    nodes_2, matrix_2 = create_adjacency_matrix(tree_2)

    # Compute the distance between the adjacency matrices
    distance = compute_distance_adjacency_matrix(matrix_1, matrix_2)
    return distance



# # Example Newick string
# newick_str = "((A,B)G, (C, (D, E)X))F;"

# # Create adjacency matrix from the Newick string
# nodes, adjacency_matrix = create_adjacency_matrix(newick_str)

# # Display the results
# if nodes and adjacency_matrix is not None:
#     print("Nodes:", nodes)
#     print("\nAdjacency Matrix:")
#     print("  " + " ".join(nodes))
#     for i, row in enumerate(adjacency_matrix):
#         print(nodes[i] + " " + " ".join(map(str, row.astype(int))))


In [60]:
print(newick.loads(newick_str)[0].ascii_art())

        ┌─A
    ┌─G─┤
    │   └─B
──F─┤
    │   ┌─C
    └───┤
        │   ┌─D
        └─X─┤
            └─E


## Implement the signed similarity between two trees (see Roos and Hekkila)


In [23]:
import itertools
import newick

def build_node_paths(tree, path=None, paths=None):
    """
    Recursively build the paths to each node in the tree.
    """
    if paths is None:
        paths = {}
    if path is None:
        path = []
    
    if tree.name:
        paths[tree.name] = path + [tree.name]
    
    for child in tree.descendants:
        build_node_paths(child, path + [tree.name], paths)
    return paths


def calculate_distance(paths, node1, node2):
    """
    Compute the distances in terms of edges between two nodes.
    """
    path1 = paths.get(node1, [])
    path2 = paths.get(node2, [])

    # Find the nearest common ancestor
    common_ancestor = None
    for u, v in zip(path1, path2):
        if u == v:
            common_ancestor = u
        else:
            break
    
    # Compute the distance as the sum of the distances to the common ancestor
    # Total path - path to reach ancestor - 1 (to avoid counting the common ancestor twice)                                                                                                                                                                                                                                                                                                                                                                                                                     
    distance = (len(path1) - path1.index(common_ancestor) - 1) + (len(path2) - path2.index(common_ancestor) - 1)
    return distance

def compute_u(tree_paths_tree1,
                tree_paths_tree2, 
                A, B, C):
    """
    Compute the u value for a triplet of nodes taken from tree 1 and tree 2.
    """
    # Distances dans l'arbre testé
    d_tree1_AB = calculate_distance(tree_paths_tree1, A, B)
    d_tree1_AC = calculate_distance(tree_paths_tree1, A, C)
    # Distances dans l'arbre de référence
    d_tree2_AB = calculate_distance(tree_paths_tree2, A, B)
    d_tree2_AC = calculate_distance(tree_paths_tree2, A, C)
    
    # Calcul de la fonction u
    sign_test = int(d_tree1_AB - d_tree1_AC > 0) - int(d_tree1_AB - d_tree1_AC < 0)
    sign_ref = int(d_tree2_AB - d_tree2_AC > 0) - int(d_tree2_AB - d_tree2_AC < 0)
    u_value = 1 - 0.5 * abs(sign_test - sign_ref)
    return u_value


def compute_roos_similarity(tree_1, tree_2):
    """Compute the Roos similarity between two trees.
    """
    # Compute the paths for every node in the tree
    tree_paths_1 = build_node_paths(tree_1)
    tree_paths_2 = build_node_paths(tree_2)
    
    # List the common nodes in the tree
    common_nodes = set(tree_paths_1.keys()).intersection(tree_paths_2.keys())
    
    # Compute for each triplet the u value
    S = 0
    for A, B, C in itertools.combinations(common_nodes, 3):
        S += compute_u(tree_paths_1, tree_paths_2, A, B, C)
    
    # Adjust the score for identical trees to zero
    # Perfect score is the case were each triplet has a u value of 1
    max_score = len(list(itertools.combinations(common_nodes, 3)))
    similarity_score = max_score - S  # Subtract to reflect dissimilarity
    
    return similarity_score

def roos_similarity(tree_1_str, tree_2_str):
    """
    Compute the Roos similarity between two trees.
    """
    # Charger les arbres depuis les chaînes Newick
    tree_1 = newick.loads(tree_1_str)[0]
    tree_2 = newick.loads(tree_2_str)[0]
    
    return compute_roos_similarity(tree_1, tree_2)


# Exemples de chaînes Newick pour l'arbre testé et l'arbre de référence
newick_str_test = "((A,B)X, (C, (D, E)Z)U)F;"
newick_str_ref = "((A,X)B, (C, (D, E)Z)U)F;"

# Calculer la similarité de Roos
similarity_score = roos_similarity(newick_str_test, newick_str_ref)

# Afficher le score de similarité
print("Score de similarité de Roos:", similarity_score)


Score de similarité de Roos: 14.0


In [20]:
print(newick.loads(newick_str_test)[0].ascii_art())

        ┌─A
    ┌─X─┤
    │   └─B
──F─┤
    │   ┌─C
    └─U─┤
        │   ┌─D
        └─Z─┤
            └─E


In [16]:
print(newick.loads(newick_str_ref)[0].ascii_art())

        ┌─A
    ┌─X─┤
    │   └─B
──F─┤
    │   ┌─C
    └─U─┤
        │   ┌─D
        └─Z─┤
            └─E


In [26]:
tree_1 = "(((A,B)X,(E,G)H)I,((C,D)Y,(M,N)O)Z);"
tree_2 = "(((A,C)X,(P,Q)H)I,((F,L)Y,(R,S)O)Z);"


In [30]:
print(newick.loads("(((A,B)X,C)Y,D)Z;")[0].ascii_art())


            ┌─A
        ┌─X─┤
    ┌─Y─┤   └─B
──Z─┤   └─C
    └─D


In [33]:
print(newick.loads(tree_2)[0].ascii_art())

            ┌─A
        ┌─X─┤
        │   └─C
    ┌─I─┤
    │   │   ┌─P
    │   └─H─┤
    │       └─Q
────┤
    │       ┌─F
    │   ┌─Y─┤
    │   │   └─L
    └─Z─┤
        │   ┌─R
        └─O─┤
            └─S


In [31]:
from stemmadist.dists.rf.rf import rf_distance

rf_distance(tree_1, tree_2)

6

In [1]:
%pip install networkx

Collecting networkx
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Downloading networkx-3.4.2-py3-none-any.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: networkx
Successfully installed networkx-3.4.2

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
import networkx as nx
import newick

def newick_to_networkx(newick_str):
    # Parse the Newick tree
    tree = newick.loads(newick_str)[0]  # loads returns a list of trees; we use the first
    print(tree.ascii_art())
    # Initialize an empty directed NetworkX graph
    G = nx.DiGraph()

    # Recursive function to add nodes and edges to the graph
    def add_edges(node, parent=None):
        current_node = node.name or id(node)
        # Add the current node
        G.add_node(current_node, label=node.name)  # Internal nodes without names get a unique object
        if parent is not None:
            G.add_edge(parent, current_node)

        # Recursively add children
        for child in node.descendants:
            add_edges(child, node)

    # Start building the graph from the root of the tree
    add_edges(tree)
    
    return G

# Example usage
newick_str = "((A,B)M, (D, L)X)F;"
G_1 = newick_to_networkx(newick_str)
print("Nodes:", G_1.nodes())
print("Edges:", G_1.edges())

G_2 = newick_to_networkx("((A,B)X, (D, L)M)F;")
print("Nodes:", G_2.nodes())
print("Edges:", G_2.edges())

def label_based_cost(u, v):
    # if u.get("label") and v.get("label"):
    #     return 0 if G_1.nodes[u['label']] == G_2.nodes[v['label']] else 1
    # return 0
    label_u = G_1.nodes[u['label']] if u.get("label") else u
    label_v = G_2.nodes[v['label']] if v.get("label") else v
    print(u.get("label"), v.get("label"))
    return 0 if label_u == label_v else 1


nx.graph_edit_distance(G_1, G_2,
                             node_subst_cost=label_based_cost,
                             node_del_cost=lambda n: 1,
                             node_ins_cost=lambda n: 1)

        ┌─A
    ┌─M─┤
    │   └─B
──F─┤
    │   ┌─D
    └─X─┤
        └─L
Nodes: ['F', 'M', Node("F"), 'A', Node("M"), 'B', 'X', 'D', Node("X"), 'L']
Edges: [(Node("F"), 'M'), (Node("F"), 'X'), (Node("M"), 'A'), (Node("M"), 'B'), (Node("X"), 'D'), (Node("X"), 'L')]
        ┌─A
    ┌─X─┤
    │   └─B
──F─┤
    │   ┌─D
    └─M─┤
        └─L
Nodes: ['F', 'X', Node("F"), 'A', Node("X"), 'B', 'M', 'D', Node("M"), 'L']
Edges: [(Node("F"), 'X'), (Node("F"), 'M'), (Node("X"), 'A'), (Node("X"), 'B'), (Node("M"), 'D'), (Node("M"), 'L')]
F F
F X
F None
F A
F None
F B
F M
F D
F None
F L
M F
M X
M None
M A
M None
M B
M M
M D
M None
M L
None F
None X
None None
None A
None None
None B
None M
None D
None None
None L
A F
A X
A None
A A
A None
A B
A M
A D
A None
A L
None F
None X
None None
None A
None None
None B
None M
None D
None None
None L
B F
B X
B None
B A
B None
B B
B M
B D
B None
B L
X F
X X
X None
X A
X None
X B
X M
X D
X None
X L
D F
D X
D None
D A
D None
D B
D M
D D
D None
D L
None F
None X
No

0.0

In [5]:
import networkx as nx
import newick

def newick_to_networkx(newick_str):
    # Parse the Newick tree
    tree = newick.loads(newick_str)[0]  # loads returns a list of trees; we use the first
    print(tree.ascii_art())
    
    # Initialize an empty directed NetworkX graph
    G = nx.DiGraph()
    unique_id = 0  # Counter for unnamed internal nodes

    # Recursive function to add nodes and edges to the graph
    def add_edges(node, parent=None):
        nonlocal unique_id
        current_node = node.name if node.name else f"Internal_{unique_id}"
        if not node.name:
            unique_id += 1

        # Add the current node
        G.add_node(current_node, label=node.name)  # Internal nodes without names get a unique ID
        if parent is not None:
            G.add_edge(parent, current_node)

        # Recursively add children
        for child in node.descendants:
            add_edges(child, current_node)

    # Start building the graph from the root of the tree
    add_edges(tree)
    return G

# Graphs from Newick strings
newick_1 = "((A,B)M, (D,L)X)F;"
newick_2 = "((A,B)X, (D,L)M)F;"

G_1 = newick_to_networkx(newick_1)
G_2 = newick_to_networkx(newick_2)

print("Nodes (G_1):", G_1.nodes(data=True))
print("Edges (G_1):", G_1.edges())
print("Nodes (G_2):", G_2.nodes(data=True))
print("Edges (G_2):", G_2.edges())

# Cost functions for GED
def label_based_cost(u, v):
    # Retrieve labels and compare
    label_u = u.get("label", "")
    label_v = v.get("label", "")
    return 0 if label_u == label_v else 1/2

# Deletion and insertion cost
def node_deletion_cost(node):
    return 1

def node_insertion_cost(node):
    return 1

# Compute GED
ged = nx.graph_edit_distance(
    G_1, G_2,
    node_subst_cost=label_based_cost,
    node_del_cost=node_deletion_cost,
    node_ins_cost=node_insertion_cost,
)

print("Graph Edit Distance (GED):", ged)


        ┌─A
    ┌─M─┤
    │   └─B
──F─┤
    │   ┌─D
    └─X─┤
        └─L
        ┌─A
    ┌─X─┤
    │   └─B
──F─┤
    │   ┌─D
    └─M─┤
        └─L
Nodes (G_1): [('F', {'label': 'F'}), ('M', {'label': 'M'}), ('A', {'label': 'A'}), ('B', {'label': 'B'}), ('X', {'label': 'X'}), ('D', {'label': 'D'}), ('L', {'label': 'L'})]
Edges (G_1): [('F', 'M'), ('F', 'X'), ('M', 'A'), ('M', 'B'), ('X', 'D'), ('X', 'L')]
Nodes (G_2): [('F', {'label': 'F'}), ('X', {'label': 'X'}), ('A', {'label': 'A'}), ('B', {'label': 'B'}), ('M', {'label': 'M'}), ('D', {'label': 'D'}), ('L', {'label': 'L'})]
Edges (G_2): [('F', 'X'), ('F', 'M'), ('X', 'A'), ('X', 'B'), ('M', 'D'), ('M', 'L')]
Graph Edit Distance (GED): 1.0


In [55]:
G1 = nx.cycle_graph(6)
G2 = nx.wheel_graph(7)
nx.graph_edit_distance(G1, G2)

7.0

In [57]:
print(G1.edges())
print(G1.nodes())


print(G2.edges())
print(G2.nodes())

[(0, 1), (0, 5), (1, 2), (2, 3), (3, 4), (4, 5)]
[0, 1, 2, 3, 4, 5]
[(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (1, 2), (1, 6), (2, 3), (3, 4), (4, 5), (5, 6)]
[0, 1, 2, 3, 4, 5, 6]


In [35]:
def constant_cost(u, v):
    return 1 if u != v else 0  # Cost is 1 for differing nodes, 0 for identical ones


ged = nx.graph_edit_distance(G_1, G_2,
                             node_subst_cost=constant_cost,
                             node_del_cost=lambda n: 1,
                             node_ins_cost=lambda n: 1)

print(ged)

0.0


In [None]:
G1