In [None]:
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
from tqdm import tqdm
import xarray as xr

def rss(x, y):
    return np.sqrt(np.sum((x - y)**2))

def binary_entropy(p):
    q = (1 - p)
    return -p * np.log2(p) - q * np.log2(q)

def plot_loss_history(loss_history):
    min_loss = loss_history.min()
    plt.plot(loss_history - min_loss)
    plt.plot(
        np.linspace(0, len(loss_history), num=1000),
        np.linspace(len(loss_history), 0, num=1000),
        lw=1, linestyle='--', color='grey'
        )
    plt.title(f'+{min_loss}')
    plt.yscale('log')
    return plt.gca()

def mean_residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    out = np.abs(((frac_obs - expect_frac)))
    out[np.isnan(out)] = 0
    return (out * m).sum() / m.sum()

In [None]:
def model(
    s,
    m,
    y=None,
    gamma_hyper=1.,
    pi0=1.,
    rho0=1.,
    epsilon0=0.01,
    alpha0=1000.,
    dtype=torch.float32,
    device='cpu',
):
    
    # Cast inputs and set device
    m, gamma_hyper, pi0, rho0, epsilon0, alpha0 = [
        torch.tensor(v, dtype=dtype, device=device)
        for v in [m, gamma_hyper, pi0, rho0, epsilon0, alpha0]
    ]
    if y is not None:
        y = torch.tensor(y)
    
    n, g = m.shape
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.Beta(gamma_hyper, gamma_hyper)
            )
    # gamma.shape == (s, g)
    
    rho_hyper = pyro.sample('rho_hyper', dist.Gamma(rho0, 1.))
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s, dtype=dtype, device=device) * rho_hyper))
    
    epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon0))
    alpha_hyper = pyro.sample('alpha_hyper', dist.Gamma(alpha0, 1.))
    pi_hyper = pyro.sample('pi_hyper', dist.Gamma(pi0, 1.))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * s * pi_hyper))
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1) 
    # pi.shape == (n, s)
    # alpha.shape == epsilon.shape == (n,)
    
    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    # p.shape == (n, g)

        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ),
        obs=y
    )
    # y.shape == (n, g)
    return y

def conditioned_model(
    model,
    data={},
    dtype=torch.float32,
    device='cpu',
    **kwargs,
):
    data = {
        k: torch.tensor(v, dtype=dtype, device=device)
        for k, v in data.items()
    }
    return partial(
        pyro.condition(
            model,
            data=data
        ),
        dtype=dtype, device=device,
        **kwargs,
    )

def find_map(
    model,
    lag=10,
    stop_at=1.0,
    max_iter=int(1e5),
    learning_rate = 1e-0,
    clip_norm=100.,
):
    guide = pyro.infer.autoguide.AutoLaplaceApproximation(model)
    svi = pyro.infer.SVI(
        model,
        guide,
        pyro.optim.Adamax(
            optim_args={"lr": learning_rate},
            clip_args={"clip_norm": clip_norm}
        ),
        loss=pyro.infer.JitTrace_ELBO()
    )
    
    pyro.clear_param_store()
    pbar = tqdm(range(max_iter), position=0, leave=True)
    history = []
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                break

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if (i % 1 == 0):
                if i < 2:
                    pbar.set_postfix({
                        'ELBO': history[-1],
                    })
                elif i < lag + 1:
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta_1': history[-2] - history[-1],
                    })
                else:
                    delta_lag = (history[-lag] - history[-1]) / lag
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta_1': history[-2] - history[-1],
                        f'delta_{lag}': delta_lag
                    })
                    if delta_lag < stop_at:
                        info('Optimization converged')
                        break
    except KeyboardInterrupt:
        info('Optimization interrupted')
    pbar.refresh()
    
    # Gather MAP from parameter-store
    mapest = {
        k: v.detach().cpu().numpy().squeeze()
        for k, v
        in pyro.infer.Predictive(
            model, guide=guide, num_samples=1
        )().items()
    }
    return mapest, np.array(history)

  
def mean_residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    out = np.abs(((frac_obs - expect_frac)))
    out[np.isnan(out)] = 0
    return (out * m).sum() / m.sum()

