In [None]:
# %load lib/genotype_mixture
#!/usr/bin/env python3


import numpy as np
import pymc3 as pm
from pymc3.distributions.transforms import t_stick_breaking, logodds
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from lib.util import info
from lib.pileup import list_bases, get_pileup_dims
import tqdm
import theano.tensor as tt
import theano.sparse as ts
from itertools import product

from lib.pymc3 import trace_stat_plot
from arviz import ess
from lib.genotype_mixture import gamma_plot, pi_plot

stick_breaking = t_stick_breaking(1e-10)

from lib.genotype_mixture import pileup_to_model_input
from infer_strain_fractions import find_MAP_loop
import numpy as np
import scipy as sp
import pymc3 as pm

import logging
pymc3_logger = logging.getLogger("pymc3")


def build_2strain_p(minor_frac):
    return np.array([1 - minor_frac, minor_frac])

def simulate_frac(n, p, alpha, random_state=None):
    return sp.stats.dirichlet.rvs(alpha=p * alpha, size=n, random_state=random_state)

def simulate_pileup(genotype, frac, cvrg, err_rate, random_state=None):
    # Modified from lib.genotype_mixture.simulate_pileup to include random_state option.
    g, s, a = genotype.shape
    n, s2 = frac.shape
    assert s == s2
    g, n2 = cvrg.shape
    assert n == n2

    frac_allele = frac @ genotype
    assert frac_allele.shape == (g, n, a)
    prob_allele = (frac_allele * (1 - err_rate)
                   + (1 - frac_allele) * (err_rate / (a - 1)))
    assert np.allclose(prob_allele.sum(2), 1)

    tally = np.empty_like(prob_allele)
    for i in range(g):
        for j in range(n):
            tally[i, j, :] = sp.stats.multinomial.rvs(
                cvrg[i, j],
                prob_allele[i, j, :],
                random_state=random_state,
            )
    assert (tally.sum(2) == cvrg).all()

    pileup = []
    for i in range(n):
        pileup.append(pd.DataFrame(tally[:, i, :]).stack())
    pileup = (pd.concat(pileup, axis=1)
              .unstack(1)
              .rename_axis(columns=('sample_id', 'base'), index='position'))
    return pileup_to_model_input(pileup)

def build_2strain_semi_informative_haplo(g1, g0):
    haplotype = np.array([[1] * g1 + [0.5] * g0,
                          [0] * g1 + [0.5] * g0]).T
    haplotype = np.stack([haplotype, 1 - haplotype], axis=2)
    return haplotype

def simulate_simple_pileup(haplo, frac, m, error_rate, random_state=None):
    g = haplo.shape[0]
    n = frac.shape[0]
    return simulate_pileup(haplo, frac, np.ones((g, n)) * m , error_rate, random_state=random_state)

def simulate_simple_2strain_semi_informative(n, minor_frac, frac_conc, g1, g0, cvrg, error_rate, random_state):
    expect_frac = build_2strain_p(minor_frac)
    frac = simulate_frac(n, expect_frac, frac_conc, random_state=random_state)
    haplotype = build_2strain_semi_informative_haplo(g1, g0)
    pileup = simulate_simple_pileup(haplotype, frac, cvrg, error_rate, random_state=random_state)
    return frac, haplotype, pileup

