# Prototype PyMC3

In [None]:
# %load lib/genotype_mixture
#!/usr/bin/env python3

import numpy as np
import pymc3 as pm
from pymc3.distributions.transforms import t_stick_breaking, logodds
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from lib.util import info
from lib.pileup import list_bases, get_pileup_dims
import tqdm
import theano.tensor as tt
import theano.sparse as ts
from itertools import product

from lib.pymc3 import trace_stat_plot
from arviz import ess
from lib.genotype_mixture import gamma_plot, pi_plot



stick_breaking = t_stick_breaking(1e-10)


def build_biallelic_model_fuzzy(g, n, s):
    # EXPERIMENTAL: Observations overdispersed as a BetaBinom w/ concentrations
    # 10.
    a = 2

    with pm.Model() as model:
        # Fraction
        pi = pm.Dirichlet('pi', a=np.ones(s), shape=(n, s),
                          transform=stick_breaking,
                          )
        pi_hyper = pm.Data('pi_hyper', value=0.0)
        pm.Potential('heterogeneity_penalty',
                     -(pm.math.sqrt(pi).sum(0).sum()**2) * pi_hyper)

        rho_hyper = pm.Data('rho_hyper', value=0.0)
        pm.Potential('diversity_penalty',
                     -(pm.math.sqrt(pi.sum(0)).sum()**2)
                     * rho_hyper)

        # Genotype
        gamma_ = pm.Uniform('gamma_', 0, 1, shape=(g * s, 1))
        gamma = pm.Deterministic('gamma', (pm.math.concatenate([gamma_, 1 - gamma_], axis=1)
                                           .reshape((g, s, a))))
        gamma_hyper = pm.Data('gamma_hyper', value=0.0)
        pm.Potential('ambiguity_penalty',
                     -((pm.math.sqrt(gamma).sum(2)**2).sum(0) * pi.sum(0)).sum(0)
                     * gamma_hyper)

        # Product of fraction and genotype
        true_p = pm.Deterministic('true_p', pm.math.dot(pi, gamma))

        # Sequencing error
        epsilon_hyper = pm.Data('epsilon_hyper', value=100)
        epsilon = pm.Beta('epsilon', alpha=2, beta=epsilon_hyper,
                          shape=n)
        epsilon_ = epsilon.reshape((n, 1, 1))
        err_base_prob = tt.ones((n, g, a)) / a
        p_with_error = (true_p * (1 - epsilon_)) + (err_base_prob * epsilon_)

        # Observation
        observed = pm.Data('observed', value=np.empty((g * n, a)))

        _p = p_with_error.reshape((-1, a))[:, 0]
        # Overdispersion term
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
        # TODO: Figure out how to also fit this term.
        # FIXME: Do I want the default to be a valid value?
        #  Realistic or close to asymptotic?
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
#         alpha = pm.Data('alpha', value=1000)
#         pm.BetaBinomial('data',
#                         alpha=_p * alpha,
#                         beta=(1 - _p) * alpha,
#                         n=observed.reshape((-1, a)).sum(1),
#                         observed=observed[:,0])

        # FIXME: This may not work as well as the
        # highly concentrated BetaBinomial above.
        pm.Binomial('data',
                    p=_p,
                    n=observed.reshape((-1, a)).sum(1),
                    observed=observed[:,0])

    return model

