In [None]:
import numpy as np
import importlib
import sys, os
import matplotlib.pyplot as plt
import time
import random
from numba import njit
sys.path.append(os.path.abspath('../../lib'))
import pairs_tensor_constructor
import pairs_tensor_util
import util
import tree_sampler_PT_to_anc
import numba
from util import make_adj_from_anc
from tree_sampler import _calc_tree_llh

from common import Models
from data_simulator_full_auto import generate_simulated_data
from tree_plotter import plot_tree

In [None]:
n_muts = 200
n_cells = 1000
fpr = 0.001
ado = 0.1

data, true_tree = generate_simulated_data(n_clust=n_muts, 
                                            n_cells=n_cells, 
                                            n_muts=n_muts, 
                                            FPR=fpr, 
                                            ADO=ado, 
                                            cell_alpha=1, 
                                            mut_alpha=1,
                                            drange=1
                                            )
adj_mat = true_tree[1]
anc_mat = util.make_ancestral_from_adj(adj_mat)

In [None]:
pairs_tensor = pairs_tensor_constructor.construct_pairs_tensor(data,0.001,0.1,1, verbose=False)

In [None]:
importlib.reload(tree_sampler_PT_to_anc)

n_samples = 10
times = []
samples = np.zeros((n_muts+1,n_muts+1,n_samples))
for i in range(n_samples):
    s = time.time()
    samples[:,:,i] = tree_sampler_PT_to_anc._sample_tree(pairs_tensor)
    times.append(time.time() - s)

print(np.mean(times))

In [None]:
from torch import norm


def _make_selection(selection_probs):
    choice_array = np.exp(selection_probs - np.max(selection_probs))
    choice_array = choice_array.flatten() / np.sum(choice_array)
    rng = np.random.default_rng()
    choice = rng.choice(len(choice_array), size=1, p=choice_array)
    # choice = np.random.choice(len(choice_array), size=(1,), p=choice_array)
    i,j,rel = np.unravel_index(choice, shape=selection_probs.shape)
    return int(i),int(j),int(rel)

@numba.njit()
def _make_selection_numba(selection_probs):
    maxP = np.max(selection_probs)
    norm_sp = np.copy(selection_probs)
    nrmC = 0
    for i in range(norm_sp.shape[0]):
        for j in range(norm_sp.shape[1]):
            for rel in range(norm_sp.shape[2]):
                norm_sp[i,j,rel] = np.exp(norm_sp[i,j,rel] - maxP)
                nrmC = nrmC + norm_sp[i,j,rel]

    s = 0
    a = np.random.rand()*nrmC
    choice_made = False
    for i in range(norm_sp.shape[0]):
        for j in range(norm_sp.shape[1]):
            for rel in range(norm_sp.shape[2]):
                if a < s + norm_sp[i,j,rel]:
                    choice_made = True
                    break
                s = s + norm_sp[i,j,rel]
            if choice_made:
                break
        if choice_made:
                break
    return i,j,rel

In [None]:
A = np.log(np.random.rand(10,10,5))

norm_times = []
njit_times = []
norm_res = []
njit_res = []
for i in range(10000):
    s = time.time()
    norm_res.append(_make_selection(A))
    norm_times.append(time.time()-s)
    s = time.time()
    njit_res.append(_make_selection_numba(A))
    njit_times.append(time.time()-s)

print(np.sum(norm_times))
print(np.sum(njit_times))

plt.figure(figsize=(10,5))
for i in range(5):
    plt.subplot(2,3,i+1)
    plt.imshow(np.exp(A[:,:,i]))
    plt.title("A_mat: " + str(i))

plt.figure(figsize=(10,5))
toplt = np.zeros((10,10,5))
for i in norm_res:
    toplt[i[0],i[1],i[2]] += 1
for i in range(5):
    plt.subplot(2,3,i+1)
    plt.imshow(toplt[:,:,i])
    plt.title("norm_res " + str(i))

plt.figure(figsize=(10,5))
toplt = np.zeros((10,10,5))
for i in njit_res:
    toplt[i[0],i[1],i[2]] += 1
for i in range(5):
    plt.subplot(2,3,i+1)
    plt.imshow(toplt[:,:,i])
    plt.title("njit_res: " + str(i))
