# Imports and Global Variables

In [3]:
import os
import sys
os.chdir('/Users/gil2rok/school/peer_research/Heracles/heracles/')

import pickle
import torch
import geoopt
import icecream as ic
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import mlflow
import cassiopeia

from mlflow import log_metric
from scipy.linalg import expm
from scipy.cluster.hierarchy import average, to_tree
import icecream as ic

from metrics import triplets_correct, dist_correlation, cas_triplets_correct
from logalike import Logalike
from util import char_to_dist, embed_tree, estimate_tree, generate_Q

In [10]:
%load_ext autoreload
%autoreload 2

sns.set_theme()
path = 'data' # relative path between parent directory (set above) to directory with saved data

mlflow.set_tracking_uri('http://127.0.0.1:5000')  # set up connection
mlflow.set_experiment('heracles training') # set the experiment

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<Experiment: artifact_location='mlflow-artifacts:/823951667479544761', creation_time=1682531041860, experiment_id='823951667479544761', last_update_time=1682531041860, lifecycle_stage='active', name='heracles training', tags={}>

# Running Model

In [11]:
# load true tree
fname = os.path.join(path, 'true_tree')
with open(fname, 'rb') as file:
    true_tree = pickle.load(file)
    
# load parameters for infinitesimal generator Q
fname = os.path.join(path, 'params')
with open(fname, 'rb') as file:
    params = pickle.load(file)

In [18]:
# start mlflow run and preprocess character matrix
with mlflow.start_run():
    char_matrix = true_tree.character_matrix.drop_duplicates().to_numpy(dtype=int) # character matrix of leaf cells

    # data hyperparameters
    num_cells = char_matrix.shape[0] # number of cells
    num_sites = char_matrix.shape[1] # number of target sites
    num_states = 15
    
    deletion_rate = params['deletion_rate'] # global deletion rate
    mutation_rate = params['mutation_rate'] # site-specific mutation rate
    indel_distribution = list(params['transition_prob'].values())
    
    # model hyperparameters
    num_epochs = 30
    embedding_dim = 3
    lr = 5e-2
    tree_reconstriction = 'neighbor-joining'
    rho = torch.tensor([2], dtype=torch.float64)
    stabilize = 1

    # generate infinitesimal generator matrix Q and initial embedding X
    Q_list = [generate_Q(num_sites, num_states, deletion_rate, mutation_rate, indel_distribution)
              for _ in range(num_sites)]
    dist_matrix = char_to_dist(char_matrix) # compute distance matrix
    est_tree = estimate_tree(dist_matrix, method=tree_reconstriction) # estimate phylogenetic tree
    X = embed_tree(est_tree, torch.sqrt(rho), num_cells, local_dim=embedding_dim-1)
    
    # initalize logalike object and optimizer
    l = Logalike(X=X, Q_list=Q_list, character_matrix=char_matrix, num_states=num_states, rho=rho, priors=None,)
    opt = geoopt.optim.RiemannianSGD([l.X], lr=lr, stabilize=stabilize)

    # run optimization
    for epoch in range(num_epochs):
        epoch_loss = 0

        for i in range(num_cells):
            opt.zero_grad() # zero gradient
            loss = -l.forward(i) # negative log likelihood of tree configuration
            loss.backward() # gradient on manifold
            opt.step() # take opt step
            
            epoch_loss += loss.item()
            
        log_metric('epoch_loss', epoch_loss)
        log_metric('tc all', cas_triplets_correct(true_tree, l.X, rho, all_triplets=True))
        log_metric('tc resolved', cas_triplets_correct(true_tree, l.X, rho, all_triplets=False))
        log_metric('distance correlation', dist_correlation(true_tree, l.X, rho))

