# Triplet-Based Phylogenetic Tree Reconstruction  
### Simulation, Accuracy Evaluation, and Benchmarking

This notebook evaluates a triplet-based tree reconstruction algorithm based on Max-Cut
partitioning, and compares its performance against standard phylogenetic solvers
under both complete and missing-data settings.


## 1. Imports and Global Configuration

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import itertools
from typing import Tuple
import networkx as nx

import cassiopeia as cas
from collections import defaultdict
from tqdm.auto import tqdm
import pandas as pd


## 2. Simulation Parameters

We define the parameters used to simulate complete binary lineage trees and
character evolution under a mutation model.


In [None]:
# Number of characters
k_cand = 40

# Mutation rate
lamb = 0.5

# Number of discrete mutation states
num_states = 5
q_dist = dict(zip(range(1, num_states + 1), [1 / num_states] * num_states))

# Tree depth
depth = 8

# Additional theoretical parameters (used elsewhere)
d_star = 0.9
l = 1 / 9

# Number of simulation repetitions
num_simulations = 50


## 3. Triplet Accuracy Metric

We evaluate reconstruction quality using the proportion of correctly resolved triplets.


In [None]:
def calculate_triplets_correct(
    ground_tree: cas.data.CassiopeiaTree,
    recon: nx.DiGraph,
    num_samples: int = 5000
) -> float:
    """
    Estimates triplet accuracy by random sampling.

    Parameters
    ----------
    ground_tree : CassiopeiaTree
        Ground-truth tree.
    recon : nx.DiGraph
        Reconstructed tree topology.
    num_samples : int
        Number of triplets sampled.

    Returns
    -------
    float
        Fraction of correctly resolved triplets.
    """
    num_correct = 0
    ground = ground_tree.get_tree_topology()
    leaves = list(ground_tree.leaves)

    for _ in range(num_samples):
        triplet = np.random.choice(leaves, 3, replace=False)

        recon_triplet = find_triplet_structure(triplet, recon)
        ground_triplet = find_triplet_structure(triplet, ground)

        num_correct += int(recon_triplet == ground_triplet)

    return num_correct / num_samples


## 4. Single-Tree Simulation and Logging

We repeatedly reconstruct a single simulated tree and record reconstruction accuracy.


In [None]:
output_file = "simulation_results_onetree.txt"

# Create output file if it does not exist
try:
    open(output_file, "r").close()
except FileNotFoundError:
    with open(output_file, "w") as f:
        f.write("Simulation\tAccuracy\n")

# Simulate a single ground-truth tree
ground_truth_tree = utilities.complete_binary_tree_sim(
    k_cand, q_dist, lamb, depth
)

# Infer triplets from character matrix
triplets = find_recon_triplets(ground_truth_tree)

accuracy_list = []

for i in range(num_simulations):

    recon_tree = build_tree_from_triplet_partition(
        ground_truth_tree, triplets
    )

    accuracy = calculate_triplets_correct(
        ground_truth_tree, recon_tree
    )

    accuracy_list.append(accuracy)

    with open(output_file, "a") as f:
        f.write(f"{i+1}\t{accuracy}\n")


## 5. Visualization of Simulation Results

We visualize triplet accuracy across repeated simulations.


In [None]:
file_path = "simulation_results_onetree.txt"

accuracies = []

with open(file_path, "r") as file:
    next(file)  # skip header
    for line in file:
        _, acc = line.strip().split("\t")
        accuracies.append(float(acc))

accuracies = np.array(accuracies) * 100  # percentage
plt.figure(figsize=(8, 5))
plt.plot(accuracies, lw=1.2, label="Triplet Accuracy (%)")
plt.axhline(y=64, linestyle="--", color="red", label="64% Threshold")

plt.xlabel("Simulation Index")
plt.ylabel("Accuracy (%)")
plt.title("Triplet Accuracy over Repeated Reconstructions")
plt.legend()
plt.grid(False)
plt.show()


## 6. Benchmarking Against Standard Algorithms (No Missing Data)

We compare the Max-Cut triplet-based method against classical phylogenetic solvers.


In [None]:
algorithms = {
    "Vanilla Greedy": cas.solver.VanillaGreedySolver(),
    "UPGMA": cas.solver.UPGMASolver(
        dissimilarity_function=cas.solver.dissimilarity.weighted_hamming_distance
    ),
    "Neighbor Joining": cas.solver.NeighborJoiningSolver(
        dissimilarity_function=cas.solver.dissimilarity.weighted_hamming_distance,
        add_root=True,
    ),
}

num_simulations = 100
results_no_missing = defaultdict(list)

for _ in tqdm(range(num_simulations)):

    ground_truth_tree = utilities.complete_binary_tree_sim(
        k_cand, q_dist, lamb, depth
    )

    triplets = find_recon_triplets(ground_truth_tree)
    recon_tree = build_tree_from_triplet_partition(
        ground_truth_tree, triplets
    )

    acc = calculate_triplets_correct(ground_truth_tree, recon_tree)
    results_no_missing["MAX-Cut"].append(acc)

    for name, solver in algorithms.items():

        recon = cas.data.CassiopeiaTree(
            character_matrix=ground_truth_tree.character_matrix,
            missing_state_indicator=-1,
        )

        solver.solve(recon)
        recon.collapse_mutationless_edges(
            infer_ancestral_characters=True
        )

        triplet_scores = cas.critique.compare.triplets_correct(
            ground_truth_tree, recon, number_of_trials=1000
        )

        results_no_missing[name].append(
            np.mean(list(triplet_scores[0].values()))
        )


## 7. Benchmarking with Missing Data

We repeat benchmarking with simulated missing mutations.


In [None]:
results_with_missing = defaultdict(list)

for _ in tqdm(range(num_simulations)):

    ground_truth_tree = utilities.complete_binary_missing_tree_sim(
        k_cand, q_dist, lamb, depth
    )

    triplets = find_recon_triplets(ground_truth_tree)
    recon_tree = build_tree_from_triplet_partition(
        ground_truth_tree, triplets
    )

    acc = calculate_triplets_correct(ground_truth_tree, recon_tree)
    results_with_missing["MAX-Cut"].append(acc)

    for name, solver in algorithms.items():

        recon = cas.data.CassiopeiaTree(
            character_matrix=ground_truth_tree.character_matrix,
            missing_state_indicator=-1,
        )

        solver.solve(recon)
        recon.collapse_mutationless_edges(
            infer_ancestral_characters=True
        )

        triplet_scores = cas.critique.compare.triplets_correct(
            ground_truth_tree, recon, number_of_trials=1000
        )

        results_with_missing[name].append(
            np.mean(list(triplet_scores[0].values()))
        )