def build_biallelic_model_fuzzy(g, n, s):
    # EXPERIMENTAL: Observations overdispersed as a BetaBinom w/ concentrations
    # 10.
    a = 2

    with pm.Model() as model:
        # Fraction
        pi = pm.Dirichlet('pi', a=np.ones(s), shape=(n, s),
                          transform=stick_breaking,
                          )
        pi_hyper = pm.Data('pi_hyper', value=0.0)
        pm.Potential('heterogeneity_penalty',
                     -(pm.math.sqrt(pi).sum(0).sum()**2) * pi_hyper)

        rho_hyper = pm.Data('rho_hyper', value=0.0)
        pm.Potential('diversity_penalty',
                     -(pm.math.sqrt(pi.sum(0)).sum()**2)
                     * rho_hyper)

        # Genotype
        gamma_ = pm.Uniform('gamma_', 0, 1, shape=(g * s, 1))
        gamma = pm.Deterministic('gamma', (pm.math.concatenate([gamma_, 1 - gamma_], axis=1)
                                           .reshape((g, s, a))))
        gamma_hyper = pm.Data('gamma_hyper', value=0.0)
        pm.Potential('ambiguity_penalty',
                     -((pm.math.sqrt(gamma).sum(2)**2).sum(0) * pi.sum(0)).sum(0)
                     * gamma_hyper)

        # Product of fraction and genotype
        true_p = pm.Deterministic('true_p', pm.math.dot(pi, gamma))

        # Sequencing error
        epsilon_hyper = pm.Data('epsilon_hyper', value=100)
        epsilon = pm.Beta('epsilon', alpha=2, beta=epsilon_hyper,
                          shape=n)
        epsilon_ = epsilon.reshape((n, 1, 1))
        err_base_prob = tt.ones((n, g, a)) / a
        p_with_error = (true_p * (1 - epsilon_)) + (err_base_prob * epsilon_)

        # Observation
        observed = pm.Data('observed', value=np.empty((g * n, a)))

        _p = p_with_error.reshape((-1, a))[:, 0]
        # Overdispersion term
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
        # TODO: Figure out how to also fit this term.
        # FIXME: Do I want the default to be a valid value?
        #  Realistic or close to asymptotic?
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
#         alpha = pm.Data('alpha', value=1000)
#         pm.BetaBinomial('data',
#                         alpha=_p * alpha,
#                         beta=(1 - _p) * alpha,
#                         n=observed.reshape((-1, a)).sum(1),
#                         observed=observed[:,0])

        # FIXME: This may not work as well as the
        # highly concentrated BetaBinomial above.
        pm.Binomial('data',
                    p=_p,
                    n=observed.reshape((-1, a)).sum(1),
                    observed=observed[:,0])

    return model

def build_biallelic_model_discrete(g, n, s):
    # Discrete haplotypes.
    a = 2

    with pm.Model() as model:
        # Fraction
        pi = pm.Dirichlet('pi', a=np.ones(s), shape=(n, s),
                          transform=stick_breaking,
                          )
        pi_hyper = pm.Data('pi_hyper', value=0.0)
        pm.Potential('heterogeneity_penalty',
                     -(pm.math.sqrt(pi).sum(0).sum()**2) * pi_hyper)

        rho_hyper = pm.Data('rho_hyper', value=0.0)
        pm.Potential('diversity_penalty',
                     -(pm.math.sqrt(pi.sum(0)).sum()**2)
                     * rho_hyper)

        # Genotype
        gamma_ = pm.Bernoulli('gamma_', p=0.5, shape=(g * s, 1))
        gamma = pm.Deterministic('gamma', (pm.math.concatenate([gamma_, 1 - gamma_], axis=1)
                                           .reshape((g, s, a))))
        gamma_hyper = pm.Data('gamma_hyper', value=0.0)
        pm.Potential('ambiguity_penalty',
                     -((pm.math.sqrt(gamma).sum(2)**2).sum(0) * pi.sum(0)).sum(0)
                     * gamma_hyper)

        # Product of fraction and genotype
        true_p = pm.Deterministic('true_p', pm.math.dot(pi, gamma))

        # Sequencing error
        epsilon_hyper = pm.Data('epsilon_hyper', value=100)
        epsilon = pm.Beta('epsilon', alpha=2, beta=epsilon_hyper,
                          shape=n)
        epsilon_ = epsilon.reshape((n, 1, 1))
        err_base_prob = tt.ones((n, g, a)) / a
        p_with_error = (true_p * (1 - epsilon_)) + (err_base_prob * epsilon_)

        # Observation
        observed = pm.Data('observed', value=np.empty((g * n, a)))

        _p = p_with_error.reshape((-1, a))[:, 0]
        # Overdispersion term
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
        # TODO: Figure out how to also fit this term.
        # FIXME: Do I want the default to be a valid value?
        #  Realistic or close to asymptotic?
        # alpha = pm.Gamma('alpha', mu=100, sigma=5)
