In [None]:
%load_ext autoreload
%load_ext line_profiler
%load_ext memory_profiler

In [None]:
import sfacts as sf
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

In [None]:
# %load ../haplo-benchmark/include/StrainFacts/sfacts/model_zoo/simple_ssdd2_with_error.py
# import sfacts as sf
from sfacts.model_zoo.components import (
    _mapping_subset,
    powerperturb_transformation_unit_interval,
    powerperturb_transformation,
    SHARED_DESCRIPTIONS,
    SHARED_DIMS,
)
import torch
import pyro
import pyro.distributions as dist


@sf.model.structure(
    dims=SHARED_DIMS,
    description=_mapping_subset(
        SHARED_DESCRIPTIONS,
        [
            "rho",
            "p",
            "m",
            "y",
            "epsilon",
            "alpha",
            "genotypes",
            "communities",
            "metagenotypes",
            "mu",
        ],
    ),
    default_hyperparameters=dict(
        gamma_hyper=0.01,
        rho_hyper=5.0,
        pi_hyper=0.2,
        m_hyper_concentration=0.01,
        m_hyper_mean=2,
        epsilon_hyper_mode=0.01,
        epsilon_hyper_spread=1.5,
        alpha_hyper_mean=200,
        alpha_hyper_scale=1,
        eps=1e-20,
        # alpha=1e3,
    ),
)
def model0(
    n,
    g,
    s,
    a,
    gamma_hyper,
    rho_hyper,
    pi_hyper,
    m_hyper_concentration,
    m_hyper_mean,
    epsilon_hyper_mode,
    epsilon_hyper_spread,
    alpha_hyper_mean,
    alpha_hyper_scale,
    eps,
    _unit,
):
    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            _gamma = pyro.sample("_gamma", dist.Beta(_unit, _unit))
            gamma = pyro.deterministic(
                "gamma",
                powerperturb_transformation_unit_interval(
                    _gamma, 1 / gamma_hyper, _unit
                ),
            )
    pyro.deterministic("genotypes", gamma)

    # Meta-community composition
    _rho = pyro.sample("_rho", dist.Dirichlet(_unit.repeat(s)))
    rho = pyro.deterministic(
        "rho", powerperturb_transformation(_rho, 1 / rho_hyper, _unit)
    )
    # rho = pyro.deterministic("rho", (_rho_unconditioned + eps) / (1 + eps * s))
    pyro.deterministic("metacommunity", rho)

    with pyro.plate("sample", n, dim=-1):
        # Community composition
        _pi = pyro.sample("_pi", dist.Dirichlet(_unit.repeat(s)))
        pi = pyro.deterministic(
            "pi",
            powerperturb_transformation(_pi, 1 / pi_hyper, rho),
        )
        epsilon = pyro.sample(
            "epsilon",
            dist.Beta(epsilon_hyper_spread, epsilon_hyper_spread / epsilon_hyper_mode),
        ).unsqueeze(-1)
        alpha = pyro.sample(
            "alpha",
            dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale),
        ).unsqueeze(-1)
    pyro.deterministic("communities", pi)

    m = pyro.sample(
        "m",
        dist.GammaPoisson(concentration=m_hyper_concentration, rate=m_hyper_concentration / m_hyper_mean)
        .expand([n, g])
        .to_event(),
    )

    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic("p_noerr", pi @ gamma)
    p = pyro.deterministic(
        "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
    )
    # Observation
    y = pyro.sample(
        "y",
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
        ).to_event(),
    )
    pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))
    pyro.deterministic("mu", m.mean(axis=1))

In [None]:
from pyro.distributions import TorchDistribution
from torch.distributions import constraints
import torch


