In [1]:
import numpy as np
import torch

from simulate_data import simulate_data
from cassiopeia.data import CassiopeiaTree
from cassiopeia.solver import VanillaGreedySolver, ILPSolver, NeighborJoiningSolver
from cassiopeia.critique import triplets_correct as cas_triplets_correct
from heracles.metrics import heracles_triplets_correct
from heracles.main import main

In [2]:
def compare():
    # simulate data
    true_tree, params = simulate_data(num_extant=500, missing_data=0.0, num_sites=10, num_states=5)
    char_matrix = true_tree.character_matrix
    print('Simulated Data')
    
    # reconstruct tree with Cassiopeia ILPSolver
    ilp_tree = CassiopeiaTree(character_matrix=true_tree.character_matrix, priors=None)
    ilp = ILPSolver()
    ilp.solve(ilp_tree, collapse_mutationless_edges=True)
    triplets = cas_triplets_correct(true_tree, ilp_tree)
    ilp_tc = np.mean(list(triplets[0].values()))
    print('ILP Solved:\t', ilp_tc)
    
    # reconstruct tree with Cassiopeia VanillaGreedySolver
    greedy_tree = CassiopeiaTree(character_matrix=true_tree.character_matrix, priors=None)
    vanilla_greedy = VanillaGreedySolver()
    vanilla_greedy.solve(greedy_tree, collapse_mutationless_edges=True)
    triplets = cas_triplets_correct(true_tree, greedy_tree)
    greedy_tc = np.mean(list(triplets[0].values()))
    print('Greedy Solved:\t', greedy_tc)
    
    # reconstruct tree with Cassiopeia NeighborJoiningSolver
    nj_tree = CassiopeiaTree(character_matrix=true_tree.character_matrix, priors=None)
    nj = NeighborJoiningSolver(add_root=True)
    nj.solve(nj_tree, collapse_mutationless_edges=True)
    triplets = cas_triplets_correct(true_tree, nj_tree)
    nj_tc = np.mean(list(triplets[0].values()))
    print('NJ Solved:\t', nj_tc)
    
    # compute embeddings with HERACLES
    mutation_rate, deletion_rate, transition_prob = params['mutation_rate'], params['deletion_rate'], params['transition_prob']
    embedding_dim=3
    rho=2
    num_epochs=15
    best_embeddings = main(char_matrix, mutation_rate, deletion_rate, transition_prob, 
                           embedding_dim=embedding_dim, rho=rho, 
                           num_epochs=num_epochs, true_tree=true_tree)
    heracles_tc = heracles_triplets_correct(true_tree, best_embeddings, rho=rho)
    print('Heracles Solved:\t', heracles_tc)
    
    return greedy_tc, heracles_tc, ilp_tc, nj_tc

In [3]:
compare()

[2023-04-30 18:33:02,998]    INFO [ILPSolver] Solving tree with the following parameters.
[2023-04-30 18:33:02,998]    INFO [ILPSolver] Convergence time limit: 12600
[2023-04-30 18:33:02,999]    INFO [ILPSolver] Convergence iteration limit: 0
[2023-04-30 18:33:02,999]    INFO [ILPSolver] Max potential graph layer size: 10000
[2023-04-30 18:33:02,999]    INFO [ILPSolver] Max potential graph lca distance: None
[2023-04-30 18:33:03,000]    INFO [ILPSolver] MIP gap: 0.01
[2023-04-30 18:33:03,002]    INFO [ILPSolver] Phylogenetic root: (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)


Simulated Data