#         alpha = pm.Data('alpha', value=1000)
#         pm.BetaBinomial('data',
#                         alpha=_p * alpha,
#                         beta=(1 - _p) * alpha,
#                         n=observed.reshape((-1, a)).sum(1),
#                         observed=observed[:,0])

        # FIXME: This may not work as well as the
        # highly concentrated BetaBinomial above.
        pm.Binomial('data',
                    p=_p,
                    n=observed.reshape((-1, a)).sum(1),
                    observed=observed[:,0])

    return model

def select_permutation_greedy(data, ref, axis, key=None):
    index_swap = (0, axis)
    data = np.swapaxes(data, *index_swap)
    ref = np.swapaxes(ref, *index_swap)
    
    if key is None:
        key = lambda x, y: (np.abs(x - y)**2).sum()

    data_idx = np.arange(data.shape[0])
    ref_idx = np.arange(ref.shape[0])
    
    perm = []
    for i in ref_idx:
        best_data_idx = 0
        best_data_loss = np.inf
        for j in data_idx:
            if j in perm:
                continue
            loss = key(ref[i], data[j])
            if loss < best_data_loss:
                best_data_idx = j
                best_data_loss = loss
        perm.append(best_data_idx)
    
    for j in data_idx:
        if j not in perm:
            perm.append(j)
    
    return perm

def sample_both_models(pileup, s, gamma_hyper, pi_hyper, rho_hyper, epsilon_hyper, chains=5, **kwargs):
    n, g, a = pileup.shape
    model_disc = build_biallelic_model_discrete(g, n, s)
    model_disc.observed.set_value(pileup.reshape((-1, a)))
    model_disc.gamma_hyper.set_value(gamma_hyper)  # Does not affect discrete model
    model_disc.pi_hyper.set_value(pi_hyper)
    model_disc.rho_hyper.set_value(rho_hyper)
    model_disc.epsilon_hyper.set_value(epsilon_hyper)
    
    pymc3_logger.setLevel(logging.ERROR)
    
    with model_disc:
        trace_disc = pm.sample(
            chains=chains, cores=chains,
            step=[
                pm.step_methods.BinaryGibbsMetropolis(
                    vars=[model_disc.gamma_],
                    transit_p=0.75
                ),
            ],
            **kwargs
        )
        
    model_fuzz = build_biallelic_model_fuzzy(g, n, s)
    model_fuzz.observed.set_value(pileup.reshape((-1, a)))
    model_fuzz.gamma_hyper.set_value(gamma_hyper)
    model_fuzz.pi_hyper.set_value(pi_hyper)
    model_fuzz.rho_hyper.set_value(rho_hyper)
    model_fuzz.epsilon_hyper.set_value(epsilon_hyper)
    with model_fuzz:
        trace_fuzz = pm.sample(
            chains=chains, cores=chains,
            **kwargs
        )

    pymc3_logger.setLevel(logging.INFO)

    return trace_disc, trace_fuzz

def score_sampled_frac_estimate(trace, true, i):
    expect = np.mean(trace['pi'], axis=0)
    permute = select_permutation_greedy(expect, true, axis=1)
    expect_i = expect[:, permute[i]]
    return (np.mean((expect_i - true[:,i])**2))**(1/2)

In [None]:
trace_stat_plot(trace_discrete, 'model_logp')
pi_chain = np.concatenate(list(p['pi'] for p in trace_discrete.points(chains=[0])))
pi_chain

In [None]:
n = 10
minor_frac = 0.05
frac_conc = 10
g1 = 50  # Informative positions
error_rate = 0.01


results = []
for s_fit in [2, 3]:
    for g0 in [0, 1, 5]:
        for minor_frac in [0.4, 0.2, 0.05, 0.02, 0.01]:
            for m in [50, 10, 2]:
                for sim_seed in [5]:
                    random_state = np.random.RandomState(sim_seed)
                    frac, haplo, pileup = simulate_simple_2strain_semi_informative(
                        n=n,
                        minor_frac=minor_frac,
                        frac_conc=frac_conc,
                        g1=g1,
                        g0=g0,
                        cvrg=m,
                        error_rate=error_rate,
                        random_state=random_state,
                    )
                    trace_disc, trace_fuzz = sample_both_models(
                        pileup,
                        s=s_fit,
                        gamma_hyper=1,
                        pi_hyper=0,
                        rho_hyper=1,
                        epsilon_hyper=200,
                        random_seed=1,
                        chains=8,
                        progressbar=False,
                        compute_convergence_checks=False,
                    )
                    results.append((
                        s_fit,
                        g0,
                        minor_frac,
                        m,
                        sim_seed,
                        score_sampled_frac_estimate(trace_disc, frac, 1),
                        score_sampled_frac_estimate(trace_fuzz, frac, 1),
                    ))
                    print(results[-1])