def ssd_loglik(alpha, p, a, x):
    D = x.shape[-1]
    sum_alpha = torch.sum(alpha, dim=-1, keepdim=True)
    termA = torch.lgamma(sum_alpha) - torch.sum(torch.lgamma(alpha), dim=-1, keepdim=True)
    termB = -(D - 1) * torch.log(a)
    termC_num = torch.sum((-(alpha / a) * torch.log(p)) + (alpha / a - 1) * torch.log(x), dim=-1, keepdim=True)
    termC_den = sum_alpha * torch.log(torch.sum((x / p) ** (1 / a),  dim=-1, keepdim=True))
#     print('a', termA)
#     print('b', termB)
#     print('num', termC_num)
#     print('den', termC_den)
    return termA + termB + termC_num - termC_den

class ShiftedScaledDirichlet(TorchDistribution):
    support = pyro.distributions.Dirichlet.support
    has_rsample = False
    arg_constraints = {
        'alpha': constraints.positive,
        'p': constraints.unit_interval,
        'a': constraints.positive,
    }

    def __init__(self, alpha, p, a, validate_args=None):
        alpha, p, a = torch.distributions.utils.broadcast_all(
            alpha, p, a.unsqueeze(dim=-1)
        )
        a = a[..., [0]]
#         print(alpha, p, a)
        
#         batch_shape = alpha.shape[:-1]
#         event_shape = alpha.shape[-1:]
        self._dirichlet = pyro.distributions.Dirichlet(
            concentration=alpha
        )
        self.p = p
        self.a = a
        super(TorchDistribution, self).__init__(
            self._dirichlet.batch_shape,
            self._dirichlet.event_shape,
            validate_args=validate_args
        )
        
    @property
    def alpha(self):
        return self._dirichlet.concentration

    def sample(self, sample_shape=torch.Size()):
        y = self._dirichlet.sample(sample_shape)
        return sf.model_zoo.components.powerperturb_transformation(y, self.a, self.p)

    def log_prob(self, value):
        return ssd_loglik(self.alpha, self.p, self.a, value).squeeze(dim=-1)

In [None]:
s = 3
d = ShiftedScaledDirichlet(torch.tensor([0.2, 0.3, 0.5]), torch.tensor([1.]*3) / s, torch.tensor([[5.0], [1.0], [0.1], [0.2]]))
x = d.sample()

print(x)
print()
d.log_prob(x)

In [None]:
# %load ../haplo-benchmark/include/StrainFacts/sfacts/model_zoo/simple_ssdd2_with_error.py
# import sfacts as sf
from sfacts.model_zoo.components import (
    _mapping_subset,
    powerperturb_transformation_unit_interval,
    powerperturb_transformation,
    SHARED_DESCRIPTIONS,
    SHARED_DIMS,
)
import torch
import pyro
import pyro.distributions as dist


@sf.model.structure(
    dims=SHARED_DIMS,
    description=_mapping_subset(
        SHARED_DESCRIPTIONS,
        [
            "rho",
            "p",
            "m",
            "y",
            "epsilon",
            "alpha",
            "genotypes",
            "communities",
            "metagenotypes",
            "mu",
        ],
    ),
    default_hyperparameters=dict(
        gamma_hyper=0.01,
        rho_hyper=5.0,
        pi_hyper=0.2,
        m_hyper_concentration=0.01,
        m_hyper_mean=2,
        epsilon_hyper_mode=0.01,
        epsilon_hyper_spread=1.5,
        alpha_hyper_mean=200,
        alpha_hyper_scale=1,
        eps=1e-20,
        # alpha=1e3,
    ),
)
def model1(
    n,
    g,
    s,
    a,
    gamma_hyper,
    rho_hyper,
    pi_hyper,
    m_hyper_concentration,
    m_hyper_mean,
    epsilon_hyper_mode,
    epsilon_hyper_spread,
    alpha_hyper_mean,
    alpha_hyper_scale,
    eps,
    _unit,
):
    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
#             _gamma = pyro.sample("_gamma", dist.Beta(_unit, _unit))
#             gamma = pyro.deterministic(
#                 "gamma",
#                 powerperturb_transformation_unit_interval(
#                     _gamma, 1 / gamma_hyper, _unit
#                 ),
#             )
            _gamma = pyro.sample("_gamma", ShiftedScaledDirichlet(_unit.repeat(a), _unit.repeat(a) / a, 1 / gamma_hyper))
            gamma = _gamma[...,0]