def build_biallelic_model_discrete(g, n, s):
    # Discrete haplotypes.
    a = 2

    with pm.Model() as model:
        # Fraction
        pi = pm.Dirichlet('pi', a=np.ones(s), shape=(n, s),
                          transform=stick_breaking,
                          )
        pi_hyper = pm.Data('pi_hyper', value=0.0)
        pm.Potential('heterogeneity_penalty',
                     -(pm.math.sqrt(pi).sum(0).sum()**2) * pi_hyper)

        rho_hyper = pm.Data('rho_hyper', value=0.0)
        pm.Potential('diversity_penalty',
                     -(pm.math.sqrt(pi.sum(0)).sum()**2)
                     * rho_hyper)

        # Genotype
        gamma_ = pm.Bernoulli('gamma_', p=0.5, shape=(g * s, 1))
        gamma = pm.Deterministic('gamma', (pm.math.concatenate([gamma_, 1 - gamma_], axis=1)
                                           .reshape((g, s, a))))
        gamma_hyper = pm.Data('gamma_hyper', value=0.0)
        pm.Potential('ambiguity_penalty',
                     -((pm.math.sqrt(gamma).sum(2)**2).sum(0) * pi.sum(0)).sum(0)
                     * gamma_hyper)

        # Product of fraction and genotype
        true_p = pm.Deterministic('true_p', pm.math.dot(pi, gamma))

        # Sequencing error
        epsilon_hyper = pm.Data('epsilon_hyper', value=100)
        epsilon = pm.Beta('epsilon', alpha=2, beta=epsilon_hyper,
                          shape=n)
        epsilon_ = epsilon.reshape((n, 1, 1))
        err_base_prob = tt.ones((n, g, a)) / a
        p_with_error = (true_p * (1 - epsilon_)) + (err_base_prob * epsilon_)

        # Observation
        observed = pm.Data('observed', value=np.empty((g * n, a)))

        _p = p_with_error.reshape((-1, a))[:, 0]
        # Overdispersion term
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
        # TODO: Figure out how to also fit this term.
        # FIXME: Do I want the default to be a valid value?
        #  Realistic or close to asymptotic?
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
#         alpha = pm.Data('alpha', value=1000)
#         pm.BetaBinomial('data',
#                         alpha=_p * alpha,
#                         beta=(1 - _p) * alpha,
#                         n=observed.reshape((-1, a)).sum(1),
#                         observed=observed[:,0])

        # FIXME: This may not work as well as the
        # highly concentrated BetaBinomial above.
        pm.Binomial('data',
                    p=_p,
                    n=observed.reshape((-1, a)).sum(1),
                    observed=observed[:,0])

    return model


def dirichlet_process_rv(prefix, k):
    alpha = pm.Gamma(prefix + '_alpha', 1., 1.)
    beta = pm.Beta(prefix + '_beta', 1., alpha, shape=k)
    w = pm.Deterministic(
        prefix + '',
        beta * tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
    )
    return w


def build_biallelic_model_fuzzy_dp(g, n, s):
    # EXPERIMENTAL: Observations overdispersed as a BetaBinom w/ concentrations
    # 10.
    a = 2

    with pm.Model() as model:
        pi_w = dirichlet_process_rv('pi_w', s)
        pi_alpha = 100
        pi = pm.Dirichlet('pi', a=pi_w * pi_alpha, shape=(n, s),
                          transform=stick_breaking,
                          testval=np.ones((n, s))/s  # Uniform
                          )
        pi_hyper = pm.Data('pi_hyper', value=0.0)
        pm.Potential('heterogeneity_penalty',
                     -(pm.math.sqrt(pi).sum(0).sum()**2) * pi_hyper)

        rho_hyper = pm.Data('rho_hyper', value=0.0)
        pm.Potential('diversity_penalty',
                     -(pm.math.sqrt(pi.sum(0)).sum()**2)
                     * rho_hyper)

        # Genotype
        gamma_ = pm.Uniform('gamma_', 0, 1, shape=(g * s, 1))
        gamma = pm.Deterministic('gamma', (pm.math.concatenate([gamma_, 1 - gamma_], axis=1)
                                           .reshape((g, s, a))))
        gamma_hyper = pm.Data('gamma_hyper', value=0.0)
        pm.Potential('ambiguity_penalty',
                     -((pm.math.sqrt(gamma).sum(2)**2).sum(0) * pi.sum(0)).sum(0)
                     * gamma_hyper)

        # Product of fraction and genotype
        true_p = pm.Deterministic('true_p', pm.math.dot(pi, gamma))

        # Sequencing error
        epsilon_hyper = pm.Data('epsilon_hyper', value=100)
        epsilon = pm.Beta('epsilon', alpha=2, beta=epsilon_hyper,
                          shape=n)
        epsilon_ = epsilon.reshape((n, 1, 1))
        err_base_prob = tt.ones((n, g, a)) / a
        p_with_error = (true_p * (1 - epsilon_)) + (err_base_prob * epsilon_)

        # Observation
        _p = p_with_error.reshape((-1, a))[:, 0]
        # Overdispersion term
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
        # TODO: Figure out how to also fit this term.
        # FIXME: Do I want the default to be a valid value?
        #  Realistic or close to asymptotic?
        alpha = pm.Data('alpha', value=1000)
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)

        observed = pm.Data('observed', value=np.empty((g * n, a)))
        pm.BetaBinomial('data',
                        alpha=_p * alpha,
                        beta=(1 - _p) * alpha,
                        n=observed.reshape((-1, a)).sum(1),
                        observed=observed[:,0])

    return model

