# Imports and Global Variables

In [1]:
import torch
import geoopt
import numpy as np
import pickle
import os
import sys

from scipy.linalg import expm
from logalike import Logalike

In [2]:
%load_ext autoreload
%autoreload 2

path = '../data' # relative path to directory with saved data

# Running Model

In [3]:
# load true tree
fname = os.path.join(path, 'true_tree')
with open(fname, 'rb') as file:
    true_tree = pickle.load(file)
    
# load parameters for infinitesimal generator Q
fname = os.path.join(path, 'params')
with open(fname, 'rb') as file:
    params = pickle.load(file)
    
# TODO: fix path to call extract_compact_Q
# Q_compact = torch.tensor(extract_compact_Q(params['mutation_rate'], params['deletion_rate']))

In [172]:
def generate_Q(num_sites, num_states):
    
    deletion_rate = 1e-5 # global deletion rate 
    mutation_rate = [0.1] * num_sites # site-specific mutation rate
    indel_distribution = [1/num_states] * num_states
    
    Q = np.zeros((num_sites + num_states + 1, num_sites + num_states + 1),
                 dtype=np.float64)
    
    for i in range(num_sites + num_states): # fill in diagonals
        if i < num_sites:
            Q[i,i] = - (mutation_rate[i] + deletion_rate)
        else:
            Q[i,i] = - deletion_rate
            
    for i in range(num_sites): # fill in upper right
        for j in range(num_states):
            Q[i, num_sites + j] = mutation_rate[i] * indel_distribution[j]
            
    for i in range(num_sites + num_states):
        Q[i, -1] = deletion_rate
        
    return Q

In [173]:
cm = true_tree.character_matrix.to_numpy()

num_cells = cm.shape[0] # number of cells
num_sites = cm.shape[1] # number of target sites
num_states = 50
embedding_dim = 3

Q_list = [generate_Q(num_sites, num_states) for _ in range(num_sites)]

rho = torch.tensor(1, dtype=torch.float64)

# initial guess for points
num_cells = cm.shape[0]
manifold = geoopt.Lorentz(k=rho)
X = manifold.random_normal(num_cells, embedding_dim,
                           mean=2, std=15, dtype=torch.float64)

In [185]:
l = Logalike(X=X,
             priors=None,
             Q_list=Q_list,
             character_matrix=cm,
             num_states=num_states,
             rho=rho,
            )

opt = geoopt.optim.RiemannianAdam(l.parameters(), lr=1e-3)
for i in range(num_cells):
    opt.zero_grad()
    loss = l.forward(i)
    print('LOSS: ',loss)
    loss.backward()
    opt.step()
    break

			Original Indicies: 0 0 tensor([0])
0.01936979159409727 	 0.02 tensor(0.9841, dtype=torch.float64) tensor(0.9841, dtype=torch.float64)
			Original Indicies: 27 27 tensor([ 0, 27])
0.019999937988718 	 0.02 tensor(1.0000, dtype=torch.float64) tensor(1.0000, dtype=torch.float64)
			Original Indicies: 34 34 tensor([ 0, 34])
0.019999937988718 	 0.02 tensor(1.0000, dtype=torch.float64) tensor(1.0000, dtype=torch.float64)
			Original Indicies: 0 0 tensor([0])
0.01936979159409727 	 0.02 tensor(0.9841, dtype=torch.float64) tensor(0.9841, dtype=torch.float64)
			Original Indicies: 0 45 tensor([0])
6.251010686975459e-06 	 0.02 tensor(0.9841, dtype=torch.float64) tensor(0.0003, dtype=torch.float64)
			Original Indicies: 0 0 tensor([0])
0.01936979159409727 	 0.02 tensor(0.9841, dtype=torch.float64) tensor(0.9841, dtype=torch.float64)
			Original Indicies: 0 0 tensor([0])
0.01936979159409727 	 0.02 tensor(0.9841, dtype=torch.float64) tensor(0.9841, dtype=torch.float64)
			Original Indicies: 10 41 

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn