In [16]:
import os
import pickle
import numpy as np
import torch
import heracles

from heracles.main2 import main
from heracles.metrics import cas_triplets_correct
from simulate_data import simulate_data
from cassiopeia.data import CassiopeiaTree
from cassiopeia.solver import VanillaGreedySolver
from cassiopeia.critique import triplets_correct

# automatically reload modules
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# simulate data
num_states=15
num_sites=5
mutation_rate=0.4
deletion_rate=9e-4    
transition_prob = {i: 1/num_states for i in range(num_states)}
path = '../heracles/data'

simulate_data(transition_prob, num_sites, num_states, mutation_rate, deletion_rate, path)

In [7]:
# load data
fname = os.path.join(path, 'true_tree')
with open(fname, 'rb') as file:
    true_tree = pickle.load(file)
    
fname = os.path.join(path, 'params')
with open(fname, 'rb') as file:
    params = pickle.load(file)

In [9]:
cas_tree = CassiopeiaTree(character_matrix=true_tree.character_matrix, priors=None)
vanilla_greedy = VanillaGreedySolver()
vanilla_greedy.solve(cas_tree, collapse_mutationless_edges=True)
triplets = triplets_correct(true_tree, cas_tree)
cas_ans = np.mean(list(triplets[0].values()))

In [17]:
char_matrix = true_tree.character_matrix
mutation_rate = params['mutation_rate']
deletion_rate = params['deletion_rate']
transition_prob = params['transition_prob']
seed = 0
num_epochs = 30
lr = 5e-2
embedding_dim = 3
rho = 2
stabilize = 1
est_tree_method = 'neighbor-joining'
true_tree = true_tree

best_embeddings = main(char_matrix, mutation_rate, deletion_rate, transition_prob,
                       seed, num_epochs, lr, embedding_dim, rho, stabilize, est_tree_method, true_tree)
heracles_ans = cas_triplets_correct(true_tree, best_embeddings, rho)

In [18]:
print('Cassiopeia: ', cas_ans)
print('Heracles: ', heracles_ans)

Cassiopeia:  0.378625
Heracles:  0.5585