[autoreload of util failed: Traceback (most recent call last):
  File "/Users/gil2rok/mambaforge/envs/crispr-phylogeny/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 261, in check
    superreload(m, reload, self.old_objects)
  File "/Users/gil2rok/mambaforge/envs/crispr-phylogeny/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 459, in superreload
    module = reload(module)
  File "/Users/gil2rok/mambaforge/envs/crispr-phylogeny/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 879, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1017, in get_code
  File "<frozen importlib._bootstrap_external>", line 947, in source_to_code
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/Users/gil2rok/school/peer_research/Heracles/heracles2/util.p

-2.000000000392815 -2.0
-2.000000205839253 -2.0
-2.0000002259820917 -2.0
-1.9999998627637297 -2.0
-2.000012965673193 -2.0
-1.9999998534453256 -2.0
-1.9999998534453256 -2.0
-1.9999999315429164 -2.0
-1.9999998549168936 -2.0
-2.000000000392815 -2.0
-2.0 -2.0
-1.9999999999999982 -2.0
-2.0 -2.0
-1.9999999999999716 -2.0
-2.0000000000000004 -2.0
-2.0 -2.0
-1.9999999999999998 -2.0
-1.9999999999999987 -2.0
-1.9999999999999996 -2.0
-2.0 -2.0
-2.0000000000000036 -2.0
-2.0000000000000036 -2.0
-1.9999999999999716 -2.0
-1.9999999999999996 -2.0
-2.0 -2.0
-1.9999999999999993 -2.0
-2.0 -2.0
-2.0 -2.0
-2.0 -2.0
-2.0000000000000036 -2.0
-1.9999999999999964 -2.0
-2.0 -2.0
-1.9999999999999996 -2.0
-2.0 -2.0
-1.9999999999999996 -2.0
-1.9999999999999964 -2.0
-1.9999999999999991 -2.0
-2.0000000000000036 -2.0
-1.9999999999999947 -2.0
-1.9999999999999964 -2.0
-2.000000000000014 -2.0
-2.0 -2.0
-2.0000000000001137 -2.0
-1.9999999999999996 -2.0
-2.000000000000014 -2.0
-1.9999999999999996 -2.0
-2.0 -2.0
-1.99999999

In [118]:
from cassiopeia.data import CassiopeiaTree
from cassiopeia.solver import VanillaGreedySolver, ILPSolver
from cassiopeia import critique

cas_tree = CassiopeiaTree(character_matrix=true_tree.character_matrix, priors=None)
vanilla_greedy = VanillaGreedySolver()
vanilla_greedy.solve(cas_tree, collapse_mutationless_edges=True)
triplets = critique.triplets_correct(true_tree, cas_tree)
ans = np.mean(list(triplets[0].values()))
print(ans)

0.065


In [114]:
cas_tree = CassiopeiaTree(character_matrix=true_tree.character_matrix, priors=None)
vanilla_greedy = ILPSolver()
vanilla_greedy.solve(cas_tree, collapse_mutationless_edges=True)
triplets = critique.triplets_correct(true_tree, cas_tree)
ans = np.mean(list(triplets[0].values()))
print(ans)

[2023-04-25 15:23:31,524]    INFO [ILPSolver] Solving tree with the following parameters.
[2023-04-25 15:23:31,525]    INFO [ILPSolver] Convergence time limit: 12600
[2023-04-25 15:23:31,526]    INFO [ILPSolver] Convergence iteration limit: 0
[2023-04-25 15:23:31,526]    INFO [ILPSolver] Max potential graph layer size: 10000
[2023-04-25 15:23:31,527]    INFO [ILPSolver] Max potential graph lca distance: None
[2023-04-25 15:23:31,527]    INFO [ILPSolver] MIP gap: 0.01
[2023-04-25 15:23:31,528]    INFO [ILPSolver] Phylogenetic root: (0, 3, 0, 0, 0)
[2023-04-25 15:23:31,530]    INFO [ILPSolver] (Process: 77d8e698ed4584d6a2b61608b34a10bd) Estimating a potential graph with a maximum layer size of 10000 and a maximum LCA distance of 3.
[2023-04-25 15:23:31,531]    INFO [ILPSolver] (Process: 77d8e698ed4584d6a2b61608b34a10bd) LCA distance 0 completed with a neighborhood size of 8.
[2023-04-25 15:23:31,533]    INFO [ILPSolver] (Process: 77d8e698ed4584d6a2b61608b34a10bd) LCA distance 1 completed

ILPSolverError: Gurobi not found. You must install Gurobi & gurobipy from source.