In [None]:
data = xr.open_dataarray('data/core/100022/gtpro.nc').squeeze().sum('read')
data.sizes

In [None]:
minor_allele_incid = (data > 0).mean('library_id').min('allele')

thresh = 0.01

plt.hist(minor_allele_incid, bins=100)
plt.axvline(thresh, lw=1, linestyle='--', c='k')

informative_positions = idxwhere(minor_allele_incid.to_series() > thresh)

print(len(informative_positions), (minor_allele_incid.to_series() > thresh).mean())

In [None]:
np.random.seed(1)

# Samples with >5% of informative positions covered
suff_cvrg_samples = (data.sel(position=informative_positions).sum(['allele']) > 0).mean('position') > 0.05
npos = 2000
npos_out = 2000
position_ss_ = np.random.choice(
    informative_positions,
    size=npos + npos_out,
    replace=False
    )
position_ss, position_ss_out = position_ss_[:npos], position_ss_[npos:]

In [None]:
# Build m, y matrices from data.
_data = data.sel(library_id=suff_cvrg_samples, position=position_ss)
m = _data.sum('allele')
n, g = m.shape
y_obs = _data.sel(allele='alt')

s = 3000
model_fit = conditioned_model(
    model,
    data=dict(
        alpha=np.ones(n) * 100,
        epsilon_hyper=0.01,
        pi_hyper=1e-1 / s,
        rho_hyper=1e0,
        y=y_obs.values,
    ),
    s=s,
    m=m.values,
    gamma_hyper=1e-2,
    dtype=torch.float32,
    device='cuda',
)

# trace = pyro.poutine.trace(model_fit).get_trace()
# trace.compute_log_prob()
# print(trace.format_shapes())

In [None]:
mapest, history = find_map(model_fit, lag=10, stop_at=10., learning_rate=2e-1, max_iter=int(1e4), clip_norm=100.)

## Script

In [None]:
%%writefile 'scripts/strain_facts.py'
#!/usr/bin/env python3

import sys
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
from tqdm import tqdm
import xarray as xr

import argparse

def rss(x, y):
    return np.sqrt(np.sum((x - y)**2))

def binary_entropy(p):
    q = (1 - p)
    return -p * np.log2(p) - q * np.log2(q)

def plot_loss_history(loss_history):
    min_loss = loss_history.min()
    plt.plot(loss_history - min_loss)
    plt.plot(
        np.linspace(0, len(loss_history), num=1000),
        np.linspace(len(loss_history), 0, num=1000),
        lw=1, linestyle='--', color='grey'
        )
    plt.title(f'+{min_loss}')
    plt.yscale('log')
    return plt.gca()

def mean_residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    out = np.abs(((frac_obs - expect_frac)))
    out[np.isnan(out)] = 0
    return (out * m).sum() / m.sum()

def model(
    s,
    m,
    y=None,
    gamma_hyper=1.,
    pi0=1.,
    rho0=1.,
    epsilon0=0.01,
    alpha0=1000.,
    dtype=torch.float32,
    device='cpu',
):
    
    # Cast inputs and set device
    m, gamma_hyper, pi0, rho0, epsilon0, alpha0 = [
        torch.tensor(v, dtype=dtype, device=device)
        for v in [m, gamma_hyper, pi0, rho0, epsilon0, alpha0]
    ]
    if y is not None:
        y = torch.tensor(y)
    
    n, g = m.shape
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.Beta(gamma_hyper, gamma_hyper)
            )
    # gamma.shape == (s, g)
    
    rho_hyper = pyro.sample('rho_hyper', dist.Gamma(rho0, 1.))
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s, dtype=dtype, device=device) * rho_hyper))
    
    epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon0))
    alpha_hyper = pyro.sample('alpha_hyper', dist.Gamma(alpha0, 1.))
    pi_hyper = pyro.sample('pi_hyper', dist.Gamma(pi0, 1.))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * s * pi_hyper))
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1) 
    # pi.shape == (n, s)
    # alpha.shape == epsilon.shape == (n,)
    
    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    # p.shape == (n, g)

        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ),
        obs=y
    )
    # y.shape == (n, g)
    return y