[2023-04-30 18:33:03,235]    INFO [ILPSolver] (Process: 274a7adceaa6c5c0648e04e5daba4a40) Estimating a potential graph with a maximum layer size of 10000 and a maximum LCA distance of 9.
[2023-04-30 18:33:03,258]    INFO [ILPSolver] (Process: 274a7adceaa6c5c0648e04e5daba4a40) LCA distance 0 completed with a neighborhood size of 27.
[2023-04-30 18:33:03,282]    INFO [ILPSolver] (Process: 274a7adceaa6c5c0648e04e5daba4a40) LCA distance 1 completed with a neighborhood size of 27.
[2023-04-30 18:33:03,305]    INFO [ILPSolver] (Process: 274a7adceaa6c5c0648e04e5daba4a40) LCA distance 2 completed with a neighborhood size of 27.
[2023-04-30 18:33:03,331]    INFO [ILPSolver] (Process: 274a7adceaa6c5c0648e04e5daba4a40) LCA distance 3 completed with a neighborhood size of 29.
[2023-04-30 18:33:03,355]    INFO [ILPSolver] (Process: 274a7adceaa6c5c0648e04e5daba4a40) LCA distance 4 completed with a neighborhood size of 29.
[2023-04-30 18:33:03,380]    INFO [ILPSolver] (Process: 274a7adceaa6c5c0648e04

Set parameter Username
Academic license - for non-commercial use only - expires 2024-04-26


[2023-04-30 18:33:03,501]    INFO [ILPSolver] (Process 274a7adceaa6c5c0648e04e5daba4a40) Steiner tree solving tool 0 days, 0 hours, 0 minutes, and 0 seconds.


ILP Solved:	 0.20414285714285713
Greedy Solved:	 0.22047619047619046
NJ Solved:	 0.22061904761904763


 53%|█████▎    | 8/15 [02:55<02:32, 21.78s/it]

In [171]:
from geoopt.manifolds.lorentz.math import inner as geoopt_mdp
from heracles.hyperboloid_wilson import minkowski_dot as wilson_mdp

# wilson pt
pt1 = torch.tensor([
    -2.82828054e+07,  1.61832847e+07, -1.30400228e+07, -1.01954093e+07,
    -3.41873788e+07,  9.26049530e+06,  1.87116367e+07,  4.31363021e+06,
    1.71363229e+07,  2.06445712e+07,  3.77446594e+07,  2.24614149e+07,
    1.11693457e+06,  3.84794967e+07,  1.68280638e+07,  1.27803679e+06,
    -5.10010376e+06,  1.09749058e+07, -5.76268200e+06,  3.47728168e+07,
    -2.03301088e+07, -1.97533264e+07, -1.93611119e+07, -7.70314356e+06,
    1.44952964e+07, -1.83956401e+07,  1.27952039e+07,  2.36982942e+07,
    -2.05705965e+07,  1.08158854e+08],  dtype=torch.float64
)

pt2 = torch.tensor([
    -12391.27881638,   7090.95770962,  -5692.52423663,  -4453.40727088,
    -14966.25533362,   4056.20724708,   8204.33313047,   1877.90689277,
    7475.52070145,   9051.06373339,  16542.94743079,   9824.35412299,
        507.52694717,  16875.88623624,   7347.17530655,    540.33063492,
    -2238.1457424,    4786.88108994,  -2514.7068528,   15230.65532572,
    -8908.95529633,  -8664.08879035,  -8502.61690185,  -3326.01432724,
    6347.94495746,  -8058.41865838,   5618.47676266,  10356.32921451,
    -8982.9349874,   47364.77636841], dtype=torch.float64
)

def wilson_to_geoopt(pt):
    # switch to geoopt convention
    cp = pt.clone()
    temp = cp[0].clone()
    cp[0] = cp[-1]
    cp[-1] = temp
    return cp

In [216]:
pt3 = torch.tensor([
    -36619.98635744,  13043.51925601, -9967.58759802,  28144.8334594,
  23399.59840025,  -2253.61275547,   9005.55508014, -51996.36019142,
  -9582.78151001,  76366.30407741], dtype=torch.float64)

pt4 = torch.tensor([-2306517.19097746, 
                    -1462074.1054088, 
                    2730875.72800282],
                   dtype=torch.float64)

print(wilson_mdp(pt4, pt4).item())

186.5625


In [236]:
print(23**2 + 14**2)
print(27**2)

725
729


In [237]:
arr4 = np.array([-2306517.19097746, 
                    -1462074.1054088, 
                    2730875.72800282], dtype=np.longdouble)

wilson_mdp2(arr4, arr4).item()

Pos:  7457682241981.495 	 Neg:  7457682241794.933


186.5625

In [217]:
def wilson_mdp2(u, v):
    rank = u.shape[-1] - 1
    pos = u[:rank].dot(v[:rank])
    neg = u[rank] * v[rank]
    
    # print(u[:rank])
    print('Pos: ', pos.item(), '\t Neg: ', neg.item())
    return pos - neg

def geoopt_mdp2(u, v, keepdim: bool = False, dim: int = -1):
    d = u.size(dim) - 1
    uv = u * v
    
    pos = uv.narrow(dim, 1, d).sum(dim=dim, keepdim=False)
    neg = uv.narrow(dim, 0, 1).sum(dim=dim, keepdim=False)
    
    # print(u.narrow(dim, 1, d))
    # print('Pos: ', pos.item(), '\tNeg: ', neg.item())
    return -neg + pos
        
print('Wilson MDP:\t', wilson_mdp2(pt4, pt4).item())
print()
print('Geoopt MDP:\t', geoopt_mdp2(wilson_to_geoopt(pt4), wilson_to_geoopt(pt4)).item())

Pos:  7457682241981.495 	 Neg:  7457682241794.933
Wilson MDP:	 186.5625

Geoopt MDP:	 186.5625


In [175]:
print('Wilson MDP:\t'`, wilson_mdp(pt3, pt3).item())
print('Geoopt MDP:\t', geoopt_mdp(wilson_to_geoopt(pt3), wilson_to_geoopt(pt3)).item())

Wilson MDP:	 22.609623908996582
Geoopt MDP:	 22.609623908996582


In [209]:
a = np.array([10], dtype=np.float64)
b = np.power(a, 308)
print(b)

[1.e+308]


wilson mdp: $\big( u_0 * v_0 + u_1 * v_1 + ... + u_{n-1} * v_{n-1}\big) - u_n * v_n$

geoopt mdp: $- u_0 * v_0 + \big( u_1 * v_1 + ... + u_{n-1} * v_{n-1} + u_n * v_n \big)$ 