#             gamma = pyro.sample("gamma", ShiftedScaledBeta(concentration1=_unit, concentration0=_unit, p=_unit * 0.5, a=1 / gamma_hyper))
    pyro.deterministic("genotypes", gamma)


    # Meta-community composition
    rho = pyro.sample("rho", ShiftedScaledDirichlet(_unit.repeat(s), _unit.repeat(s) / s, 1 / rho_hyper))
    pyro.deterministic("metacommunity", rho)

    with pyro.plate("sample", n, dim=-1):
        # Community composition
        pi = pyro.sample("pi", ShiftedScaledDirichlet(_unit.repeat(s), rho, 1 / pi_hyper))
        epsilon = pyro.sample(
            "epsilon",
            dist.Beta(epsilon_hyper_spread, epsilon_hyper_spread / epsilon_hyper_mode),
        ).unsqueeze(-1)
        alpha = pyro.sample(
            "alpha",
            dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale),
        ).unsqueeze(-1)
    pyro.deterministic("communities", pi)

    m = pyro.sample(
        "m",
        dist.GammaPoisson(concentration=m_hyper_concentration, rate=m_hyper_concentration / m_hyper_mean)
        .expand([n, g])
        .to_event(),
    )

    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic("p_noerr", pi @ gamma)
    p = pyro.deterministic(
        "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
    )
    # Observation
    y = pyro.sample(
        "y",
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
        ).to_event(),
    )
    pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))
    pyro.deterministic("mu", m.mean(axis=1))

In [None]:
g = 500
n = 50
s = 10

sim1 = sf.model.ParameterizedModel(
    model1,
    coords=dict(
        sample=range(n),
        position=range(g),
        strain=range(s),
        allele=['alt', 'ref'],
    ),
    hyperparameters=dict(
        gamma_hyper=1e-5,
        rho_hyper=1.0,
        pi_hyper=0.1,
#         m_hyper_concentration=1000.,
# #         m_hyper_rate=m_hyper_concentration / m_hyper_mean # 100. / 200,
#         m_hyper_mean=0.5,
        epsilon_hyper_mode=0.01,
        epsilon_hyper_spread=1.5,
        alpha_hyper_mean=200,
        alpha_hyper_scale=1,
    ),
    data=dict(
        epsilon=np.ones(n) * 0.01,
        alpha=np.ones(n) * 1e6,
        m=np.ones((n, g)) * 1,
    ),
).simulate_world()

In [None]:
sf.plot.plot_metagenotype(sim1)
sf.plot.plot_community(sim1)
fig, ax = plt.subplots()
ax.hist(sim1.communities.max("strain"))
sf.plot.plot_genotype(sim1, row_colors_func=None)

In [None]:
# %%mprun -f sf.workflow.fit_metagenotypes_simple

nposition = 50

fit1, (history1, *_) = sf.workflow.fit_subsampled_metagenotypes_then_collapse_and_iteratively_refit_genotypes(
    model1,
    sim1.metagenotypes,
    nposition=nposition,
    nstrain=20,
    hyperparameters=dict(
        gamma_hyper=1e-7,
#         rho_hyper=0.2,
#         pi_hyper=0.15,
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.0),
    anneal_hyperparameters=dict(
#         gamma_hyper=dict(name='log', start=1e-4, end=1e-7, wait_steps=1000),
        rho_hyper=dict(name='log', start=1.0, end=0.2, wait_steps=1000),
        pi_hyper=dict(name='log', start=1.0, end=0.4, wait_steps=1000),
    ),
    annealiter=2000,
    estimation_kwargs=dict(
        jit=True, catch_keyboard_interrupt=True, ignore_jit_warnings=True,
    ),
    diss_thresh=0.02,
    frac_thresh=1e-3,
)