def build_biallelic_model_discrete_dp(g, n, s):
    # EXPERIMENTAL: Observations overdispersed as a BetaBinom w/ concentrations
    # 10.
    a = 2

    with pm.Model() as model:
        pi_w = dirichlet_process_rv('pi_w', s)
        pi_alpha = 100
        pi = pm.Dirichlet('pi', a=pi_w * pi_alpha, shape=(n, s),
                          transform=stick_breaking,
                          testval=np.ones((n, s))/s  # Uniform
                          )
        pi_hyper = pm.Data('pi_hyper', value=0.0)
        pm.Potential('heterogeneity_penalty',
                     -(pm.math.sqrt(pi).sum(0).sum()**2) * pi_hyper)

        rho_hyper = pm.Data('rho_hyper', value=0.0)
        pm.Potential('diversity_penalty',
                     -(pm.math.sqrt(pi.sum(0)).sum()**2)
                     * rho_hyper)

        # Genotype
        gamma_ = pm.Bernoulli('gamma_', p=0.5, shape=(g * s, 1))
        gamma = pm.Deterministic('gamma', (pm.math.concatenate([gamma_, 1 - gamma_], axis=1)
                                           .reshape((g, s, a))))
        gamma_hyper = pm.Data('gamma_hyper', value=0.0)
        pm.Potential('ambiguity_penalty',
                     -((pm.math.sqrt(gamma).sum(2)**2).sum(0) * pi.sum(0)).sum(0)
                     * gamma_hyper)

        # Product of fraction and genotype
        true_p = pm.Deterministic('true_p', pm.math.dot(pi, gamma))

        # Sequencing error
        epsilon_hyper = pm.Data('epsilon_hyper', value=100)
        epsilon = pm.Beta('epsilon', alpha=2, beta=epsilon_hyper,
                          shape=n)
        epsilon_ = epsilon.reshape((n, 1, 1))
        err_base_prob = tt.ones((n, g, a)) / a
        p_with_error = (true_p * (1 - epsilon_)) + (err_base_prob * epsilon_)

        # Observation
        _p = p_with_error.reshape((-1, a))[:, 0]
        # Overdispersion term
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
        # TODO: Figure out how to also fit this term.
        # FIXME: Do I want the default to be a valid value?
        #  Realistic or close to asymptotic?
        alpha = pm.Data('alpha', value=1000)
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)

        observed = pm.Data('observed', value=np.empty((g * n, a)))
        pm.BetaBinomial('data',
                        alpha=_p * alpha,
                        beta=(1 - _p) * alpha,
                        n=observed.reshape((-1, a)).sum(1),
                        observed=observed[:,0])

    return model

def build_biallelic_model_fuzzy_hd(g, n, s):
    # EXPERIMENTAL: Observations overdispersed as a BetaBinom w/ concentrations
    # 10.
    a = 2

    with pm.Model() as model:
        # Fraction
        rho_hyper = pm.Normal('rho_hyper', sigma=1)  # pm.Data('rho_hyper', value=0.0)
        rho = pm.Dirichlet(
            'rho',
            a=tt.ones(s) * pm.math.exp(-rho_hyper),
            transform=stick_breaking,
            shape=(s,),
        )
        pi_hyper = pm.Normal('pi_hyper', sigma=1)  # pm.Data('pi_hyper', value=0.0)
        pi = pm.Dirichlet(
            'pi',
            a=rho * pm.math.exp(-pi_hyper),
            shape=(n, s),
            transform=stick_breaking,
        )

        # Genotype
        gamma_ = pm.Uniform('gamma_', 0, 1, shape=(g * s, 1))
        gamma = pm.Deterministic('gamma', (pm.math.concatenate([gamma_, 1 - gamma_], axis=1)
                                           .reshape((g, s, a))))
        gamma_hyper = pm.Data('gamma_hyper', value=0.0)
        pm.Potential('ambiguity_penalty',
                     -((pm.math.sqrt(gamma).sum(2)**2).sum(0) * pi.sum(0)).sum(0)
                     * gamma_hyper)

        # Product of fraction and genotype
        true_p = pm.Deterministic('true_p', pm.math.dot(pi, gamma))

        # Sequencing error
        epsilon_hyper = pm.Data('epsilon_hyper', value=100)
        epsilon = pm.Beta('epsilon', alpha=2, beta=epsilon_hyper,
                          shape=n)
        epsilon_ = epsilon.reshape((n, 1, 1))
        err_base_prob = tt.ones((n, g, a)) / a
        p_with_error = (true_p * (1 - epsilon_)) + (err_base_prob * epsilon_)

        # Observation
        _p = p_with_error.reshape((-1, a))[:, 0]
        # Overdispersion term
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
        # TODO: Figure out how to also fit this term.
        # FIXME: Do I want the default to be a valid value?
        #  Realistic or close to asymptotic?
        alpha = pm.Data('alpha', value=1000)
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)

        observed = pm.Data('observed', value=np.empty((g * n, a)))
        pm.BetaBinomial('data',
                        alpha=_p * alpha,
                        beta=(1 - _p) * alpha,
                        n=observed.reshape((-1, a)).sum(1),
                        observed=observed[:,0])

    return model