def conditioned_model(
    model,
    data={},
    dtype=torch.float32,
    device='cpu',
    **kwargs,
):
    data = {
        k: torch.tensor(v, dtype=dtype, device=device)
        for k, v in data.items()
    }
    return partial(
        pyro.condition(
            model,
            data=data
        ),
        dtype=dtype, device=device,
        **kwargs,
    )

def find_map(
    model,
    lag=10,
    stop_at=1.0,
    max_iter=int(1e5),
    learning_rate = 1e-0,
    clip_norm=100.,
):
    guide = pyro.infer.autoguide.AutoLaplaceApproximation(model)
    svi = pyro.infer.SVI(
        model,
        guide,
        pyro.optim.Adamax(
            optim_args={"lr": learning_rate},
            clip_args={"clip_norm": clip_norm}
        ),
        loss=pyro.infer.JitTrace_ELBO()
    )
    
    pyro.clear_param_store()
    pbar = tqdm(range(max_iter), position=0, leave=True)
    history = []
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                break

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if (i % 1 == 0):
                if i < 2:
                    pbar.set_postfix({
                        'ELBO': history[-1],
                    })
                elif i < lag + 1:
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta_1': history[-2] - history[-1],
                    })
                else:
                    delta_lag = (history[-lag] - history[-1]) / lag
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta_1': history[-2] - history[-1],
                        f'delta_{lag}': delta_lag
                    })
                    if delta_lag < stop_at:
                        info('Optimization converged')
                        break
    except KeyboardInterrupt:
        info('Optimization interrupted')
    pbar.refresh()
    
    # Gather MAP from parameter-store
    mapest = {
        k: v.detach().cpu().numpy().squeeze()
        for k, v
        in pyro.infer.Predictive(
            model, guide=guide, num_samples=1
        )().items()
    }
    return mapest, np.array(history)

  
def mean_residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    out = np.abs(((frac_obs - expect_frac)))
    out[np.isnan(out)] = 0
    return (out * m).sum() / m.sum()


def parse_args(argv):
    p = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    # Input
    p.add_argument(
        "pileup",
        help="""
Single, fully processed, pileup table in NetCDF format with the following dimensions:
    * library_id
    * position
    * read
    * allele
                        """,
    )

    # Shape of the model
    p.add_argument("--nstrains", metavar="INT", type=int, default=1000)
    p.add_argument(
        "--npos",
        metavar="INT",
        default=2000,
        type=int,
        help=("Number of positions to sample for model fitting."),
    )
    
    # Data filtering
    p.add_argument("--incid-thresh",
                   metavar="FLOAT",
                   type=float,
                   default=0.02,
                   help=("Minimum fraction of samples that must have the minor allele "
                         "for the position to be considered 'informative'."),
    )
    p.add_argument("--cvrg-thresh",
                   metavar="FLOAT",
                   type=float,
                   default=0.5,
                   help=("Minimum fraction of 'informative' positions with counts "
                         "necessary for sample to be included."),
    )

    # Regularization
    p.add_argument(
        "--pi-hyper",
        metavar="FLOAT",
        default=1e-1,
        type=float,
        help=("Heterogeneity regularization parameter (will be scaled by 1 / s)."),
    )
    p.add_argument(
        "--rho-hyper",
        metavar="FLOAT",
        default=1e0,
        type=float,
        help=("Diversity regularization parameter."),
    )
    p.add_argument(
        "--gamma-hyper",
        metavar="FLOAT",
        default=1e-2,
        type=float,
        help=("Ambiguity regularization parameter."),
    )
    p.add_argument("--epsilon-hyper", metavar="FLOAT", default=0.01, type=float)
    p.add_argument(
        "--alpha",
        metavar="FLOAT",
        default=100.,
        type=float,
        help=('Concentration parameter of BetaBinomial observation.'),
    )

    # Fitting
    p.add_argument("--random-seed", default=0, type=int, help=("TODO"))
    p.add_argument("--max-iter", default=10000, type=int, help=("TODO"))
    p.add_argument("--lag", default=50, type=int, help=("TODO"))
    p.add_argument("--stop-at", default=5., type=float, help=("TODO"))
    p.add_argument("--learning-rate", default=1e-0, type=float, help=("TODO"))
    p.add_argument("--clip-norm", default=100., type=float, help=("TODO"))

    # Hardware
    p.add_argument("--device", default='cpu', help=("PyTorch device name."))

    # Output
    p.add_argument(
        "--outpath",
        metavar="PATH",
        help=("Path for genotype fraction output."),
    )
    
    args = p.parse_args(argv)
    
    if args.outpath == None:
        args.outpath = args.pileup + '_strain-facts.nc'
    
    return args