In [None]:
# %%mprun -f sf.workflow.fit_metagenotypes_simple

nposition = 100

fit2, (history2, *_) = sf.workflow.fit_subsampled_metagenotypes_then_collapse_and_iteratively_refit_genotypes(
    model1,
    sim1.metagenotypes,
    nposition=nposition,
    nstrain=20,
    hyperparameters=dict(
        gamma_hyper=1e-7,
#         rho_hyper=0.2,
#         pi_hyper=0.15,
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.0),
    anneal_hyperparameters=dict(
#         gamma_hyper=dict(name='log', start=1e-4, end=1e-7, wait_steps=1000),
        rho_hyper=dict(name='log', start=1.0, end=0.2, wait_steps=1000),
        pi_hyper=dict(name='log', start=1.0, end=0.4, wait_steps=1000),
    ),
    annealiter=2000,
    estimation_kwargs=dict(
        jit=True, catch_keyboard_interrupt=True, ignore_jit_warnings=True,
    ),
    diss_thresh=0.02,
    frac_thresh=1e-3,
)

In [None]:
# %%mprun -f sf.workflow.fit_metagenotypes_simple

nposition = 200

fit3, (history3, *_) = sf.workflow.fit_subsampled_metagenotypes_then_collapse_and_iteratively_refit_genotypes(
    model1,
    sim1.metagenotypes,
    nposition=nposition,
    nstrain=20,
    hyperparameters=dict(
        gamma_hyper=1e-7,
#         rho_hyper=0.2,
#         pi_hyper=0.15,
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.0),
    anneal_hyperparameters=dict(
#         gamma_hyper=dict(name='log', start=1e-4, end=1e-7, wait_steps=1000),
        rho_hyper=dict(name='log', start=1.0, end=0.2, wait_steps=1000),
        pi_hyper=dict(name='log', start=1.0, end=0.4, wait_steps=1000),
    ),
    annealiter=2000,
    estimation_kwargs=dict(
        jit=True, catch_keyboard_interrupt=True, ignore_jit_warnings=True,
    ),
    diss_thresh=0.02,
    frac_thresh=1e-3,
)

In [None]:
# %%mprun -f sf.workflow.fit_metagenotypes_simple

nposition = 500

fit4, (history4, *_) = sf.workflow.fit_subsampled_metagenotypes_then_collapse_and_iteratively_refit_genotypes(
    model1,
    sim1.metagenotypes,
    nposition=nposition,
    nstrain=20,
    hyperparameters=dict(
        gamma_hyper=1e-7,
#         rho_hyper=0.2,
#         pi_hyper=0.15,
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.0),
    anneal_hyperparameters=dict(
#         gamma_hyper=dict(name='log', start=1e-4, end=1e-7, wait_steps=1000),
        rho_hyper=dict(name='log', start=1.0, end=0.2, wait_steps=1000),
        pi_hyper=dict(name='log', start=1.0, end=0.4, wait_steps=1000),
    ),
    annealiter=2000,
    estimation_kwargs=dict(
        jit=True, catch_keyboard_interrupt=True, ignore_jit_warnings=True,
    ),
    diss_thresh=0.02,
    frac_thresh=1e-3,
)

In [None]:
s = sim1
f = fit1


sf.plot.plot_community(s, col_linkage_func=lambda w: s.metagenotypes.linkage("sample"))

for f in [fit1, fit2, fit3, fit4]:
    print(
        sf.evaluation.braycurtis_error(s, f)[0],
        sf.evaluation.unifrac_error(s, f)[0],
        sf.evaluation.unifrac_error2(s, f)[0],
        sf.evaluation.genotype_error(s, f)[0],
        sf.evaluation.discretized_genotype_error(s, f)[0],
        sf.evaluation.genotype_error(f, s)[0],
        sf.evaluation.discretized_genotype_error(f, s)[0],
    )
    sf.plot.plot_community(f, col_linkage_func=lambda w: s.metagenotypes.linkage("sample"))