def build_biallelic_model_discrete_hd(g, n, s):
    # Discrete haplotypes.
    a = 2

    with pm.Model() as model:
        # Fraction
        rho_hyper = pm.Normal('rho_hyper', sigma=1)  # pm.Data('rho_hyper', value=0.0)
        rho = pm.Dirichlet(
            'rho',
            a=tt.ones(s) * pm.math.exp(-rho_hyper),
            transform=stick_breaking,
            shape=(s,),
        )
        pi_hyper = pm.Normal('pi_hyper', sigma=1)  # pm.Data('pi_hyper', value=0.0)
        pi = pm.Dirichlet(
            'pi',
            a=rho * pm.math.exp(-pi_hyper),
            shape=(n, s),
            transform=stick_breaking,
        )

        # Genotype
        gamma_ = pm.Bernoulli('gamma_', p=0.5, shape=(g * s, 1))
        gamma = pm.Deterministic('gamma', (pm.math.concatenate([gamma_, 1 - gamma_], axis=1)
                                           .reshape((g, s, a))))
        gamma_hyper = pm.Data('gamma_hyper', value=0.0)
        pm.Potential('ambiguity_penalty',
                     -((pm.math.sqrt(gamma).sum(2)**2).sum(0) * pi.sum(0)).sum(0)
                     * gamma_hyper)

        # Product of fraction and genotype
        true_p = pm.Deterministic('true_p', pm.math.dot(pi, gamma))

        # Sequencing error
        epsilon_hyper = pm.Data('epsilon_hyper', value=100)
        epsilon = pm.Beta('epsilon', alpha=2, beta=epsilon_hyper,
                          shape=n)
        epsilon_ = epsilon.reshape((n, 1, 1))
        err_base_prob = tt.ones((n, g, a)) / a
        p_with_error = (true_p * (1 - epsilon_)) + (err_base_prob * epsilon_)

        # Observation
        _p = p_with_error.reshape((-1, a))[:, 0]
        # Overdispersion term
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
        # TODO: Figure out how to also fit this term.
        # FIXME: Do I want the default to be a valid value?
        #  Realistic or close to asymptotic?
        alpha = pm.Data('alpha', value=1000)
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)

        observed = pm.Data('observed', value=np.empty((g * n, a)))
        pm.BetaBinomial('data',
                        alpha=_p * alpha,
                        beta=(1 - _p) * alpha,
                        n=observed.reshape((-1, a)).sum(1),
                        observed=observed[:,0])

    return model

In [None]:
from lib.genotype_mixture import simulate_pileup, pileup_to_model_input
from scripts.infer_strain_fractions import find_MAP_loop
import numpy as np
import scipy as sp
import pymc3 as pm

n = 10
dom_frac = 0.98
avg_frac = np.array([dom_frac, 1 - dom_frac])
frac_conc = 10
frac = sp.stats.dirichlet.rvs(alpha=avg_frac * frac_conc, size=n)

In [None]:
pi_plot(frac, pwr=1/2)
frac

In [None]:
m = 30
a = 2
g0 = 5  # Noisy positions
g1 = 50
g = g1 + g0
error_rate = 0.01
haplotype = np.array([[1] * g1 + [0.5] * g0,
                      [0] * g1 + [0.5] * g0]).T
haplotype = np.stack([haplotype, 1 - haplotype], axis=2)

pileup = simulate_pileup(haplotype, frac, np.ones((g, n)) * m , error_rate)

# Visualize
gamma_plot(haplotype)