if __name__ == "__main__":
#     warnings.filterwarnings(
#         "ignore", category=UserWarning, module="pymc3.sampling", lineno=566
#     )

    args = parse_args(sys.argv[1:])
    info(args)

    info(f'Setting random seed: {args.random_seed}')
    np.random.seed(args.random_seed)
    
    info('Loading input data.')
    data = xr.open_dataarray(args.pileup).squeeze()
    info(f'Input data shape: {data.sizes}.')
    data = data.sum('read')
    
    info('Filtering positions.')
    minor_allele_incid = (data > 0).mean('library_id').min('allele')
    informative_positions = idxwhere(
        minor_allele_incid.to_series() > args.incid_thresh
    )
    npos_available = len(informative_positions)
    info(f'Found {npos_available} informative positions with minor '
         f'allele incidence of >{args.incid_thresh}')
    npos = min(args.npos, npos_available)
    info(f'Randomly sampling {npos} positions.')
    position_ss = np.random.choice(
        informative_positions,
        size=npos,
        replace=False,
        )
    
    info('Filtering libraries.')
    suff_cvrg_samples = idxwhere(
        ((data.sel(position=informative_positions).sum(['allele']) > 0)
        .mean('position') > args.cvrg_thresh).to_series()
    )
    nlibs = len(suff_cvrg_samples)
    info(f'Found {nlibs} libraries with >{args.cvrg_thresh:0.1%} '
         f'of informative positions covered.')
    
    info('Building conditioned model.')
    data_fit = data.sel(library_id=suff_cvrg_samples, position=position_ss)
    m = data_fit.sum('allele')
    n, g = m.shape
    y_obs = data_fit.sel(allele='alt')
    s = args.nstrains
    model_fit = conditioned_model(
        model,
        data=dict(
            alpha=np.ones(n) * args.alpha,
            epsilon_hyper=args.epsilon_hyper,
            pi_hyper=args.pi_hyper / s,
            rho_hyper=args.rho_hyper,
            y=y_obs.values,
        ),
        s=s,
        m=m.values,
        gamma_hyper=args.gamma_hyper,
        dtype=torch.float32,
        device=args.device,
    )
    
    info('Fitting model.')
    mapest, history = find_map(
        model_fit,
        lag=args.lag,
        stop_at=args.stop_at,
        learning_rate=args.learning_rate,
        max_iter=args.max_iter,
        clip_norm=args.clip_norm,
    )

    result = xr.Dataset(
        {
            'gamma': (['strain', 'position'], mapest['gamma']),
            'rho': (['strain'], mapest['rho']),
            'alpha_hyper': ([], mapest['alpha_hyper']),
            'pi': (['library_id', 'strain'], mapest['pi']),
            'epsilon': (['library_id'], mapest['epsilon']),
            'rho_hyper': ([], mapest['rho_hyper']),
            'epsilon_hyper': ([], mapest['epsilon_hyper']),
            'pi_hyper': ([], mapest['pi_hyper']),
            'alpha': (['library_id'], mapest['alpha']),
            'p_noerr': (['library_id', 'position'], mapest['p_noerr']),
            'p': (['library_id', 'position'], mapest['p']),
            'y': (['library_id', 'position'], y_obs),
            'm': (['library_id', 'position'], m),
            'elbo_trace': (['iteration'], history),
        },
        coords=dict(strain=np.arange(s), position=data_fit.position, library_id=data_fit.library_id),
    )
    
    result.to_netcdf(
        args.outpath,
        encoding=dict(
            gamma=dict(dtype='float32', zlib=True, complevel=6),
            pi=dict(dtype='float32', zlib=True, complevel=6),
            p_noerr=dict(dtype='float32', zlib=True, complevel=6),
            p=dict(dtype='float32', zlib=True, complevel=6),
            y=dict(dtype='uint16', zlib=True, complevel=6),
            m=dict(dtype='uint16', zlib=True, complevel=6),
            elbo_trace=dict(dtype='float32', zlib=True, complevel=6),
        )
    )
    

