# Imports and Global Variables

In [65]:
import os
os.chdir('/Users/gil2rok/school/crispr-phylogeny2/code/Heracles')

import pickle
import torch
import geoopt
import icecream as ic
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import mlflow

from mlflow import log_metric
from scipy.linalg import expm
from scipy.cluster.hierarchy import average, to_tree
import icecream as ic

from metrics2 import triplets_correct, dist_correlation, cas_triplets_correct
from logalike import Logalike
from util.util import char_matrix_to_dist_matrix, embed_tree, estimate_tree, generate_Q

In [66]:
%load_ext autoreload
%autoreload 2

sns.set_theme()
path = 'data' # relative path between parent directory (set above) to directory with saved data

mlflow.set_tracking_uri('http://127.0.0.1:5000')  # set up connection
mlflow.set_experiment('heracles') # set the experiment

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<Experiment: artifact_location='mlflow-artifacts:/0', creation_time=1681143470700, experiment_id='0', last_update_time=1681958091218, lifecycle_stage='active', name='heracles', tags={}>

# Running Model

In [67]:
# load true tree
fname = os.path.join(path, 'true_tree')
with open(fname, 'rb') as file:
    true_tree = pickle.load(file)
    
# load parameters for infinitesimal generator Q
fname = os.path.join(path, 'params')
with open(fname, 'rb') as file:
    params = pickle.load(file)

In [109]:
# start mlflow run and preprocess character matrix
with mlflow.start_run():
    char_matrix = true_tree.character_matrix.drop_duplicates().to_numpy(dtype=int) # character matrix of leaf cells

    # hyperparameters
    num_cells = char_matrix.shape[0] # number of cells
    num_sites = char_matrix.shape[1] # number of target sites
    num_states = 15
    num_epochs = 60
    embedding_dim = 3
    rho = torch.tensor([2], dtype=torch.float64)

    # generate infinitesimal generator matrix Q and initial embedding X
    Q_list = [generate_Q(num_sites, num_states) for _ in range(num_sites)]
    dist_matrix = char_matrix_to_dist_matrix(char_matrix) # compute distance matrix
    est_tree = estimate_tree(dist_matrix, method='neighbor-joining') # estimate phylogenetic tree
    X = embed_tree(est_tree, torch.sqrt(rho), num_cells, local_dim=embedding_dim-1)

    # initalize logalike object and optimizer
    l = Logalike(X=X, priors=None, Q_list=Q_list, character_matrix=char_matrix, num_states=num_states, rho=rho,)
    opt = geoopt.optim.RiemannianSGD([l.X], lr=5e-2, stabilize=1)

    # run optimization
    for epoch in range(num_epochs):
        epoch_loss = 0

        for i in range(num_cells):
            opt.zero_grad() # zero gradient
            loss = -l.forward(i) # negative log likelihood of tree configuration
            loss.backward() # gradient on manifold
            opt.step() # take opt step
            
            epoch_loss += loss.item()
            
        log_metric('epoch_loss', epoch_loss)
        log_metric('tc all', cas_triplets_correct(true_tree, l.X, rho, all_triplets=True))
        log_metric('tc resolved', cas_triplets_correct(true_tree, l.X, rho, all_triplets=False))
        log_metric('distance correlation', dist_correlation(true_tree, l.X, rho))

In [118]:
from cassiopeia.data import CassiopeiaTree
from cassiopeia.solver import VanillaGreedySolver, ILPSolver
from cassiopeia import critique

cas_tree = CassiopeiaTree(character_matrix=true_tree.character_matrix, priors=None)
vanilla_greedy = VanillaGreedySolver()
vanilla_greedy.solve(cas_tree, collapse_mutationless_edges=True)
triplets = critique.triplets_correct(true_tree, cas_tree)
ans = np.mean(list(triplets[0].values()))
print(ans)

0.065


In [114]:
cas_tree = CassiopeiaTree(character_matrix=true_tree.character_matrix, priors=None)
vanilla_greedy = ILPSolver()
vanilla_greedy.solve(cas_tree, collapse_mutationless_edges=True)
triplets = critique.triplets_correct(true_tree, cas_tree)
ans = np.mean(list(triplets[0].values()))
print(ans)

[2023-04-25 15:23:31,524]    INFO [ILPSolver] Solving tree with the following parameters.
[2023-04-25 15:23:31,525]    INFO [ILPSolver] Convergence time limit: 12600
[2023-04-25 15:23:31,526]    INFO [ILPSolver] Convergence iteration limit: 0
[2023-04-25 15:23:31,526]    INFO [ILPSolver] Max potential graph layer size: 10000
[2023-04-25 15:23:31,527]    INFO [ILPSolver] Max potential graph lca distance: None
[2023-04-25 15:23:31,527]    INFO [ILPSolver] MIP gap: 0.01
[2023-04-25 15:23:31,528]    INFO [ILPSolver] Phylogenetic root: (0, 3, 0, 0, 0)
[2023-04-25 15:23:31,530]    INFO [ILPSolver] (Process: 77d8e698ed4584d6a2b61608b34a10bd) Estimating a potential graph with a maximum layer size of 10000 and a maximum LCA distance of 3.
[2023-04-25 15:23:31,531]    INFO [ILPSolver] (Process: 77d8e698ed4584d6a2b61608b34a10bd) LCA distance 0 completed with a neighborhood size of 8.
[2023-04-25 15:23:31,533]    INFO [ILPSolver] (Process: 77d8e698ed4584d6a2b61608b34a10bd) LCA distance 1 completed

ILPSolverError: Gurobi not found. You must install Gurobi & gurobipy from source.