y = pileup_to_model_input(pileup).swapaxes(0, 1)
y = y / y.sum(2, keepdims=True)
gamma_plot(y)

In [None]:
model_discrete = build_biallelic_model_discrete(g, n, 3)
model_discrete.observed.set_value(pileup_to_model_input(pileup).reshape((-1, a)))
#model_discrete.alpha.set_value(1e5)
model_discrete.gamma_hyper.set_value(1)
model_discrete.pi_hyper.set_value(0)
model_discrete.rho_hyper.set_value(1)
model_discrete.epsilon_hyper.set_value(200)


with model_discrete:
    trace_discrete = pm.sample(chains=1,
#                             tune=2500,
#                             draws=1000,
#                            discard_tuned_samples=False,
                               step=[pm.step_methods.BinaryGibbsMetropolis(vars=[model_discrete.gamma_],
                                                                           transit_p=0.5),
#                                      pm.step_methods.NUTS(vars=[model_discrete.epsilon, model_discrete.pi, model_discrete.pi_hyper, model_discrete.rho, model_discrete.rho_hyper], max_treedepth=6)
                                    ]
                              )

In [None]:
model_fuzzy = build_biallelic_model_fuzzy(g, n, 3)
model_fuzzy.observed.set_value(pileup_to_model_input(pileup).reshape((-1, a)))
#model_fuzzy.alpha.set_value(1e5)
model_fuzzy.gamma_hyper.set_value(1)
model_fuzzy.pi_hyper.set_value(0)
model_fuzzy.rho_hyper.set_value(1)
model_fuzzy.epsilon_hyper.set_value(200)

with model_fuzzy:
    trace_fuzzy = pm.sample(chains=1,
#                             tune=2500,
#                             draws=1000,
# #                            discard_tuned_samples=False,
#                             max_treedepth=7,
                           )

In [None]:
plt.plot(trace_discrete.epsilon.mean(0), label='discrete', color='green')
plt.plot(trace_fuzzy.epsilon.mean(0), label='fuzzy', color='blue')
# plt.plot(mapest_fuzzy['epsilon'], label='fuzzy_mapest', color='darkblue')
# plt.plot(trace_fuzzy_dp.epsilon.mean(0), label='fuzzy_dp', color='aqua')
# plt.plot(trace_discrete_dp.epsilon.mean(0), label='discrete_dp', color='lightgreen')


plt.axhline(1e-2, lw=1, linestyle='--', color='k')

plt.legend(bbox_to_anchor=(1, 1))

In [None]:
trace_stat_plot(trace_fuzzy, 'model_logp')
trace_stat_plot(trace_discrete, 'model_logp')
# trace_stat_plot(trace_fuzzy_dp, 'model_logp')
# trace_stat_plot(trace_discrete_dp, 'model_logp')

(
    ess(trace_fuzzy.model_logp),
    ess(trace_discrete.model_logp),
#     ess(trace_fuzzy_dp.model_logp),
#     ess(trace_discrete_dp.model_logp),
)

In [None]:
def select_permutation_greedy(data, ref, axis, key=None):
    index_swap = (0, axis)
    data = np.swapaxes(data, *index_swap)
    ref = np.swapaxes(ref, *index_swap)
    
    if key is None:
        key = lambda x, y: (np.abs(x - y)**2).sum()

    data_idx = np.arange(data.shape[0])
    ref_idx = np.arange(ref.shape[0])
    
    perm = []
    for i in ref_idx:
        best_data_idx = 0
        best_data_loss = np.inf
        for j in data_idx:
            if j in perm:
                continue
            loss = key(ref[i], data[j])
            if loss < best_data_loss:
                best_data_idx = j
                best_data_loss = loss
        perm.append(best_data_idx)
    
    for j in data_idx:
        if j not in perm:
            perm.append(j)
    
    return perm

In [None]:
_expect_fuzzy = np.median(trace_fuzzy['pi'], axis=0)
_expect_discrete = np.median(trace_discrete['pi'], axis=0)
# _map_fuzzy = mapest_fuzzy['pi']
# _expect_fuzzy_dp = np.median(trace_fuzzy_dp['pi'], axis=0)
# _expect_discrete_dp = np.median(trace_discrete_dp['pi'], axis=0)