In [None]:
%run scripts/strain_facts.py --device cuda gtpro.nc --learning-rate 2e-1 --stop-at 10000

In [None]:
result2 = xr.load_dataset('gtpro.nc_strain-facts.nc')

In [None]:
pi_fit = pd.DataFrame(mapest['pi'], index=_data.library_id)
gamma_fit = pd.DataFrame(mapest['gamma'], columns=_data.position)

In [None]:
plt.plot(pi_fit.max(1).sort_values(ascending=False).values)
plt.axhline(1.0, c='k', lw=1, linestyle='--')

In [None]:
plt.plot(pi_fit.max(0).sort_values(ascending=False).values)
plt.axhline(1.0, c='k', lw=1, linestyle='--')

In [None]:
plt.plot(pi_fit.sum(0).sort_values(ascending=False).values)
plt.plot((pi_fit > 0.15).sum(0).sort_values(ascending=False).values)

In [None]:
plt.hist(mapest['alpha'], bins=100)
None

In [None]:
plt.hist(mapest['epsilon'], bins=50)
None

In [None]:
plt.scatter((pi_fit.T * m.mean('position')).sum(1), binary_entropy(gamma_fit).mean(1), s=1)
plt.ylabel('strain-entropy')
plt.xlabel('estimated-total-coverage')
plt.xlim(-1, 10)

In [None]:
high_cvrg_samples = idxwhere(((data.sel(position=informative_positions).sum(['allele']) > 0).mean('position') > 0.5).to_series())

In [None]:
low_diversity_samples = idxwhere(pi_fit.max(1).loc[high_cvrg_samples].sort_values() > 0.98)
high_diversity_samples = idxwhere(pi_fit.max(1).loc[high_cvrg_samples].sort_values() < 0.5)

len(low_diversity_samples), len(high_diversity_samples)

In [None]:
fig, axs = plt.subplots(2, sharey=True, sharex=True)

axs[1].set_yscale('log')

high_coverage_libraries_sorted_by_max_strain_fraction = pi_fit.max(1).loc[high_cvrg_samples].sort_values().index


for library_id in high_coverage_libraries_sorted_by_max_strain_fraction[:5]:
    d = data.sel(library_id=library_id)
    d = (d / d.sum('allele')).dropna('position').max('allele')
    axs[0].hist(d, bins=np.linspace(0.5, 0.999, num=11), density=False, alpha=0.5, color='black')
    
for library_id in high_coverage_libraries_sorted_by_max_strain_fraction[-5:]:
    d = data.sel(library_id=library_id)
    d = (d / d.sum('allele')).dropna('position').max('allele')
    axs[1].hist(d, bins=np.linspace(0.5, 0.999, num=11), density=False, alpha=0.5, color='black')
    


In [None]:
frac_obs = y_obs.numpy() / m.numpy()
frac_obs_ = frac_obs.copy()
frac_obs_[np.isnan(frac_obs_)] = 0.5

frac_expect = (mapest['p_noerr'].squeeze()) #* m.numpy()

print(np.abs(((frac_obs_ - frac_expect) * m.numpy())).sum().sum() / m.numpy().sum())

#fig = plt.figure(figsize=(10, 10))
#sns.heatmap(frac_obs[:,:], cmap='coolwarm', cbar=False, vmin=0, vmax=1)

In [None]:
drop_taxa = (pi_fit.max(0) < 0.01)
drop_taxa.sum()

