# Imports and Global Variables

In [1]:
import os
os.chdir('/Users/gil2rok/school/crispr-phylogeny2/code/Heracles')

import pickle
import torch
import geoopt
import icecream as ic
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import mlflow

from mlflow import log_metric
from scipy.linalg import expm
from scipy.cluster.hierarchy import average, to_tree
import icecream as ic

from metrics import dist_correlation, rf, tc2
from metrics2 import triplets_correct, dist_correlation
from logalike import Logalike
from util.util import char_matrix_to_dist_matrix, embed_tree, estimate_tree, generate_Q

In [2]:
%load_ext autoreload
%autoreload 2

sns.set_theme()
path = 'data' # relative path between parent directory (set above) to directory with saved data

mlflow.set_tracking_uri('http://127.0.0.1:5000')  # set up connection
mlflow.set_experiment('heracles_official') # set the experiment

<Experiment: artifact_location='mlflow-artifacts:/0', creation_time=1681143470700, experiment_id='0', last_update_time=1681175451897, lifecycle_stage='active', name='heracles_official', tags={}>

# Running Model

In [3]:
# load true tree
fname = os.path.join(path, 'true_tree')
with open(fname, 'rb') as file:
    true_tree = pickle.load(file)
    
# load parameters for infinitesimal generator Q
fname = os.path.join(path, 'params')
with open(fname, 'rb') as file:
    params = pickle.load(file)

In [25]:
# start mlflow run and preprocess character matrix
with mlflow.start_run():
    char_matrix = true_tree.character_matrix.drop_duplicates().to_numpy(dtype=int) # character matrix of leaf cells

    # hyperparameters
    num_cells = char_matrix.shape[0] # number of cells
    num_sites = char_matrix.shape[1] # number of target sites
    num_states = 15
    num_epochs = 40
    embedding_dim = 5
    rho = torch.tensor([2], dtype=torch.float64)

    # generate infinitesimal generator matrix Q and initial embedding X
    Q_list = [generate_Q(num_sites, num_states) for _ in range(num_sites)]
    dist_matrix = char_matrix_to_dist_matrix(char_matrix) # compute distance matrix
    est_tree = estimate_tree(dist_matrix, method='neighbor-joining') # estimate phylogenetic tree
    X = embed_tree(est_tree, torch.sqrt(rho), num_cells, local_dim=embedding_dim-1)

    # initalize logalike object and optimizer
    l = Logalike(X=X, priors=None, Q_list=Q_list, character_matrix=char_matrix, num_states=num_states, rho=rho,)
    opt = geoopt.optim.RiemannianSGD([l.X], lr=5e-2, stabilize=1)

    # run optimization
    for epoch in range(num_epochs):
        epoch_loss = 0

        for i in range(num_cells):
            opt.zero_grad() # zero gradient
            loss = -l.forward(i) # negative log likelihood of tree configuration
            loss.backward() # gradient on manifold
            opt.step() # take opt step
            
            epoch_loss += loss.item()
            
        log_metric('epoch_loss', epoch_loss)
        log_metric('triplets correct', triplets_correct(true_tree, l.X, rho))
        log_metric('distance correlation', dist_correlation(true_tree, l.X, rho))