permute_discrete = select_permutation_greedy(_expect_discrete, frac, axis=1)
permute_fuzzy = select_permutation_greedy(_expect_fuzzy, frac, axis=1)
# permute_fuzzy_dp = select_permutation_greedy(_expect_fuzzy_dp, frac, axis=1)
# permute_discrete_dp = select_permutation_greedy(_expect_discrete_dp, frac, axis=1)

#_map_dp = mapest_dp['pi']

pi_plot(frac, pwr=1/2)
pi_plot(_expect_discrete[:,permute_discrete], pwr=1/2)
# pi_plot(_expect_discrete_dp[:,permute_discrete_dp])
pi_plot(_expect_fuzzy[:,permute_fuzzy], pwr=1/2)
# pi_plot(_expect_fuzzy_dp[:,permute_fuzzy_dp])
# pi_plot(_map_fuzzy[:,permute_fuzzy], pwr=1/2)


In [None]:
_expect_fuzzy = trace_fuzzy['gamma'].mean(0)
_expect_discrete = trace_discrete['gamma'].mean(0)
# _map_fuzzy = mapest_fuzzy['gamma']
# _expect_fuzzy_dp = trace_fuzzy_dp['gamma'].mean(0)
# _expect_discrete_dp = trace_discrete_dp['gamma'].mean(0)

#_map_dp = mapest_dp['gamma']


gamma_plot(haplotype)
gamma_plot(_expect_discrete[:,permute_discrete])
# gamma_plot(_expect_discrete_dp[:,permute_discrete_dp])
gamma_plot(_expect_fuzzy[:,permute_fuzzy])
# gamma_plot(_expect_fuzzy_dp[:,permute_fuzzy_dp])

In [None]:
_true = np.concatenate([
    np.zeros((g, trace_fuzzy['gamma'].shape[2] - haplotype.shape[1], a)),
    haplotype
], axis=1)
_loss_f = lambda d: np.sqrt((np.square(d)).mean())

print('discrete', _loss_f((np.mean(trace_discrete.gamma[:,:,permute_discrete], axis=0) - _true)[:,1,0]))
print('fuzzy', _loss_f((np.mean(trace_fuzzy.gamma[:,:,permute_fuzzy], axis=0) - _true)[:,1,0]))
# print('fuzzy_discretized', _loss_f((np.mean(trace_fuzzy.gamma[:,:,permute_fuzzy] > 0.5, axis=0) - _true)[:,1,0]))
# print('fuzzy_mapest', _loss_f((mapest_fuzzy['gamma'][:,permute_fuzzy] - _true)[:,1,0]))
# print('fuzzy_dp', _loss_f((np.mean(trace_fuzzy_dp.gamma[:,:,permute_fuzzy_dp], axis=0) - _true)[:,1,0]))
# print('discrete_dp', _loss_f((np.mean(trace_discrete_dp.gamma[:,:,permute_discrete_dp], axis=0) - _true)[:,1,0]))


In [None]:
_true = np.concatenate([
    np.zeros((g, trace_fuzzy['gamma'].shape[2] - haplotype.shape[1], a)),
    haplotype
], axis=1)
_loss_f = lambda d: np.abs(d).mean()

print('discrete', _loss_f((np.mean(trace_discrete.gamma[:,:,permute_discrete], axis=0) - _true)[:,1,0]))
print('fuzzy', _loss_f((np.mean(trace_fuzzy.gamma[:,:,permute_fuzzy], axis=0) - _true)[:,1,0]))
# print('fuzzy_discretized', _loss_f((np.mean(trace_fuzzy.gamma[:,:,permute_fuzzy] > 0.5, axis=0) - _true)[:,1,0]))
# print('fuzzy_mapest', _loss_f((mapest_fuzzy['gamma'][:,permute_fuzzy] - _true)[:,1,0]))
# print('fuzzy_dp', _loss_f((np.mean(trace_fuzzy_dp.gamma[:,:,permute_fuzzy_dp], axis=0) - _true)[:,1,0]))
# print('discrete_dp', _loss_f((np.mean(trace_discrete_dp.gamma[:,:,permute_discrete_dp], axis=0) - _true)[:,1,0]))


In [None]:
j = 1
_true = frac[:,j]


fig, ax = plt.subplots(figsize=(5, 10))

ax.scatter(_true, range(n), color='k', marker='x', label='true')
#ax.scatter(mapest_fuzzy['pi'][:,permute_fuzzy][:,j], range(n), color='g', marker='.', label='fuzzy-mapest')