In [None]:
sns.heatmap(gamma_fit.loc[drop_taxa].T, vmin=0, vmax=1, cmap='coolwarm')

In [None]:
# Build m, y matrices from data, summing over both reads.
_data = data[high_cvrg_samples, :].astype('float32')
m = torch.tensor(_data.sum(['read', 'allele']).values)
n, g = m.shape
y_obs = torch.tensor(_data.sum('read').sel(allele='alt').values)


# Build fully-conditioned model.
s = 1500
model_geno = partial(
    pyro.condition(
        model,
        data={
#           'alpha_hyper': torch.tensor(300.),
          'alpha': torch.ones(n) * 10.,
          'epsilon_hyper': torch.tensor(0.01),
#           'pi_hyper': torch.tensor(1e-1 / s),
#           'rho_hyper': torch.tensor(1e0),
#           'epsilon': torch.ones(n) * 0.001,
#           'rho': torch.ones(s) / s,
           'pi': torch.tensor(mapest['pi']),
           'y': y_obs,
        }
    ),
    s=s,
    m=m,
    gamma_hyper=torch.tensor(1e-0),
#     pi0=torch.tensor(1e-1),
#    rho0=torch.tensor(1e-1),
#    alpha0=torch.tensor(100.),  # These two params have no effect IF we condition
#    epsilon0=torch.tensor(0.01),  #  on epsilon_hyper and alpha_hyper
)

# trace = pyro.poutine.trace(model_fit).get_trace()
# trace.compute_log_prob()
# print(trace.format_shapes())

In [None]:
mapest_geno, history_geno = find_map(model_geno)

In [None]:
plot_loss_history(history_geno)

In [None]:
gamma_geno = pd.DataFrame(mapest_geno['gamma'], columns=_data.position) 

In [None]:
sns.heatmap(gamma_geno.loc[~drop_taxa].T, vmin=0, vmax=1, cmap='coolwarm')

In [None]:
sample_h = binary_entropy(pi_fit).sum(1)
strain_h = binary_entropy(gamma_geno).mean(1)

In [None]:
plt.scatter((pi_fit.T * m.mean(1)).sum(1), strain_h, s=1)
plt.ylabel('strain-entropy')
plt.xlabel('estimated-total-coverage')
#plt.xlim(-1, 10)

In [None]:
plt.scatter(m.mean(1), sample_h, s=1)
plt.ylabel('sample-entropy')
plt.xlabel('sample-mean-coverage')
plt.yscale('log')
plt.xscale('log')

In [None]:
plt.hist(sample_h, bins=np.linspace(0, 10, num=50))
None

In [None]:
plt.hist(strain_h, bins=np.linspace(0, 1, num=50))
None

In [None]:
# Build m, y matrices from data, summing over both reads.
_data = data[:, :].astype('float32')
m = torch.tensor(_data.sum(['read', 'allele']).values)
n, g = m.shape
y_obs = torch.tensor(_data.sum('read').sel(allele='alt').values)


# Build fully-conditioned model.
s = 1500
model_frac = partial(
    pyro.condition(
        model,
        data={
#           'alpha_hyper': torch.tensor(300.),
          'alpha': torch.ones(n) * 10.,
          'epsilon_hyper': torch.tensor(0.01),
          'pi_hyper': torch.tensor(1e-1 / s),
          'rho_hyper': torch.tensor(1e0),
#           'epsilon': torch.ones(n) * 0.001,
#           'rho': torch.ones(s) / s,
           'gamma': torch.tensor(mapest_geno['gamma']),
           'y': y_obs,
        }
    ),
    s=s,
    m=m,
#     gamma_hyper=torch.tensor(1e-0),
#     pi0=torch.tensor(1e-1),
#    rho0=torch.tensor(1e-1),
#    alpha0=torch.tensor(100.),  # These two params have no effect IF we condition
#    epsilon0=torch.tensor(0.01),  #  on epsilon_hyper and alpha_hyper
)

# trace = pyro.poutine.trace(model_fit).get_trace()
# trace.compute_log_prob()
# print(trace.format_shapes())

In [None]:
mapest_frac, history_frac = find_map(model_frac)