In [None]:
trace_stat_plot(trace_disc, 'model_logp')

def get_single_chain_values(trace, chains, var):
    return np.stack(
        list(
            p[var]
            for p
            in trace.points(chains=chains)
        )
    )

for chain in range(1):
    plt.plot(get_single_chain_values(trace_disc, chains=[chain], var='pi')[:,0,1])

In [None]:
trace_disc.get_sampler_stats('model_logp')

In [None]:
def pick_chain_with_best_terminus(trace):
    trace.model_logp[]

In [None]:
from infer_strain_fractions import find_MAP_loop_retry

find_MAP_loop_retry??

In [None]:
plt.plot(trace_disc.epsilon.mean(0), label='discrete', color='green')
plt.plot(trace_fuzz.epsilon.mean(0), label='fuzzy', color='blue')
# plt.plot(mapest_fuzzy['epsilon'], label='fuzzy_mapest', color='darkblue')
# plt.plot(trace_fuzzy_dp.epsilon.mean(0), label='fuzzy_dp', color='aqua')
# plt.plot(trace_discrete_dp.epsilon.mean(0), label='discrete_dp', color='lightgreen')


plt.axhline(1e-2, lw=1, linestyle='--', color='k')

plt.legend(bbox_to_anchor=(1, 1))

In [None]:
trace_stat_plot(trace_fuzz, 'model_logp')
trace_stat_plot(trace_disc, 'model_logp')

(
    ess(trace_fuzz.model_logp),
    ess(trace_disc.model_logp),
)

In [None]:
pi_plot(frac, pwr=1/2)
gamma_plot(haplo)
y = pileup.swapaxes(0, 1)
y = y / y.sum(2, keepdims=True)
gamma_plot(y)

In [None]:
_expect_fuzz = np.median(trace_fuzz['pi'], axis=0)
_expect_disc = np.median(trace_disc['pi'], axis=0)

permute_disc = select_permutation_greedy(_expect_disc, frac, axis=1)
permute_fuzz = select_permutation_greedy(_expect_fuzz, frac, axis=1)

pi_plot(frac, pwr=1/2)
pi_plot(_expect_disc[:,permute_disc], pwr=1/2)
pi_plot(_expect_fuzz[:,permute_fuzz], pwr=1/2)

In [None]:
_expect_fuzz = trace_fuzz['gamma'].mean(0)
_expect_disc = trace_disc['gamma'].mean(0)

gamma_plot(haplo)
gamma_plot(_expect_disc[:,permute_disc])
gamma_plot(_expect_fuzz[:,permute_fuzz])

In [None]:
j = 1
_true = frac[:,j]

fig, ax = plt.subplots(figsize=(5, 10))

ax.scatter(_true, range(n), color='k', marker='x', label='true')

for pi, name, color, offset in [
                        (trace_fuzz.pi[:,:,permute_fuzz], 'fuzz', 'blue', +0.1),
                        (trace_disc.pi[:,:,permute_disc], 'disc', 'green', -0.1),
                       ]:
    ax.scatter(np.quantile(pi, 0.5, axis=0)[:,j], np.arange(n) + offset, color=color, marker='^', label=name)
    print(name, np.sqrt(np.mean(np.square(np.quantile(pi, 0.5, axis=0)[:,j] - _true))))
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.25, axis=0)[:,j], np.quantile(pi, 0.75, axis=0)[:,j], lw=1, color=color)
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.05, axis=0)[:,j], np.quantile(pi, 0.95, axis=0)[:,j], lw=0.5, color=color)
    ax.hlines(np.arange(n) + offset, np.quantile(pi, 0.01, axis=0)[:,j], np.quantile(pi, 0.99, axis=0)[:,j], lw=0.25, color=color)
ax.legend(bbox_to_anchor=(1.25, 1))
#ax.set_xscale('log')