for pi, name, color, offset in [
                        (trace_fuzzy.pi[:,:,permute_fuzzy], 'fuzzy', 'blue', +0.1),
#                         (trace_fuzzy_dp.pi[:,:,permute_fuzzy_dp], 'fuzzy_dp', 'aqua', +0.2),
                        (trace_discrete.pi[:,:,permute_discrete], 'discrete', 'green', -0.1),
#                         (trace_discrete_dp.pi[:,:,permute_discrete_dp], 'discrete_dp', 'lightgreen', -0.2),
                       ]:
    ax.scatter(np.quantile(pi, 0.5, axis=0)[:,j], np.arange(n) + offset, color=color, marker='^', label=name)
    print(name, np.sqrt(np.mean(np.square(np.quantile(pi, 0.5, axis=0)[:,j] - _true))))
#    plt.scatter(np.mean(pi, axis=0)[:,0], range(n), color=color, marker='o')
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.25, axis=0)[:,j], np.quantile(pi, 0.75, axis=0)[:,j], lw=1, color=color)
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.05, axis=0)[:,j], np.quantile(pi, 0.95, axis=0)[:,j], lw=0.5, color=color)
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.01, axis=0)[:,j], np.quantile(pi, 0.99, axis=0)[:,j], lw=0.25, color=color)
ax.legend(bbox_to_anchor=(1.25, 1))
#ax.set_xscale('log')

In [None]:
j = 0
_true = frac[:,j]


fig, ax = plt.subplots(figsize=(5, 10))

ax.scatter(_true, range(n), color='k', marker='x', label='true')
#ax.scatter(mapest_fuzzy['pi'][:,permute_fuzzy][:,j], range(n), color='g', marker='.', label='fuzzy-mapest')


for pi, name, color, offset in [
                        (trace_fuzzy.pi[:,:,permute_fuzzy], 'fuzzy', 'blue', +0.1),
#                         (trace_fuzzy_dp.pi[:,:,permute_fuzzy_dp], 'fuzzy_dp', 'aqua', +0.2),
                        (trace_discrete.pi[:,:,permute_discrete], 'discrete', 'green', -0.1),
#                         (trace_discrete_dp.pi[:,:,permute_discrete_dp], 'discrete_dp', 'lightgreen', -0.2),
                       ]:
    ax.scatter(np.quantile(pi, 0.5, axis=0)[:,j], np.arange(n) + offset, color=color, marker='^', label=name)
    print(name, np.sqrt(np.mean(np.square(np.quantile(pi, 0.5, axis=0)[:,j] - _true))))
#    plt.scatter(np.mean(pi, axis=0)[:,0], range(n), color=color, marker='o')
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.25, axis=0)[:,j], np.quantile(pi, 0.75, axis=0)[:,j], lw=1, color=color)
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.05, axis=0)[:,j], np.quantile(pi, 0.95, axis=0)[:,j], lw=0.5, color=color)
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.01, axis=0)[:,j], np.quantile(pi, 0.99, axis=0)[:,j], lw=0.25, color=color)
ax.legend(bbox_to_anchor=(1.25, 1))
#ax.set_xscale('log')

# Prototype Desman

In [None]:
import desman

desman??

# Prototype Model Comparison

In [None]:
@dataclass
class SimulationReplicate:
    frac: np.ndarray
    haplo: np.ndarray
    pileup: np.ndarray
    model: pm.Model
    trace: pm.sampling.MultiTrace

In [None]:
dm_pileup.shape

In [None]:
# This cell (and specifically the last line) results in a segmentation fault.

from desman import HaploSNP_Sampler, Init_NMFT

prng = np.random.RandomState(1)

dm_pileup = (
    np.concatenate(
        [pileup_to_model_input(pileup),
         np.zeros_like(pileup_to_model_input(pileup))],
        axis=-1
    )
    .swapaxes(0, 1)
    .astype(int)
)

dm_init = Init_NMFT.Init_NMFT(dm_pileup, 2, prng)
dm_init.factorize()
tau_est = dm_init.get_tau()

dm_smplr = HaploSNP_Sampler.HaploSNP_Sampler(dm_pileup, 2, prng)
dm_smplr.tau = np.copy(dm_init.get_tau(), order='C')
dm_smplr.updateTauIndices()
dm_smplr.gamma = np.copy(dm_init.get_gamma(), order='C')
#dm_smplr.eta = np.copy(dm_init.eta, order='C')
dm_smplr.update()