## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

## Library

In [None]:
!tree sfacts -I __pycache__

### `__init__.py`

In [None]:
%%writefile sfacts/logging_util.py
from datetime import datetime
import sys


def info(*msg, quiet=False):
    now = datetime.now()
    if not quiet:
        print(f"[{now}]", *msg, file=sys.stderr, flush=True)


### pyro_util.py

### data.py

In [None]:
%%writefile sfacts/data.py
from sfacts.logging_util import info
from sfacts.pandas_util import idxwhere
from sfacts.math import binary_entropy
import xarray as xr
import numpy as np
from tqdm import tqdm
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage
import scipy as sp
from functools import partial


def _on_2_simplex(d):
    return (d.min() >= 0) and (d.max() <= 1.0)


def _strictly_positive(d):
    return d.min() > 0


def _positive_counts(d):
    return (d.astype(int) == d).all()


class WrappedDataArrayMixin:
    constraints = {}

    # The following are all white-listed and
    # transparently passed through to self.data, but with
    # different symantics for the return value.
    dims = ()
    safe_unwrapped = [
        "shape",
        "sizes",
        "to_pandas",
        "to_dataframe",
        "min",
        "max",
        "sum",
        "mean",
        "median",
        "values",
        "pipe",
        "to_series",
        "isel",
        "sel",
    ]
    # safe_lifted = []
    variable_name = None

    @classmethod
    def from_ndarray(cls, x, coords=None):
        if coords is None:
            coords = {k: None for k in cls.dims}
        shapes = {k: x.shape[i] for i, k in enumerate(cls.dims)}
        for k in coords:
            if coords[k] is None:
                coords[k] = range(shapes[k])
        data = xr.DataArray(
            x,
            dims=cls.dims,
            coords=coords,
        )
        return cls(data)

    @classmethod
    def stack(cls, mapping, dim, prefix=False, validate=True):
        if not len(cls.dims) == 2:
            raise NotImplementedError(
                "Generic stacking has only been implemented for 2D wrapped DataArrays"
            )
        axis = cls.dims.index(dim)
        data = []
        for k, d in mapping.items():
            if prefix:
                d = (
                    d.to_pandas()
                    .rename(lambda s: f"{k}_{s}", axis=axis)
                    .stack()
                    .to_xarray()
                )
            else:
                d = d.data
            data.append(d)
        out = cls(xr.concat(data, dim=dim))
        if validate:
            out.validate_constraints()
        return out

    def __init__(self, data):
        self.data = data
        self.validate_fast()

    def __getattr__(self, name):
        if name in self.dims:
            return getattr(self.data, name)
        elif name in self.safe_unwrapped:
            return getattr(self.data, name)
        # elif name in self.safe_lifted:
        #     return lambda *args, **kwargs: self.__class__(
        #         getattr(self.data, name)(*args, **kwargs)
        #     )
        else:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{name}' "
                f"and this name is not found in '{self.__class__.__name__}.dims', "
                f"'{self.__class__.__name__}.safe_unwrapped', "
                f"or '{self.__class__.__name__}.safe_lifted'. "
                f"Consider working with the '{self.__class__.__name__}.data' "
                f"xr.DataArray object directly."
            )

    def validate_fast(self):
        assert len(self.data.shape) == len(self.dims)
        assert self.data.dims == self.dims

    def validate_constraints(self):
        self.validate_fast()
        for name in self.constraints:
            assert self.constraints[name](self.data), f"Failed constraint: {name}"

    def lift(self, func, *args, **kwargs):
        return self.__class__(func(self.data, *args, **kwargs))

    def mlift(self, name, *args, **kwargs):
        func = getattr(self, name)
        return self.__class__(func(*args, **kwargs))

    def __repr__(self):
        return f"{self.__class__.__name__}({self.data})"

    @classmethod
    def concat(cls, data, dim):
        out_data = []
        new_coords = []
        for name in data:
            d = data[name].data
            out_data.append(d)
            new_coords.extend([f"{name}_{i}" for i in d[dim].values])
        out_data = xr.concat(out_data, dim)
        out_data[dim] = new_coords
        return cls(out_data)

    def to_world(self):
        return World(self.data.to_dataset())
    
    
    def random_sample(self, n, dim, replace=False, keep_order=True):
        dim_n = self.data.sizes[dim]
        ii = np.random.choice(np.arange(dim_n), size=n, replace=replace)
        if keep_order:
            ii = sorted(ii)
        return self.__class__(data=self.data.isel(**{dim: ii}))


class Metagenotypes(WrappedDataArrayMixin):
    dims = ("sample", "position", "allele")
    constraints = dict(positive_counts=_positive_counts)
    variable_name = "metagenotypes"

    @classmethod
    def load(cls, filename_or_obj, validate=True):
        data = (
            xr.open_dataarray(filename_or_obj)
            .rename({"library_id": "sample"})
            .squeeze(drop=True)
        )
        data.name = "metagenotypes"
        result = cls(data)
        if validate:
            result.validate_constraints()
        return result

    @classmethod
    def from_counts_and_totals(cls, y, m, coords=None):
        if coords is None:
            coords = {}
        if not "allele" in coords:
            coords["allele"] = ["alt", "ref"]
        x = np.stack([y, m - y], axis=-1)
        return cls.from_ndarray(x, coords=coords)

    def dump(self, path, validate=True):
        if validate:
            self.validate_constraints()
        self.data.astype(np.uint8).to_dataset(name="tally").to_netcdf(
            path, encoding=dict(tally=dict(zlib=True, complevel=6))
        )

    def select_variable_positions(self, thresh):
        # TODO: Consider using .lift() to do this.
        variable_positions = (
            self.data
            .argmin('allele', skipna=False)
            .mean('sample')
            .pipe(
                lambda x: (x > thresh) &
                (x < (1 - thresh))
            )
        )
        return self.mlift("sel", position=variable_positions)

    def select_samples_with_coverage(self, cvrg_thresh):
        # TODO: Consider using .lift() to do this.
        x = self.data
        covered_samples = (x.sum("allele") > 0).mean("position") > cvrg_thresh
        return self.mlift("sel", sample=covered_samples)

    def frequencies(self, pseudo=0.0):
        "Convert metagenotype counts to a frequency with optional pseudocount."
        return (self.data + pseudo) / (
            self.data.sum("allele") + pseudo * self.sizes["allele"]
        )

    def dominant_allele_fraction(self, pseudo=0.0):
        "Convert metagenotype counts to a frequencies with optional pseudocount."
        return self.frequencies(pseudo=pseudo).max("allele")

    def alt_allele_fraction(self, pseudo=0.0):
        return self.frequencies(pseudo=pseudo).sel(allele="alt")

    def to_estimated_genotypes(self, pseudo=1.0):
        return Genotypes(
            self.alt_allele_fraction(pseudo=pseudo).rename({"sample": "strain"})
        )

    def total_counts(self):
        return self.data.sum("allele")

    def allele_counts(self, allele="alt"):
        return self.sel(allele=allele)

    def mean_depth(self, dim="sample"):
        if dim == "sample":
            over = "position"
        elif dim == "position":
            over = "sample"
        return self.total_counts().mean(over)

    def to_counts_and_totals(self, binary_allele="alt"):
        return dict(
            y=self.allele_counts(allele=binary_allele).values,
            m=self.total_counts().values,
        )

    def pdist(self, dim="sample", pseudo=1.0, **kwargs):
        if dim == "sample":
            _dim = "strain"
        else:
            _dim = dim
        return (
            self.to_estimated_genotypes(pseudo=pseudo)
            .pdist(dim=_dim, **kwargs)
            .rename_axis(columns=dim, index=dim)
        )

    def cosine_pdist(self, dim="sample"):
        if dim != "sample":
            raise NotImplementedError("Only dim 'sample' has been implemented.")
        d = self.to_dataframe().unstack(dim).T
        return pd.DataFrame(
            squareform(pdist(d.values, metric="cosine")), index=d.index, columns=d.index
        )

    def linkage(self, dim="sample", pseudo=1.0, **kwargs):
        if dim == "sample":
            _dim = "strain"
        else:
            _dim = dim
        return self.to_estimated_genotypes(pseudo=pseudo).linkage(dim=_dim, **kwargs)

    def cosine_linkage(
        self,
        dim="sample",
        method="complete",
        optimal_ordering=True,
        **kwargs,
    ):
        dmat = self.cosine_pdist(dim=dim)
        cdmat = squareform(dmat)
        return linkage(
            cdmat, method=method, optimal_ordering=optimal_ordering, **kwargs
        )

    def entropy(self, dim="sample"):
        if dim == "sample":
            over = "position"
        elif dim == "position":
            over = "sample"
        p = self.dominant_allele_fraction()
        ent = binary_entropy(p)
        return ent.sum(over).rename("entropy")


class Genotypes(WrappedDataArrayMixin):
    dims = ("strain", "position")
    constraints = dict(on_2_simplex=_on_2_simplex)
    variable_name = "genotypes"

    def softmask_missing(self, missingness, eps=1e-10):
        clip = partial(np.clip, a_min=eps, a_max=(1 - eps))
        return self.lift(
            lambda g, m: sp.special.expit(sp.special.logit(clip(g)) * clip(m)),
            m=missingness.data,
        )

    def discretized(self):
        return self.lift(np.round)

    def fuzzed(self, eps=1e-5):
        return self.lift(lambda x: (x + eps) / (1 + 2 * eps))

    # TODO: Move distance metrics to a new module?
    @staticmethod
    def _convert_to_sign_representation(p):
        "Alternative representation of binary genotype on a [-1, 1] interval."
        return p * 2 - 1

    @staticmethod
    def _genotype_sign_representation_dissimilarity(x, y, pseudo=0.0):
        "Dissimilarity between 1D genotypes, accounting for fuzzyness."
        dist = ((x - y) / 2) ** 2
        weight = (x * y) ** 2
        wmean_dist = ((weight * dist).mean()) / ((weight.mean() + pseudo))
        return wmean_dist

    @staticmethod
    def _genotype_dissimilarity(x, y, pseudo=0.0):
        return self._genotype_sign_representation_dissimilarity(
            self._genotype_p_to_s(x), self._genotype_p_to_s(y)
        )

    @staticmethod
    def _genotype_dissimilarity_cdmat(unwrapped_values, pseudo=0.0, quiet=True):
        g_sign = Genotypes._convert_to_sign_representation(unwrapped_values)
        s, _ = g_sign.shape
        cdmat = np.empty((s * (s - 1)) // 2)
        k = 0
        with tqdm(total=len(cdmat), disable=quiet) as pbar:
            for i in range(0, s - 1):
                for j in range(i + 1, s):
                    cdmat[k] = Genotypes._genotype_sign_representation_dissimilarity(
                        g_sign[i], g_sign[j], pseudo=pseudo
                    )
                    k = k + 1
                    pbar.update()
        return cdmat

    def pdist(self, dim="strain", pseudo=0.0, quiet=True):
        index = getattr(self, dim)
        if dim == "strain":
            unwrapped_values = self.values
            cdmat = self._genotype_dissimilarity_cdmat(
                unwrapped_values, quiet=quiet, pseudo=pseudo
            )
        elif dim == "position":
            unwrapped_values = self.values.T
            cdmat = pdist(
                self._convert_to_sign_representation(self.values.T), metric="cosine"
            )
        # Reboxing
        dmat = pd.DataFrame(squareform(cdmat), index=index, columns=index)
        return dmat

    def linkage(
        self,
        dim="strain",
        pseudo=0.0,
        quiet=True,
        method="complete",
        optimal_ordering=True,
        **kwargs,
    ):
        dmat = self.pdist(dim=dim, pseudo=pseudo, quiet=quiet)
        cdmat = squareform(dmat)
        return linkage(
            cdmat, method=method, optimal_ordering=optimal_ordering, **kwargs
        )

    def cosine_pdist(self, dim="strain"):
        if dim == "strain":
            d = self.values
            index = self.strain
        elif dim == "position":
            d = self.values.T
            index = self.position
        d = self._convert_to_sign_representation(d)
        cdmat = pdist(d, metric="cosine")
        return pd.DataFrame(squareform(cdmat), index=index, columns=index)

    def cosine_linkage(
        self, dim="strain", method="complete", optimal_ordering=True, **kwargs
    ):
        cdmat = squareform(self.cosine_pdist(dim=dim))
        return linkage(
            cdmat, method=method, optimal_ordering=optimal_ordering, **kwargs
        )

    def entropy(self, dim="strain"):
        if dim == "strain":
            sum_over = "position"
        elif dim == "position":
            sum_over = "strain"
        p = self.data
        ent = binary_entropy(p)
        return ent.sum(sum_over).rename("entropy")


class Missingness(WrappedDataArrayMixin):
    dims = ("strain", "position")
    constraints = dict(on_2_simplex=_on_2_simplex)
    variable_name = "missingness"


class Communities(WrappedDataArrayMixin):
    dims = ("sample", "strain")
    constraints = dict(strains_sum_to_1=lambda d: (d.sum("strain") == 1.0).all())
    variable_name = "communities"

    def fuzzed(self, eps=1e-5):
        new_data = self.data + eps
        new_data = new_data / new_data.sum("strain")
        return self.__class__(new_data)

    def pdist(self, dim="strain", quiet=True):
        index = getattr(self, dim)
        if dim == "strain":
            unwrapped_values = self.values.T
            cdmat = pdist(unwrapped_values, metric="cosine")
        elif dim == "sample":
            unwrapped_values = self.values
            cdmat = pdist(unwrapped_values, metric="braycurtis")
        # Reboxing
        dmat = pd.DataFrame(squareform(cdmat), index=index, columns=index)
        return dmat

    def linkage(
        self,
        dim="strain",
        quiet=True,
        method="average",
        optimal_ordering=True,
        **kwargs,
    ):
        dmat = self.pdist(dim=dim, quiet=quiet)
        cdmat = squareform(dmat)
        return linkage(
            cdmat, method=method, optimal_ordering=optimal_ordering, **kwargs
        )


class Overdispersion(WrappedDataArrayMixin):
    dims = ("sample",)
    constraints = dict(strains_sum_to_1=_strictly_positive)
    variable_name = "overdispersion"


class ErrorRate(WrappedDataArrayMixin):
    dims = ("sample",)
    constraints = dict(on_2_simplex=_on_2_simplex)
    variable_name = "error_rate"


class World:
    safe_lifted = ["isel", "sel"]
    safe_unwrapped = ["sizes"]
    dims = ("sample", "position", "strain", "allele")
    variables = [Genotypes, Missingness, Communities, Metagenotypes]
    _variable_wrapper_map = {wrapper.variable_name: wrapper for wrapper in variables}

    def __init__(self, data):
        self.data = data  # self._align_dims(data)
        self.validate_fast()

    #     @classmethod
    #     def _align_dims(cls, data):
    #         missing_dims = [k for k in cls.dims if k not in data.dims]
    #         return data.expand_dims(missing_dims).transpose(*cls.dims)

    def validate_fast(self):
        assert not (
            set(self.data.dims) - set(self.dims)
        ), f"Found data dims that shouldn't exist: {self.data.dims}"

    def validate_constraints(self):
        self.validate_fast()
        for variable_name in _variable_wrapper_map:
            if variable_name in self.data:
                wrapped_variable = getattr(self, name)
                wrapped_variable.validate_constraints()

    def random_sample(self, n, dim, replace=False, keep_order=True):
        dim_n = self.data.sizes[dim]
        ii = np.random.choice(np.arange(dim_n), size=n, replace=replace)
        if keep_order:
            ii = sorted(ii)
        return self.__class__(data=self.data.isel(**{dim: ii}))

    @property
    def masked_genotypes(self):
        return self.genotypes.softmask_missing(self.missingness)

    def __getattr__(self, name):
        if name in self.dims:
            # Return dims for those registered in self.dims.
            return getattr(self.data, name)
        if name in self._variable_wrapper_map:
            # Return wrapped variables for those registered in self.variables.
            return self._variable_wrapper_map[name](self.data[name])
        elif name in self.safe_unwrapped:
            # Return a naked version of the variables registered in self.safe_unwrapped
            return getattr(self.data, name)
        elif name in self.safe_lifted:
            # Return a lifted version of the the attributes registered in safe_lifted
            return lambda *args, **kwargs: self.__class__(
                getattr(self.data, name)(*args, **kwargs)
            )
        else:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{name}' "
                f"and this name is not found in '{self.__class__.__name__}.dims', "
                f"'{self.__class__.__name__}.safe_unwrapped', "
                f"or '{self.__class__.__name__}.safe_lifted'. "
                f"Consider working with the '{self.__class__.__name__}.data' object directly."
            )

    @classmethod
    def concat(cls, data, dim, rename_coords=False):
        new_coords = []
        # Add source metadata and rename concatenation coordinates
        renamed_data = []
        shared_variables = set([str(v) for v in list(data.values())[0].data.variables])
        for name in data:
            d = data[name].data.copy()
            d["_concat_from"] = xr.DataArray(name, dims=(dim,), coords={dim: d[dim]})
            if rename_coords:
                new_coords.extend([f"{name}_{i}" for i in d[dim].values])
            else:
                new_coords.extend(d[dim].values)
            shared_variables &= set([str(v) for v in d.variables])
            renamed_data.append(d)
        # Drop unshared variables
        ready_data = []
        for d in renamed_data:
            ready_data.append(d[list(shared_variables - set(cls.dims))])
        # Concatenate
        out_data = xr.concat(
            ready_data, dim, data_vars="minimal", coords="minimal", compat="override"
        )
        out_data[dim] = new_coords
        return cls(out_data)


def latent_metagenotypes_pdist(world, dim='sample'):
    if dim == 'sample':
        dim = 'strain'
    return Genotypes(world.data.p.rename({"sample": "strain"})).pdist(dim=dim)


def latent_metagenotypes_linkage(world, dim='sample', method="average", optimal_ordering=True):
    return linkage(
        squareform(latent_metagenotypes_pdist(world, dim=dim)),
        method=method,
        optimal_ordering=optimal_ordering,
    )

### plot.py

In [None]:
%%writefile sfacts/plot.py
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.spatial.distance import squareform
import pandas as pd
import numpy as np
from tqdm import tqdm
import sfacts as sf
from functools import partial
from sklearn.manifold import MDS


def _calculate_clustermap_sizes(
    nx,
    ny,
    scalex=0.15,
    scaley=0.02,
    cwidth=0,
    cheight=0,
    dwidth=0.2,
    dheight=1.0,
):
    # TODO: Incorporate colors.
    mwidth = nx * scalex
    mheight = ny * scaley
    fwidth = mwidth + cwidth + dwidth
    fheight = mheight + cheight + dheight
    dendrogram_ratio = (dwidth / fwidth, dheight / fheight)
    colors_ratio = (cwidth / fwidth, cheight / fheight)
    return (fwidth, fheight), dendrogram_ratio, colors_ratio


def _min_max_normalize(x):
    return (x - x.min()) / (x.max() - x.min())


def _scale_to_max_of_one(x):
    return x / x.max()


def dictionary_union(*args):
    out = args[0].copy()
    for a in args:
        out.update(a)
    return out


def plot_generic_clustermap_factory(
    matrix_func,
    col_colors_func=None,
    row_colors_func=None,
    col_linkage_func=None,
    row_linkage_func=None,
    row_col_annotation_cmap=mpl.cm.viridis,
    scalex=0.05,
    scaley=0.05,
    cwidth=0.1,
    cheight=0.1,
    dwidth=1.0,
    dheight=1.0,
    vmin=None,
    vmax=None,
    center=None,
    cmap=None,
    norm=mpl.colors.PowerNorm(1.0),
    xticklabels=0,
    yticklabels=0,
    metric="correlation",
    cbar_pos=None,
    transpose=False,
    isel=None
):
    def _plot_func(
        world,
        matrix_func=matrix_func,
        col_linkage_func=col_linkage_func,
        row_linkage_func=row_linkage_func,
        col_colors_func=col_colors_func,
        row_colors_func=row_colors_func,
        row_col_annotation_cmap=row_col_annotation_cmap,
        scalex=scalex,
        scaley=scaley,
        cwidth=cwidth,
        cheight=cheight,
        dwidth=dwidth,
        dheight=dheight,
        vmin=vmin,
        vmax=vmax,
        center=center,
        cmap=cmap,
        norm=norm,
        xticklabels=xticklabels,
        yticklabels=yticklabels,
        metric=metric,
        cbar_pos=cbar_pos,
        transpose=transpose,
        isel=isel,
        **kwargs,
    ):
        matrix_data = matrix_func(world)
        
        if isel is None:
            isel={}
            
        matrix_data = matrix_data.isel(**isel).to_pandas()

        if transpose:
            matrix_data = matrix_data.T
            col_linkage_func, row_linkage_func = (
                row_linkage_func,
                col_linkage_func,
            )
            col_colors_func, row_colors_func = row_colors_func, col_colors_func
            scalex, scaley = scaley, scalex
            cwidth, cheight = cheight, cwidth
            dwidth, dheight = dheight, dwidth
            xticklabels, yticklabels = yticklabels, xticklabels

        if col_linkage_func is None:
            col_linkage = None
        else:
            col_linkage = col_linkage_func(world)

        if row_linkage_func is None:
            row_linkage = None
        else:
            row_linkage = row_linkage_func(world)

        if col_colors_func is None:
            col_colors = None
        else:
            col_colors = (
                col_colors_func(world)
                .pipe(_scale_to_max_of_one)
                .to_dataframe()
                .applymap(row_col_annotation_cmap)
            )
            cwidth *= col_colors.shape[1]

        if row_colors_func is None:
            row_colors = None
        else:
            row_colors = (
                row_colors_func(world)
                .pipe(_scale_to_max_of_one)
                .to_dataframe()
                .applymap(row_col_annotation_cmap)
            )
            cheight *= row_colors.shape[1]

        ny, nx = matrix_data.shape
        figsize, dendrogram_ratio, colors_ratio = _calculate_clustermap_sizes(
            nx,
            ny,
            scalex=scalex,
            scaley=scaley,
            cwidth=cwidth,
            cheight=cheight,
            dwidth=dwidth,
            dheight=dheight,
        )
        #     sf.logging_util.info(matrix_data.shape, applied_scale_kwargs, figsize, dendrogram_ratio, colors_ratio)

        clustermap_kwargs = dict(
            vmin=vmin,
            vmax=vmax,
            center=center,
            norm=norm,
            cmap=cmap,
            xticklabels=xticklabels,
            yticklabels=yticklabels,
            col_linkage=col_linkage,
            row_linkage=row_linkage,
            row_colors=row_colors,
            col_colors=col_colors,
            figsize=figsize,
            dendrogram_ratio=dendrogram_ratio,
            colors_ratio=colors_ratio,
            metric=metric,
            cbar_pos=cbar_pos,
        )
        clustermap_kwargs.update(kwargs)

        grid = sns.clustermap(matrix_data, **clustermap_kwargs)
        return grid

    return _plot_func


plot_metagenotype = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.metagenotypes.alt_allele_fraction(pseudo=1.0).T,
    row_linkage_func=lambda w: w.metagenotypes.linkage(dim="position"),
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim="sample"),
    scalex=0.15,
    scaley=0.01,
    vmin=0,
    vmax=1,
    center=0.5,
    cmap="coolwarm",
    xticklabels=1,
    yticklabels=0,
    col_colors_func=(
        lambda w: (
            w.metagenotypes.sum("allele")
            .mean("position")
            .pipe(np.sqrt)
            .rename("mean_depth")
        )
    ),
)


plot_expected_fractions = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.data["p"].T,
    row_linkage_func=lambda w: w.metagenotypes.linkage(dim="position"),
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim="sample"),
    scalex=0.15,
    scaley=0.01,
    vmin=0,
    vmax=1,
    center=0.5,
    cmap="coolwarm",
    xticklabels=1,
    yticklabels=0,
    col_colors_func=(
        lambda w: (
            w.metagenotypes.sum("allele")
            .mean("position")
            .pipe(np.sqrt)
            .rename("mean_depth")
        )
    ),
)

plot_prediction_error = plot_generic_clustermap_factory(
    matrix_func=lambda w: (
        w.data["p"] - w.metagenotypes.frequencies().sel(allele="alt")
    )
    .fillna(0)
    .T,
    row_linkage_func=lambda w: w.metagenotypes.linkage(dim="position"),
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim="sample"),
    scalex=0.15,
    scaley=0.01,
    vmin=-1,
    vmax=1,
    center=0,
    cmap="coolwarm",
    xticklabels=1,
    yticklabels=0,
)

plot_dominance = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.metagenotypes.dominant_allele_fraction(pseudo=1.0)
    .T,
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim="sample"),
    metric="cosine",
    scalex=0.15,
    scaley=0.01,
    vmin=0,
    vmax=1,
    xticklabels=1,
    yticklabels=0,
    col_colors_func=(
        lambda w: (
            w.metagenotypes.sum("allele")
            .mean("position")
            .pipe(np.sqrt)
            .rename("mean_depth")
        )
    ),
)

plot_depth = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.metagenotypes.sum("allele").T,
    row_linkage_func=lambda w: w.metagenotypes.linkage(dim="position"),
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim="sample"),
    scalex=0.15,
    scaley=0.01,
    vmin=0,
    vmax=1,
    center=0.5,
    cmap="gray",
    xticklabels=1,
    yticklabels=0,
    col_colors_func=(
        lambda w: (
            w.metagenotypes.sum("allele")
            .mean("position")
            .pipe(np.sqrt)
            .rename("mean_depth")
        )
    ),
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
)

plot_genotype = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.genotypes,
    row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
    col_linkage_func=lambda w: w.genotypes.linkage(dim="position"),
    scaley=0.20,
    scalex=0.01,
    vmin=0,
    center=0.5,
    vmax=1,
    cmap="coolwarm",
    yticklabels=1,
    xticklabels=0,
    row_colors_func=(lambda w: (w.genotypes.entropy())),
)

plot_masked_genotype = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.masked_genotypes,
    row_linkage_func=lambda w: w.masked_genotypes.linkage(dim="strain"),
    col_linkage_func=lambda w: w.genotypes.linkage(dim="position"),
    scaley=0.20,
    scalex=0.01,
    vmin=0,
    vmax=1,
    cmap="coolwarm",
    yticklabels=1,
    xticklabels=0,
    row_colors_func=(lambda w: (w.genotypes.entropy())),
)

plot_missing = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.missingness,
    row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
    col_linkage_func=lambda w: w.genotypes.linkage(dim="position"),
    metric="cosine",
    scaley=0.20,
    scalex=0.01,
    vmin=0,
    vmax=1,
    cmap=None,
    yticklabels=1,
    xticklabels=0,
    row_colors_func=(lambda w: (1 - w.missingness.mean("position"))),
)

plot_community = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.communities.data.T,
    row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
    col_linkage_func=lambda w: w.communities.linkage(dim="sample"),
    row_colors_func=(lambda w: (w.communities.sum("sample").pipe(np.sqrt))),
    metric="cosine",
    scaley=0.20,
    scalex=0.14,
    dwidth=1.0,
    vmin=0,
    vmax=1,
    cmap=None,
    norm=mpl.colors.PowerNorm(1 / 2),
    xticklabels=1,
    yticklabels=1,
)


def plot_loss_history(trace):
    trace = np.array(trace)
    plt.plot((trace - trace.min()))
    plt.yscale("log")


def nmds_ordination(dmat):
    init = MDS(
        n_components=2,
        max_iter=3000,
        eps=1e-9,
        random_state=1,
        dissimilarity="precomputed",
        n_jobs=1,
    ).fit_transform(dmat)
    nmds = MDS(
        n_components=2,
        metric=False,
        max_iter=3000,
        eps=1e-12,
        dissimilarity="precomputed",
        random_state=1,
        n_jobs=1,
        n_init=1,
    )
    ordin = nmds.fit_transform(dmat, init=init)

    ordin = pd.DataFrame(
        ordin,
        index=dmat.index,
        columns=[f"PC{i}" for i in np.arange(ordin.shape[1]) + 1],
    )
    return ordin


def ordination_plot(
    world,
    dmat_func,
    ordin_func=nmds_ordination,
    colors_func=None,
    sizes_func=None,
    xy=("PC1", "PC2"),
    ax=None,
    **kwargs,
):
    """Plot nMDS ordination with markers colored/shaped by metadata features."""
    x, y = xy
    dmat = dmat_func(world)

    if colors_func is None:
        colors = None
    else:
        colors = colors_func(world)

    if sizes_func is None:
        sizes = None
    else:
        sizes = sizes_func(world)

    ordin = ordin_func(dmat)

    scatter_kwargs = dict(
        c=colors,
        cmap="viridis",
        s=sizes,
        edgecolor="k",
        lw=0.2,
        alpha=0.8,
    )
    scatter_kwargs.update(kwargs)
    if ax is None:
        ax = plt.gca()
    ax.scatter(x=x, y=y, data=ordin, **scatter_kwargs)
    ax.set_xlabel(f"{x}")
    ax.set_ylabel(f"{y}")
    return ax, ordin


def plot_metagenotype_frequency_spectrum(
    world,
    sample_list=None,
    show_dominant=False,
    axwidth=2,
    axheight=1.5,
    bins=None,
    axs=None,
    **kwargs,
):
    if sample_list is None:
        sample_list = world.sample.values

    hist_kwargs = dict(color="black")
    hist_kwargs.update(kwargs)

    n = len(sample_list)
    if not axs:
        fig, axs = plt.subplots(
            n, n, figsize=(axwidth * n, axheight * n), sharex=True, sharey=True
        )
    axs = np.asarray(axs).reshape((n, n))

    if bins is None:
        bins = np.linspace(0.5, 1.0, num=21)

    frequencies = world.metagenotypes.mlift("sel", sample=sample_list).frequencies()
    for sample_i, row in zip(sample_list, axs):
        for sample_j, ax in zip(sample_list, row):
            domfreq_ij = (
                frequencies.sel(sample=[sample_i, sample_j])
                .mean("sample")
                .max("allele")
            )
            ax.hist(domfreq_ij, bins=bins, **hist_kwargs)

    if show_dominant:
        max_frac = world.communities.sel(sample=sample_list).max("strain")
        max_frac_complement = 1 - max_frac
        for i, sample in enumerate(sample_list):
            ax = axs[i, i]
            ax.axvline(
                max_frac.sel(sample=sample),
                linestyle="--",
                lw=1,
                color="darkblue",
            )
            ax.axvline(
                max_frac_complement.sel(sample=sample),
                linestyle="--",
                lw=1,
                color="darkred",
            )

    for i, sample in enumerate(sample_list):
        ax_left = axs[i, 0]
        ax_top = axs[0, i]
        ax_left.set_ylabel(sample)
        ax_top.set_title(sample)

    ax.set_xlim(0.5, 1)


def plot_metagenotype_frequency_spectrum_comparison(
    worlds, sample, alpha=0.5, bins=None, ax=None
):
    if bins is None:
        bins = np.linspace(0.5, 1.0, num=21)
    if ax is None:
        ax = plt.gca()

    plot_hist = lambda w, label: ax.hist(
        w.metagenotypes.mlift("sel", sample=[sample])
        .dominant_allele_fraction()
        .values.squeeze(),
        bins=bins,
        alpha=alpha,
        label=label,
    )
    for k, w in worlds.items():
        plot_hist(w, k)
    ax.set_title(sample)

    return ax


def plot_beta_diversity_comparison(
    worldA,
    worldB,
    **kwargs,
):
    cdmatA = squareform(worldA.communities.pdist(dim="sample"))
    cdmatB = squareform(worldB.communities.pdist(dim="sample"))

    return sns.jointplot("a", "b", data=pd.DataFrame(dict(a=cdmatA, b=cdmatB)))


### model.py

In [None]:
%%writefile sfacts/model.py
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from torch.nn.functional import pad as torch_pad
import xarray as xr
import sfacts as sf
from sfacts.pyro_util import all_torch
from sfacts.logging_util import info


class Structure:
    def __init__(self, generative, dims, description, default_hyperparameters=None):
        """

        *generative* :: Pyro generative model function(shape_dim_0, shape_dim_1, shape_dim_2, ..., **hyper_parameters)
        *dims* :: Sequence of names for dim_0, dim_1, dim_2, ...
        *description* :: Mapping from model variable to its dims.
        *default_hyperparameters* :: Values to use for hyperparameters when not explicitly set.
        """
        if default_hyperparameters is None:
            default_hyperparameters = {}

        self.generative = generative
        self.dims = dims
        self.description = description
        self.default_hyperparameters = default_hyperparameters

    #         _ = self(self._dummy_shape, **all_torch(**self.default_hyperparameters))

    #         info(f"New Structure({self.generative}, {self.default_hyperparameters})")

    def __call__(self, shape, data, hyperparameters, unit):
        assert len(shape) == len(self.dims)
        conditioned_generative = pyro.condition(self.generative, data)
        return conditioned_generative(*shape, **hyperparameters, _unit=unit)

    #     def condition(self, **data):
    #         new_data = self.data.copy()
    #         new_data.update(data)
    #         return self.__class__(
    #             generative=self.generative,
    #             dims=self.dims,
    #             description=self.description,
    #             default_hyperparameters=self.default_hyperparameters,
    #             data=new_data,
    #         )

    @property
    def _dummy_shape(self):
        shape = range(1, len(self.dims) + 1)
        return shape

    def explain_shapes(self, shape=None):
        if shape is None:
            shape = self._dummy_shape
        info(dict(zip(self.dims, shape)))
        sf.pyro_util.shape_info(self(shape, **self.default_hyperparameters))

    def __repr__(self):
        return (
            f"{self.generative.__name__}("
            # f"{self.generative}, "
            f"dims={self.dims}, "
            f"description={self.description}, "
            f"default_hyperparameters={self.default_hyperparameters} "
            f")"
        )


# For decorator use.
def structure(dims, description, default_hyperparameters=None):
    return partial(
        Structure,
        dims=dims,
        description=description,
        default_hyperparameters=default_hyperparameters,
    )


class ParameterizedModel:
    def __init__(
        self,
        structure,
        coords,
        dtype=torch.float32,
        device="cpu",
        data=None,
        hyperparameters=None,
    ):
        if hyperparameters is None:
            hyperparameters = {}

        if data is None:
            data = {}

        self.structure = structure
        self.coords = {k: self._coords_or_range(coords[k]) for k in self.structure.dims}
        self.dtype = dtype
        self.device = device
        self.hyperparameters = self.structure.default_hyperparameters.copy()
        self.hyperparameters.update(hyperparameters)
        self.data = data

    @property
    def sizes(self):
        return {k: len(self.coords[k]) for k in self.structure.dims}

    @property
    def shape(self):
        return tuple(self.sizes.values())

    def __repr__(self):
        return (
            f"{self.__class__.__name__}"
            f"({self.structure}, "
            f"coords={self.coords}, "
            f"dtype={self.dtype}, "
            f"device={self.device}, "
            f"hyperparameters={self.hyperparameters}, "
            f"data={self.data})"
        )

    def __call__(self):
        # Here's where all the action happens.
        # All parameters are cast based on dtype and device.
        # The model is conditioned on the
        # data, and then called with the shape tuple
        # and cast hyperparameters.
        return self.structure(
            self.shape,
            data=all_torch(**self.data, dtype=self.dtype, device=self.device),
            hyperparameters=all_torch(
                **self.hyperparameters, dtype=self.dtype, device=self.device
            ),
            unit=torch.tensor(1.0, dtype=self.dtype, device=self.device),
        )

    @staticmethod
    def _coords_or_range(coords):
        if type(coords) == int:
            return range(coords)
        else:
            return coords

    def with_hyperparameters(self, **hyperparameters):
        new_hyperparameters = self.hyperparameters.copy()
        new_hyperparameters.update(hyperparameters)
        return self.__class__(
            structure=self.structure,
            coords=self.coords,
            dtype=self.dtype,
            device=self.device,
            hyperparameters=new_hyperparameters,
            data=self.data,
        )

    def with_amended_coords(self, **coords):
        new_coords = self.coords.copy()
        new_coords.update(coords)
        return self.__class__(
            structure=self.structure,
            coords=new_coords,
            dtype=self.dtype,
            device=self.device,
            hyperparameters=self.hyperparameters,
            data=self.data,
        )

    def condition(self, **data):
        new_data = self.data.copy()
        new_data.update(data)
        return self.__class__(
            structure=self.structure,
            coords=self.coords,
            dtype=self.dtype,
            device=self.device,
            hyperparameters=self.hyperparameters,
            data=new_data,
        )

    def format_world(self, data):
        out = {}
        for k in self.structure.description:
            out[k] = xr.DataArray(
                data[k],
                dims=self.structure.description[k],
                coords={dim: self.coords[dim] for dim in self.structure.description[k]},
            )
        return sf.data.World(xr.Dataset(out))

    def simulate(self, n=1, seed=None):
        sf.pyro_util.set_random_seed(seed)
        obs = pyro.infer.Predictive(self, num_samples=n)()
        obs = {k: obs[k].detach().cpu().numpy().squeeze() for k in obs.keys()}
        return obs

    def simulate_world(self, seed=None):
        return self.format_world(self.simulate(n=1))


### `model_zoo/__init__.py`

In [None]:
%%writefile sfacts/model_zoo/__init__.py
from sfacts.model_zoo.full_metagenotype import full_metagenotype_model_structure
from sfacts.model_zoo.full_metagenotype_dirichlet_rho import full_metagenotype_dirichlet_rho_model_structure

### model_zoo/components.py

In [None]:
%%writefile sfacts/model_zoo/components.py
import sfacts as sf
import pyro
import pyro.distributions as dist
import torch
from torch.nn.functional import pad as torch_pad


SHARED_DIMS = ("sample", "position", "strain", "allele")
SHARED_DESCRIPTIONS = dict(
    gamma=("strain", "position"),
    delta=("strain", "position"),
    rho=("strain",),
    pi=("sample", "strain"),
    epsilon=("sample",),
    m_hyper_r=("sample",),
    m_hyper_r_mean=(),
    m_hyper_r_scale=(),
    mu=("sample",),
    nu=("sample", "position"),
    p_noerr=("sample", "position"),
    p=("sample", "position"),
    alpha_hyper_mean=(),
    alpha=("sample",),
    m=("sample", "position"),
    y=("sample", "position"),
    genotypes=("strain", "position"),
    missingness=("strain", "position"),
    communities=("sample", "strain"),
    metagenotypes=("sample", "position", "allele"),
)


def _mapping_subset(mapping, keys):
    return {k: mapping[k] for k in keys}


def stickbreaking_betas_to_probs(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return torch_pad(beta, (0, 1), value=1) * torch_pad(beta1m_cumprod, (1, 0), value=1)


def NegativeBinomialReparam(mu, r):
    p = 1.0 / ((r / mu) + 1.0)
    logits = torch.logit(p)
    #     p = torch.clamp(p, eps, 1 - eps)
    return dist.NegativeBinomial(
        total_count=r,
        logits=logits,
    )


def unit_interval_power_transformation(p, alpha, beta):
    log_p = torch.log(p)
    log_q = torch.log1p(-p)
    log_p_raised = log_p * alpha
    log_q_raised = log_q * beta
    return torch.exp(
        log_p_raised - torch.logsumexp(torch.stack([log_p_raised, log_q_raised]), dim=0)
    )


# def _pp_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp):
#     # Genotypes
#     #     delta_hyper_p = pyro.sample('delta_hyper_p', dist.Beta(1., 1.))
#     with pyro.plate("position", g, dim=-1):
#         with pyro.plate("strain", s, dim=-2):
#             _gamma = pyro.sample("_gamma", dist.Beta(1.0, 1.0))
#             gamma = pyro.deterministic(
#                 "gamma",
#                 unit_interval_power_transformation(
#                     _gamma, 1 / gamma_hyper, 1 / gamma_hyper
#                 ),
#             )
#             #                 Position presence/absence
#             _delta = pyro.sample("_delta", dist.Beta(1.0, 1.0))
#             delta = pyro.deterministic(
#                 "delta",
#                 unit_interval_power_transformation(
#                     _delta,
#                     2 * (1 - delta_hyper_r) / delta_hyper_temp,
#                     2 * delta_hyper_r / delta_hyper_temp,
#                 ),
#             )
#
#     #                 delta = pyro.sample(
#     #                     'delta',
#     #                     dist.RelaxedBernoulli(
#     #                         temperature=delta_hyper_temp, probs=delta_hyper_p
#     #                     ),
#     #                 )
#
#     # These deterministics are accessed by PointMixin class properties.
#     pyro.deterministic("genotypes", gamma)
#     pyro.deterministic("missingness", delta)
#     return gamma, delta
#
#
# def _gsm_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp):
#     with pyro.plate("position", g, dim=-1):
#         with pyro.plate("strain", s, dim=-2):
#             gamma = pyro.sample(
#                 "gamma",
#                 dist.RelaxedBernoulli(
#                     temperature=gamma_hyper,
#                     probs=0.5,
#                 ),
#             )
#             delta = pyro.sample(
#                 "delta",
#                 dist.RelaxedBernoulli(
#                     temperature=delta_hyper_temp, probs=delta_hyper_r
#                 ),
#             )
#     pyro.deterministic("genotypes", gamma)
#     pyro.deterministic("missingness", delta)
#     return gamma, delta
#
#
# def _beta_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp):
#     with pyro.plate("position", g, dim=-1):
#         with pyro.plate("strain", s, dim=-2):
#             gamma = pyro.sample(
#                 "gamma",
#                 dist.Beta(
#                     gamma_hyper,
#                     gamma_hyper,
#                 ),
#             )
#             delta = pyro.sample(
#                 "delta",
#                 dist.RelaxedBernoulli(
#                     delta_hyper_r * delta_hyper_temp,
#                     (1 - delta_hyper_r) * delta_hyper_temp,
#                 ),
#             )
#     pyro.deterministic("genotypes", gamma)
#     pyro.deterministic("missingness", delta)
#     return gamma, delta
#
#
# def _hybrid_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp):
#     with pyro.plate("position", g, dim=-1):
#         with pyro.plate("strain", s, dim=-2):
#             _gamma = pyro.sample("_gamma", dist.Beta(1.0, 1.0))
#             gamma = pyro.deterministic(
#                 "gamma",
#                 unit_interval_power_transformation(
#                     _gamma, 1 / gamma_hyper, 1 / gamma_hyper
#                 ),
#             )
#             #                 Position presence/absence
#             delta = pyro.sample(
#                 "delta",
#                 dist.RelaxedBernoulli(
#                     temperature=delta_hyper_temp, probs=delta_hyper_r
#                 ),
#             )
#     pyro.deterministic("genotypes", gamma)
#     pyro.deterministic("missingness", delta)
#     return gamma, delta
#
#
# def _dp_rho_module(s, rho_hyper):
#     #         # TODO: Will torch.ones(s) fail when I try to run this on the GPU because it's, by default on the CPU?
#     #         rho = pyro.sample(
#     #             "rho", dist.Dirichlet(rho_hyper * torch.ones(s))
#     #         )
#     # Meta-community composition
#     rho_betas = pyro.sample(
#         "rho_betas", dist.Beta(1.0, rho_hyper).expand([s - 1]).to_event()
#     )
#     rho = pyro.deterministic("rho", stickbreaking_betas_to_probs(rho_betas))
#     pyro.deterministic("metacommunity", rho)
#     return rho
#
#
# def _dirichlet_pi_epsilon_alpha_mu_module(
#     n,
#     pi_hyper,
#     rho,
#     epsilon_hyper_alpha,
#     epsilon_hyper_beta,
#     alpha_hyper_mean,
#     alpha_hyper_scale,
#     mu_hyper_mean,
#     mu_hyper_scale,
# ):
#     with pyro.plate("sample", n, dim=-1):
#         # Community composition
#         pi = pyro.sample("pi", dist.Dirichlet(pi_hyper * rho, validate_args=False))
#         # Sequencing error
#         epsilon = pyro.sample(
#             "epsilon", dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta)
#         ).unsqueeze(-1)
#         alpha = pyro.sample(
#             "alpha",
#             dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale),
#         ).unsqueeze(-1)
#         # Sample coverage
#         mu = pyro.sample(
#             "mu", dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale)
#         )
#     pyro.deterministic("communities", pi)
#     return pi, epsilon, alpha, mu
#
#
# def _dirichlet_pi_epsilon_alpha_r_mu_module(
#     n,
#     pi_hyper,
#     rho,
#     epsilon_hyper_alpha,
#     epsilon_hyper_beta,
#     alpha_hyper_mean,
#     alpha_hyper_scale,
#     mu_hyper_mean,
#     mu_hyper_scale,
# ):
#     with pyro.plate("sample", n, dim=-1):
#         # Community composition
#         pi = pyro.sample("pi", dist.Dirichlet(pi_hyper * rho, validate_args=False))
#         # Sequencing error
#         epsilon = pyro.sample(
#             "epsilon", dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta)
#         ).unsqueeze(-1)
#         alpha = pyro.sample(
#             "alpha",
#             dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale),
#         ).unsqueeze(-1)
#         # Sample coverage
#         mu = pyro.sample(
#             "mu", dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale)
#         )
#     pyro.deterministic("communities", pi)
#     return pi, epsilon, alpha, mu
#
#
# def _lognormal_alpha_hyper_mean_module(alpha_hyper_hyper_mean, alpha_hyper_hyper_scale):
#     alpha_hyper_mean = pyro.sample(
#         "alpha_hyper_mean",
#         dist.LogNormal(
#             loc=torch.log(alpha_hyper_hyper_mean), scale=alpha_hyper_hyper_scale
#         ),
#     )
#     return alpha_hyper_mean
#
#
# def _m_hyper_r_module(n, m_hyper_r_scale):
#     m_hyper_r_mean = pyro.sample("m_hyper_r_mean", dist.LogNormal(loc=0.0, scale=10.0))
#     m_hyper_r = pyro.sample(
#         "m_hyper_r",
#         dist.LogNormal(loc=torch.log(m_hyper_r_mean), scale=m_hyper_r_scale)
#         .expand([n, 1])
#         .to_event(),
#     )
#     return m_hyper_r
#
#
# def _betabinomial_observation_module(pi, gamma, delta, m_hyper_r, mu, epsilon, alpha):
#     # Depth at each position
#     nu = pyro.deterministic("nu", pi @ delta)
#     # TODO: Consider using pyro.distributions.GammaPoisson parameterization?
#     m = pyro.sample(
#         "m",
#         NegativeBinomialReparam(nu * mu.reshape((-1, 1)), m_hyper_r).to_event(),
#     )
#
#     # Expected fractions of each allele at each position
#     p_noerr = pyro.deterministic("p_noerr", pi @ (gamma * delta) / nu)
#     p = pyro.deterministic(
#         "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
#     )
#
#     # Observation
#     y = pyro.sample(
#         "y",
#         dist.BetaBinomial(
#             concentration1=alpha * p,
#             concentration0=alpha * (1 - p),
#             total_count=m,
#             #             validate_args=False,
#         ).to_event(),
#     )
#     # TODO: Check that dim=0 works?
#     metagenotypes = pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))
#
#
# @sf.model.structure(
#     dims=SHARED_DIMS,
#     description=_mapping_subset(
#         SHARED_DESCRIPTIONS,
#         [
#             "rho",
#             "epsilon",
#             "m_hyper_r",
#             "mu",
#             "nu",
#             "p_noerr",
#             "p",
#             "m",
#             "y",
#             "alpha_hyper_mean",
#             "alpha",
#             "genotypes",
#             "missingness",
#             "communities",
#             "metagenotypes",
#         ],
#     ),
#     default_hyperparameters=dict(
#         gamma_hyper=0.01,
#         delta_hyper_temp=0.01,
#         delta_hyper_r=0.9,
#         rho_hyper=5.0,
#         pi_hyper=0.2,
#         epsilon_hyper_alpha=1.5,
#         epsilon_hyper_beta=1.5 / 0.01,
#         mu_hyper_mean=1.0,
#         mu_hyper_scale=10.0,
#         #         m_hyper_r_mu=1.,
#         m_hyper_r_scale=1.0,
#         alpha_hyper_hyper_mean=100.0,
#         alpha_hyper_hyper_scale=1.0,
#         alpha_hyper_scale=0.5,
#     ),
# )
# def pp_fuzzy_missing_dp_betabinomial_metagenotype(
#     n,
#     g,
#     s,
#     a,
#     gamma_hyper,
#     delta_hyper_r,
#     delta_hyper_temp,
#     rho_hyper,  # =1.0,
#     pi_hyper,  # =1.0,
#     alpha_hyper_hyper_mean,  # =100.0,
#     alpha_hyper_hyper_scale,  # =1.0,
#     alpha_hyper_scale,  # =0.5,
#     epsilon_hyper_alpha,  # =1.5,
#     epsilon_hyper_beta,  # =1.5 / 0.01,
#     mu_hyper_mean,  # =1.0,
#     mu_hyper_scale,  # =1.0,
#     #         m_hyper_r_mu,
#     m_hyper_r_scale,
# ):
#     gamma, delta = _pp_gamma_delta_module(
#         s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp
#     )
#     rho = _dp_rho_module(s, rho_hyper)
#     alpha_hyper_mean = _lognormal_alpha_hyper_mean_module(
#         alpha_hyper_hyper_mean, alpha_hyper_hyper_scale
#     )
#     pi, epsilon, alpha, mu = _dirichlet_pi_epsilon_alpha_mu_module(
#         n,
#         pi_hyper,
#         rho,
#         epsilon_hyper_alpha,
#         epsilon_hyper_beta,
#         alpha_hyper_mean,
#         alpha_hyper_scale,
#         mu_hyper_mean,
#         mu_hyper_scale,
#     )
#     m_hyper_r = _m_hyper_r_module(n, m_hyper_r_scale)
#     _betabinomial_observation_module(pi, gamma, delta, m_hyper_r, mu, epsilon, alpha)
#
#
# @sf.model.structure(
#     dims=SHARED_DIMS,
#     description=_mapping_subset(
#         SHARED_DESCRIPTIONS,
#         [
#             "rho",
#             "epsilon",
#             "m_hyper_r",
#             "mu",
#             "nu",
#             "p_noerr",
#             "p",
#             "m",
#             "y",
#             "alpha_hyper_mean",
#             "alpha",
#             "genotypes",
#             "missingness",
#             "communities",
#             "metagenotypes",
#         ],
#     ),
#     default_hyperparameters=dict(
#         gamma_hyper=0.01,
#         delta_hyper_temp=0.01,
#         delta_hyper_r=0.9,
#         rho_hyper=5.0,
#         pi_hyper=0.2,
#         epsilon_hyper_alpha=1.5,
#         epsilon_hyper_beta=1.5 / 0.01,
#         mu_hyper_mean=1.0,
#         mu_hyper_scale=10.0,
#         #         m_hyper_r_mu=1.,
#         m_hyper_r_scale=1.0,
#         alpha_hyper_hyper_mean=100.0,
#         alpha_hyper_hyper_scale=1.0,
#         alpha_hyper_scale=0.5,
#     ),
# )
# def gsm_fuzzy_missing_dp_betabinomial_metagenotype(
#     n,
#     g,
#     s,
#     a,
#     gamma_hyper,
#     delta_hyper_r,
#     delta_hyper_temp,
#     rho_hyper,  # =1.0,
#     pi_hyper,  # =1.0,
#     alpha_hyper_hyper_mean,  # =100.0,
#     alpha_hyper_hyper_scale,  # =1.0,
#     alpha_hyper_scale,  # =0.5,
#     epsilon_hyper_alpha,  # =1.5,
#     epsilon_hyper_beta,  # =1.5 / 0.01,
#     mu_hyper_mean,  # =1.0,
#     mu_hyper_scale,  # =1.0,
#     #         m_hyper_r_mu,
#     m_hyper_r_scale,
# ):
#     gamma, delta = _gsm_gamma_delta_module(
#         s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp
#     )
#     rho = _dp_rho_module(s, rho_hyper)
#     alpha_hyper_mean = _lognormal_alpha_hyper_mean_module(
#         alpha_hyper_hyper_mean, alpha_hyper_hyper_scale
#     )
#     pi, epsilon, alpha, mu = _dirichlet_pi_epsilon_alpha_mu_module(
#         n,
#         pi_hyper,
#         rho,
#         epsilon_hyper_alpha,
#         epsilon_hyper_beta,
#         alpha_hyper_mean,
#         alpha_hyper_scale,
#         mu_hyper_mean,
#         mu_hyper_scale,
#     )
#     m_hyper_r = _m_hyper_r_module(n, m_hyper_r_scale)
#     _betabinomial_observation_module(pi, gamma, delta, m_hyper_r, mu, epsilon, alpha)
#
#
# @sf.model.structure(
#     dims=SHARED_DIMS,
#     description=_mapping_subset(
#         SHARED_DESCRIPTIONS,
#         [
#             "rho",
#             "epsilon",
#             "m_hyper_r",
#             "mu",
#             "nu",
#             "p_noerr",
#             "p",
#             "m",
#             "y",
#             "alpha_hyper_mean",
#             "alpha",
#             "genotypes",
#             "missingness",
#             "communities",
#             "metagenotypes",
#         ],
#     ),
#     default_hyperparameters=dict(
#         gamma_hyper=0.01,
#         delta_hyper_temp=0.01,
#         delta_hyper_r=0.9,
#         rho_hyper=5.0,
#         pi_hyper=0.2,
#         epsilon_hyper_alpha=1.5,
#         epsilon_hyper_beta=1.5 / 0.01,
#         mu_hyper_mean=1.0,
#         mu_hyper_scale=10.0,
#         #         m_hyper_r_mu=1.,
#         m_hyper_r_scale=1.0,
#         alpha_hyper_hyper_mean=100.0,
#         alpha_hyper_hyper_scale=1.0,
#         alpha_hyper_scale=0.5,
#     ),
# )
# def hybrid_fuzzy_missing_dp_betabinomial_metagenotype(
#     n,
#     g,
#     s,
#     a,
#     gamma_hyper,
#     delta_hyper_r,
#     delta_hyper_temp,
#     rho_hyper,  # =1.0,
#     pi_hyper,  # =1.0,
#     alpha_hyper_hyper_mean,  # =100.0,
#     alpha_hyper_hyper_scale,  # =1.0,
#     alpha_hyper_scale,  # =0.5,
#     epsilon_hyper_alpha,  # =1.5,
#     epsilon_hyper_beta,  # =1.5 / 0.01,
#     mu_hyper_mean,  # =1.0,
#     mu_hyper_scale,  # =1.0,
#     #         m_hyper_r_mu,
#     m_hyper_r_scale,
# ):
#     gamma, delta = _hybrid_gamma_delta_module(
#         s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp
#     )
#     rho = _dp_rho_module(s, rho_hyper)
#     alpha_hyper_mean = _lognormal_alpha_hyper_mean_module(
#         alpha_hyper_hyper_mean, alpha_hyper_hyper_scale
#     )
#     pi, epsilon, alpha, mu = _dirichlet_pi_epsilon_alpha_mu_module(
#         n,
#         pi_hyper,
#         rho,
#         epsilon_hyper_alpha,
#         epsilon_hyper_beta,
#         alpha_hyper_mean,
#         alpha_hyper_scale,
#         mu_hyper_mean,
#         mu_hyper_scale,
#     )
#     m_hyper_r = _m_hyper_r_module(n, m_hyper_r_scale)
#     _betabinomial_observation_module(pi, gamma, delta, m_hyper_r, mu, epsilon, alpha)
#
#
# @sf.model.structure(
#     dims=SHARED_DIMS,
#     description=_mapping_subset(
#         SHARED_DESCRIPTIONS,
#         ["m", "y", "genotypes", "rho", "communities", "metagenotypes"],
#     ),
#     default_hyperparameters=dict(
#         gamma_hyper=0.01,
#         rho_hyper=0.01,
#         pi_hyper=0.2,
#     ),
# )
# def simple_metagenotype(
#     n,
#     g,
#     s,
#     a,
#     gamma_hyper,
#     rho_hyper,
#     pi_hyper,
# ):
#     with pyro.plate("position", g, dim=-1):
#         with pyro.plate("strain", s, dim=-2):
#             gamma = pyro.sample("gamma", dist.Beta(gamma_hyper, gamma_hyper))
#     pyro.deterministic("genotypes", gamma)
#
#     # Meta-community composition
#     rho_betas = pyro.sample(
#         "rho_betas", dist.Beta(1.0, rho_hyper).expand([s - 1]).to_event()
#     )
#     rho = pyro.deterministic("rho", stickbreaking_betas_to_probs(rho_betas))
#
#     with pyro.plate("sample", n, dim=-1):
#         # Community composition
#         pi = pyro.sample(
#             "pi",
#             dist.Dirichlet(
#                 pi_hyper * rho,
#                 validate_args=False,
#             ),
#         )
#     pyro.deterministic("communities", pi)
#
#     # Depth at each position
#     m = pyro.sample(
#         "m",
#         NegativeBinomialReparam(
#             torch.tensor(100),
#             torch.tensor(0.1),
#         )
#         .expand([n, g])
#         .to_event(),
#     )
#
#     # Expected fractions of each allele at each position
#     p = pyro.deterministic("p", pi @ gamma)
#
#     # Observation
#     y = pyro.sample(
#         "y",
#         dist.Binomial(
#             probs=p,
#             total_count=m,
#             validate_args=False,
#         ).to_event(),
#     )
#     metagenotypes = pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))
#
#
# @sf.model.structure(
#     dims=SHARED_DIMS,
#     description=_mapping_subset(
#         SHARED_DESCRIPTIONS,
#         ["m", "y", "genotypes", "rho", "communities", "metagenotypes"],
#     ),
#     default_hyperparameters=dict(
#         gamma_hyper=0.01,
#         rho_hyper=0.01,
#         pi_hyper=0.2,
#     ),
# )
# def simple_metagenotype2(
#     n,
#     g,
#     s,
#     a,
#     gamma_hyper,
#     rho_hyper,
#     pi_hyper,
# ):
#     with pyro.plate("position", g, dim=-1):
#         with pyro.plate("strain", s, dim=-2):
#             _gamma = pyro.sample("_gamma", dist.Beta(1.0, 1.0))
#             gamma = pyro.deterministic(
#                 "gamma",
#                 unit_interval_power_transformation(
#                     _gamma, 1 / gamma_hyper, 1 / gamma_hyper
#                 ),
#             )
#     pyro.deterministic("genotypes", gamma)
#
#     # Meta-community composition
#     rho_betas = pyro.sample(
#         "rho_betas", dist.Beta(1.0, rho_hyper).expand([s - 1]).to_event()
#     )
#     rho = pyro.deterministic("rho", stickbreaking_betas_to_probs(rho_betas))
#
#     with pyro.plate("sample", n, dim=-1):
#         # Community composition
#         pi = pyro.sample(
#             "pi",
#             dist.Dirichlet(
#                 pi_hyper * rho,
#                 validate_args=False,
#             ),
#         )
#     pyro.deterministic("communities", pi)
#
#     # Depth at each position
#     m = pyro.sample(
#         "m",
#         NegativeBinomialReparam(
#             torch.tensor(100),
#             torch.tensor(0.1),
#         )
#         .expand([n, g])
#         .to_event(),
#     )
#
#     # Expected fractions of each allele at each position
#     p = pyro.deterministic("p", pi @ gamma)
#
#     # Observation
#     y = pyro.sample(
#         "y",
#         dist.Binomial(
#             probs=p,
#             total_count=m,
#         ).to_event(),
#     )
#     metagenotypes = pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))
#
#
# @sf.model.structure(
#     dims=SHARED_DIMS,
#     description=_mapping_subset(
#         SHARED_DESCRIPTIONS,
#         ["m", "y", "epsilon", "genotypes", "rho", "communities", "metagenotypes"],
#     ),
#     default_hyperparameters=dict(
#         gamma_hyper=0.01,
#         rho_hyper=0.01,
#         pi_hyper=0.2,
#         epsilon_hyper_alpha=1.5,
#         epsilon_hyper_beta=1.5 / 0.01,
#     ),
# )
# def simple_metagenotype_plus_error(
#     n,
#     g,
#     s,
#     a,
#     gamma_hyper,
#     rho_hyper,
#     pi_hyper,
#     epsilon_hyper_alpha,
#     epsilon_hyper_beta,
# ):
#     with pyro.plate("position", g, dim=-1):
#         with pyro.plate("strain", s, dim=-2):
#             gamma = pyro.sample("gamma", dist.Beta(gamma_hyper, gamma_hyper))
#     pyro.deterministic("genotypes", gamma)
#
#     # Meta-community composition
#     rho_betas = pyro.sample(
#         "rho_betas", dist.Beta(1.0, rho_hyper).expand([s - 1]).to_event()
#     )
#     rho = pyro.deterministic("rho", stickbreaking_betas_to_probs(rho_betas))
#
#     with pyro.plate("sample", n, dim=-1):
#         # Community composition
#         pi = pyro.sample(
#             "pi",
#             dist.Dirichlet(
#                 pi_hyper * rho,
#                 validate_args=False,
#             ),
#         )
#         epsilon = pyro.sample(
#             "epsilon", dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta)
#         ).unsqueeze(-1)
#     pyro.deterministic("communities", pi)
#
#     # Depth at each position
#     m = pyro.sample(
#         "m",
#         NegativeBinomialReparam(
#             torch.tensor(100),
#             torch.tensor(0.1),
#         )
#         .expand([n, g])
#         .to_event(),
#     )
#
#     # Expected fractions of each allele at each position
#     p = pyro.deterministic("p", pi @ gamma)
#
#     # Observation
#     y = pyro.sample(
#         "y",
#         dist.Binomial(
#             probs=p,
#             total_count=m,
#             validate_args=False,
#         ).to_event(),
#     )
#     metagenotypes = pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))


### model_zoo/full_metagenotype.py

In [None]:
%%writefile sfacts/model_zoo/full_metagenotype.py
import sfacts.model
from sfacts.model_zoo.components import (
    _mapping_subset,
    unit_interval_power_transformation,
    stickbreaking_betas_to_probs,
    NegativeBinomialReparam,
    SHARED_DESCRIPTIONS,
    SHARED_DIMS,
)
import torch
import pyro
import pyro.distributions as dist


@sfacts.model.structure(
    dims=SHARED_DIMS,
    description=_mapping_subset(
        SHARED_DESCRIPTIONS,
        [
            "rho",
            "epsilon",
            "m_hyper_r_mean",
            "m_hyper_r_scale",
            "m_hyper_r",
            "mu",
            "nu",
            "p_noerr",
            "p",
            "m",
            "y",
            "alpha_hyper_mean",
            "alpha",
            "genotypes",
            "missingness",
            "communities",
            "metagenotypes",
        ],
    ),
    default_hyperparameters=dict(
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_r=0.9,
        rho_hyper=5.0,
        pi_hyper=0.2,
        mu_hyper_mean=1.0,
        mu_hyper_scale=10.0,
        epsilon_hyper_mode=0.01,
        epsilon_hyper_spread=1.5,
        alpha_hyper_hyper_mean=100.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=0.5,
    ),
)
def full_metagenotype_model_structure(
    n,
    g,
    s,
    a,
    gamma_hyper,
    delta_hyper_r,
    delta_hyper_temp,
    rho_hyper,
    pi_hyper,
    alpha_hyper_hyper_mean,
    alpha_hyper_hyper_scale,
    alpha_hyper_scale,
    mu_hyper_mean,
    mu_hyper_scale,
    epsilon_hyper_mode,
    epsilon_hyper_spread,
    _unit,
):
    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            _gamma = pyro.sample("_gamma", dist.Beta(_unit, _unit))
            gamma = pyro.deterministic(
                "gamma",
                unit_interval_power_transformation(
                    _gamma, 1 / gamma_hyper, 1 / gamma_hyper
                ),
            )
            # Position presence/absence
            delta = pyro.sample(
                "delta",
                dist.RelaxedBernoulli(
                    temperature=delta_hyper_temp, probs=delta_hyper_r
                ),
            )
    pyro.deterministic("genotypes", gamma)
    pyro.deterministic("missingness", delta)

    # Meta-community composition
    rho_betas = pyro.sample(
        "rho_betas", dist.Beta(1.0, rho_hyper).expand([s - 1]).to_event()
    )
    rho = pyro.deterministic("rho", stickbreaking_betas_to_probs(rho_betas))
    pyro.deterministic("metacommunity", rho)

    alpha_hyper_mean = pyro.sample(
        "alpha_hyper_mean",
        dist.LogNormal(
            loc=torch.log(alpha_hyper_hyper_mean),
            scale=alpha_hyper_hyper_scale,
        ),
    )
    m_hyper_r_mean = pyro.sample("m_hyper_r_mean", dist.LogNormal(loc=_unit * 0.0, scale=_unit * 10.))
    m_hyper_r_scale = pyro.sample(
        "m_hyper_r_scale", dist.LogNormal(loc=_unit * 0.0, scale=_unit * 10.)
    )

    with pyro.plate("sample", n, dim=-1):
        # Community composition
        pi = pyro.sample("pi", dist.Dirichlet(pi_hyper * rho, validate_args=False))
        # Sequencing error
        epsilon = pyro.sample(
            "epsilon",
            dist.Beta(epsilon_hyper_spread, epsilon_hyper_spread / epsilon_hyper_mode),
        ).unsqueeze(-1)
        alpha = pyro.sample(
            "alpha",
            dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale),
        ).unsqueeze(-1)
        m_hyper_r = pyro.sample(
            "m_hyper_r",
            dist.LogNormal(loc=torch.log(m_hyper_r_mean), scale=m_hyper_r_scale),
        ).unsqueeze(-1)
        # Sample coverage
        mu = pyro.sample(
            "mu",
            dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale),
        )
    pyro.deterministic("communities", pi)

    # Depth at each position
    nu = pyro.deterministic("nu", pi @ delta)
    # TODO: Consider using pyro.distributions.GammaPoisson parameterization?
    m = pyro.sample(
        "m",
        NegativeBinomialReparam(nu * mu.reshape((-1, 1)), m_hyper_r).to_event(),
    )

    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic("p_noerr", pi @ (gamma * delta) / nu)
    p = pyro.deterministic(
        "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
    )

    # Observation
    y = pyro.sample(
        "y",
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
        ).to_event(),
    )
    metagenotypes = pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))


### model_zoo/full_metagenotype_dirichlet_rho.py

In [None]:
%%writefile sfacts/model_zoo/full_metagenotype_dirichlet_rho.py
import sfacts.model
from sfacts.model_zoo.components import (
    _mapping_subset,
    unit_interval_power_transformation,
    stickbreaking_betas_to_probs,
    NegativeBinomialReparam,
    SHARED_DESCRIPTIONS,
    SHARED_DIMS,
)
import torch
import pyro
import pyro.distributions as dist


@sfacts.model.structure(
    dims=SHARED_DIMS,
    description=_mapping_subset(
        SHARED_DESCRIPTIONS,
        [
            "rho",
            "epsilon",
            "m_hyper_r_mean",
            "m_hyper_r_scale",
            "m_hyper_r",
            "mu",
            "nu",
            "p_noerr",
            "p",
            "m",
            "y",
            "alpha_hyper_mean",
            "alpha",
            "genotypes",
            "missingness",
            "communities",
            "metagenotypes",
        ],
    ),
    default_hyperparameters=dict(
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_r=0.9,
        rho_hyper=5.0,
        pi_hyper=0.2,
        mu_hyper_mean=1.0,
        mu_hyper_scale=10.0,
        epsilon_hyper_mode=0.01,
        epsilon_hyper_spread=1.5,
        alpha_hyper_hyper_mean=100.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=0.5,
    ),
)
def full_metagenotype_dirichlet_rho_model_structure(
    n,
    g,
    s,
    a,
    gamma_hyper,
    delta_hyper_r,
    delta_hyper_temp,
    rho_hyper,
    pi_hyper,
    alpha_hyper_hyper_mean,
    alpha_hyper_hyper_scale,
    alpha_hyper_scale,
    mu_hyper_mean,
    mu_hyper_scale,
    epsilon_hyper_mode,
    epsilon_hyper_spread,
    _unit,
):
    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            _gamma = pyro.sample("_gamma", dist.Beta(_unit, _unit))
            gamma = pyro.deterministic(
                "gamma",
                unit_interval_power_transformation(
                    _gamma, 1 / gamma_hyper, 1 / gamma_hyper
                ),
            )
            # Position presence/absence
            delta = pyro.sample(
                "delta",
                dist.RelaxedBernoulli(
                    temperature=delta_hyper_temp, probs=delta_hyper_r
                ),
            )
    pyro.deterministic("genotypes", gamma)
    pyro.deterministic("missingness", delta)

    # Meta-community composition
    rho = pyro.sample("rho", dist.Dirichlet(_unit.repeat(s) * rho_hyper))
    pyro.deterministic("metacommunity", rho)

    alpha_hyper_mean = pyro.sample(
        "alpha_hyper_mean",
        dist.LogNormal(
            loc=torch.log(alpha_hyper_hyper_mean),
            scale=alpha_hyper_hyper_scale,
        ),
    )
    m_hyper_r_mean = pyro.sample("m_hyper_r_mean", dist.LogNormal(loc=_unit * 0.0, scale=_unit * 10.))
    m_hyper_r_scale = pyro.sample(
        "m_hyper_r_scale", dist.LogNormal(loc=_unit * 0.0, scale=_unit * 10.)
    )

    with pyro.plate("sample", n, dim=-1):
        # Community composition
        pi = pyro.sample("pi", dist.Dirichlet(pi_hyper * rho, validate_args=False))
        # Sequencing error
        epsilon = pyro.sample(
            "epsilon",
            dist.Beta(epsilon_hyper_spread, epsilon_hyper_spread / epsilon_hyper_mode),
        ).unsqueeze(-1)
        alpha = pyro.sample(
            "alpha",
            dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale),
        ).unsqueeze(-1)
        m_hyper_r = pyro.sample(
            "m_hyper_r",
            dist.LogNormal(loc=torch.log(m_hyper_r_mean), scale=m_hyper_r_scale),
        ).unsqueeze(-1)
        # Sample coverage
        mu = pyro.sample(
            "mu",
            dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale),
        )
    pyro.deterministic("communities", pi)

    # Depth at each position
    nu = pyro.deterministic("nu", pi @ delta)
    # TODO: Consider using pyro.distributions.GammaPoisson parameterization?
    m = pyro.sample(
        "m",
        NegativeBinomialReparam(nu * mu.reshape((-1, 1)), m_hyper_r).to_event(),
    )

    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic("p_noerr", pi @ (gamma * delta) / nu)
    p = pyro.deterministic(
        "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
    )

    # Observation
    y = pyro.sample(
        "y",
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
        ).to_event(),
    )
    metagenotypes = pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))


### estimation.py

In [None]:
%%writefile sfacts/estimation.py
import sfacts as sf
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import non_negative_factorization

import xarray as xr

# from sklearn.decomposition import non_negative_factorization
# from sfacts.genotype import genotype_pdist, adjust_genotype_by_missing
from sfacts.pyro_util import all_torch
import pandas as pd
import numpy as np

# import scipy as sp
# from scipy.spatial.distance import squareform
import pyro

# import pyro.distributions as dist
import torch
from tqdm import tqdm
from sfacts.logging_util import info


def nmf_approximation(
    world,
    s,
    regularization="both",
    alpha=1.0,
    l1_ratio=1.0,
    tol=1e-4,
    max_iter=int(1e4),
    random_state=None,
    init="random",
    **kwargs,
):
    d = world.metagenotypes.to_series().unstack("sample")
    columns = d.columns
    index = d.index

    gamma0, pi0, _ = non_negative_factorization(
        d.values,
        n_components=s,
        alpha=alpha,
        l1_ratio=l1_ratio,
        tol=tol,
        max_iter=max_iter,
        random_state=random_state,
        init=init,
        **kwargs,
    )
    pi1 = (
        pd.DataFrame(pi0, columns=columns)
        .rename_axis(index="strain")
        .stack()
        .to_xarray()
    )
    gamma1 = (
        pd.DataFrame(gamma0, index=index)
        .rename_axis(columns="strain")
        .stack()
        .to_xarray()
    )

    # Rebalance estimates: mean strain genotype of 1
    gamma1_strain_factor = gamma1.sum("allele").mean("position")
    gamma2 = gamma1 / gamma1_strain_factor
    pi2 = pi1 * gamma1_strain_factor

    # Transform estimates: sum-to-1
    gamma3 = (gamma2 / gamma2.sum("allele")).fillna(0.5)
    pi3 = pi2 / pi2.sum("strain")

    approx = sf.data.World(
        xr.Dataset(
            dict(
                communities=pi3.transpose("sample", "strain"),
                genotypes=gamma3.sel(allele="alt").transpose("strain", "position"),
                metagenotypes=world.metagenotypes.data,
            )
        )
    )
    return approx


def estimate_parameters(
    model,
    dtype=torch.float32,
    device="cpu",
    initialize_params=None,
    jit=True,
    maxiter=10000,
    lagA=20,
    lagB=100,
    opt=pyro.optim.Adamax({"lr": 1e-2}, {"clip_norm": 100}),
    quiet=False,
    seed=None,
):
    if initialize_params is None:
        initialize_params = {}
        
    if jit:
        loss = pyro.infer.JitTrace_ELBO()
    else:
        loss = pyro.infer.Trace_ELBO()

    sf.pyro_util.set_random_seed(seed, warn=(not quiet))

    _guide = pyro.infer.autoguide.AutoLaplaceApproximation(
        model,
        init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(
            values=all_torch(**initialize_params, dtype=dtype, device=device)
        ),
    )
    svi = pyro.infer.SVI(model, _guide, opt, loss=loss)
    pyro.clear_param_store()

    history = []
    pbar = tqdm(range(maxiter), disable=quiet)
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                pbar.close()
                raise RuntimeError("ELBO NaN?")

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if i % 10 == 0:
                if i > lagB:
                    delta = history[-2] - history[-1]
                    delta_lagA = (history[-lagA] - history[-1]) / lagA
                    delta_lagB = (history[-lagB] - history[-1]) / lagB
                    pbar.set_postfix(
                        {
                            "ELBO": history[-1],
                            "delta": delta,
                            f"lag{lagA}": delta_lagA,
                            f"lag{lagB}": delta_lagB,
                        }
                    )
                    if (delta_lagA <= 0) and (delta_lagB <= 0):
                        pbar.close()
#                         info("Converged", quiet=quiet)
                        break
    except KeyboardInterrupt:
        pbar.close()
        info("Interrupted", quiet=quiet)
        pass
    est = pyro.infer.Predictive(model, guide=_guide, num_samples=1)()
    est = {k: est[k].detach().cpu().numpy().mean(0).squeeze() for k in est.keys()}

    if device.startswith("cuda"):
        #         info(
        #             "CUDA available mem: {}".format(
        #                 torch.cuda.get_device_properties(0).total_memory
        #             ),
        #         )
        #         info("CUDA reserved mem: {}".format(torch.cuda.memory_reserved(0)))
        #         info("CUDA allocated mem: {}".format(torch.cuda.memory_allocated(0)))
        #         info(
        #             "CUDA free mem: {}".format(
        #                 torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
        #             )
        #         )
        torch.cuda.empty_cache()

    return model.format_world(est), history


def strain_cluster(world, thresh, linkage="complete", pdist_func=None):
    if pdist_func is None:
        pdist_func = lambda w: w.genotypes.pdist()
    clust = pd.Series(
        AgglomerativeClustering(
            n_clusters=None,
            distance_threshold=thresh,
            linkage="complete",
            affinity="precomputed",
        )
        .fit(pdist_func(world))
        .labels_,
        index=world.strain,
    )
    return clust


def communities_aggregated_by_strain_cluster(world, thresh, **kwargs):
    clust = strain_cluster(world, thresh, **kwargs)
    return sf.data.Communities(
        world.communities.to_pandas()
        .groupby(clust, axis="columns")
        .sum()
        .rename_axis(columns="strain")
        .stack()
        .to_xarray()
    )


### evaluation.py

### workflow.py

In [None]:
%%writefile sfacts/workflow.py
import sfacts as sf
import pyro
import time
import torch


def _chunk_start_end_iterator(total, per):
    for i in range(total // per):
        yield (per * i), (per * (i + 1))
    if (i + 1) * per < total:
        yield (i + 1) * per, total


def fit_metagenotypes_simple(
    structure,
    metagenotypes,
    nstrain,
    hyperparameters=None,
    condition_on=None,
    device='cpu',
    dtype=torch.float32,
    quiet=False,
    estimation_kwargs=None,
):
    sf.logging_util.info(f"START: Fitting data with shape {metagenotypes.sizes}.", quiet=quiet)
    model = sf.model.ParameterizedModel(
        structure,
        coords=dict(
            sample=metagenotypes.sample.values,
            position=metagenotypes.position.values,
            allele=metagenotypes.allele.values,
            strain=range(nstrain),
        ),
        hyperparameters=hyperparameters,
        data=condition_on,
        device=device,
        dtype=dtype,
    ).condition(
        **metagenotypes.to_counts_and_totals()
    )
    start_time = time.time()
    est, history = sf.estimation.estimate_parameters(
        model,
        quiet=quiet,
        **estimation_kwargs,
    )
    end_time = time.time()
    delta_time = end_time - start_time
    sf.logging_util.info(f"END: Fit in {delta_time} seconds.", quiet=quiet)
    return est, history

def fit_metagenotypes_then_refit_genotypes(
    structure,
    metagenotypes,
    nstrain,
    hyperparameters=None,
    stage2_hyperparameters=None,
    condition_on=None,
    device='cpu',
    dtype=torch.float32,
    quiet=False,
    estimation_kwargs=None
):
    if stage2_hyperparameters is None:
        stage2_hyperparameters = {}
        
    _estimate_parameters = lambda model: sf.estimation.estimate_parameters(
        model, quiet=quiet, **estimation_kwargs
    )

    sf.logging_util.info(f"START: Fitting data with shape {metagenotypes.sizes}.", quiet=quiet)
    model = sf.model.ParameterizedModel(
        structure,
        coords=dict(
            sample=metagenotypes.sample.values,
            position=metagenotypes.position.values,
            allele=metagenotypes.allele.values,
            strain=range(nstrain),
        ),
        hyperparameters=hyperparameters,
        data=condition_on,
        device=device,
        dtype=dtype,
    ).condition(**metagenotypes.to_counts_and_totals())

    start_time = time.time()
    est0, history0 = _estimate_parameters(model)
    sf.logging_util.info("Finished initial fitting.")
    sf.logging_util.info(f"Refitting missingness.", quiet=quiet)
    est1, history1 = _estimate_parameters(
        model
        .condition(
            pi=est0.data.communities.values,
            mu=est0.data.mu.values,
            alpha=est0.data.alpha.values,
            epsilon=est0.data.epsilon.values,
            m_hyper_r=est0.data.m_hyper_r.values,
        )
    )
    sf.logging_util.info(f"Refitting genotypes.", quiet=quiet)
    est2, history2 = _estimate_parameters(
        model
        .condition(
            delta=est1.data.missingness.values,
            pi=est1.data.communities.values,
            mu=est1.data.mu.values,
            alpha=est1.data.alpha.values,
            epsilon=est1.data.epsilon.values,
            m_hyper_r=est1.data.m_hyper_r.values,
        )
        .with_hyperparameters(**stage2_hyperparameters)
    )
    end_time = time.time()
    delta_time = end_time - start_time
    sf.logging_util.info(f"END: Fit in {delta_time} seconds.", quiet=quiet)
    return (est0, est1), (history0, history1)


def fit_metagenotype_subsample_collapse_then_iteratively_refit_full_genotypes(
    structure,
    metagenotypes,
    nstrain,
    nposition,
    thresh,
    hyperparameters=None,
    stage2_hyperparameters=None,
    condition_on=None,
    device='cpu',
    dtype=torch.float32,
    quiet=False,
    estimation_kwargs=None
):
    if stage2_hyperparameters is None:
        stage2_hyperparameters = {}

    _estimate_parameters = lambda model: sf.estimation.estimate_parameters(
        model, quiet=quiet, **estimation_kwargs,
    )
    _info = lambda *args, **kwargs: sf.logging_util.info(*args, quiet=quiet, **kwargs)

    _info(f"START: Fitting data with shape {metagenotypes.sizes}.")
    _info(f"Fitting strain compositions using {nposition} randomly sampled positions.")
    metagenotypes_ss = metagenotypes.random_sample(nposition, 'position')
    model = sf.model.ParameterizedModel(
        structure,
        coords=dict(
            sample=metagenotypes.sample.values,
            position=metagenotypes_ss.position.values,
            allele=metagenotypes.allele.values,
            strain=range(nstrain),
        ),
        hyperparameters=hyperparameters,
        data=condition_on,
        device=device,
        dtype=dtype,
    )

    start_time = time.time()
    est_curr, _ = _estimate_parameters(
        model
        .condition(**metagenotypes_ss.to_counts_and_totals())
    )
    _info(f"Finished initial fitting.")
    _info(f"Refitting genotypes with {stage2_hyperparameters}.")
    est_curr, _ = _estimate_parameters(
        model
        .with_hyperparameters(**stage2_hyperparameters)
        .condition(
            delta=est_curr.data.missingness.values,
            pi=est_curr.data.communities.values,
            mu=est_curr.data.mu.values,
            alpha=est_curr.data.alpha.values,
            epsilon=est_curr.data.epsilon.values,
            m_hyper_r=est_curr.data.m_hyper_r.values,
        )
        .condition(**metagenotypes_ss.to_counts_and_totals()),
    )
    _info(f"Collapsing {nstrain} initial strains.")
    agg_communities = sf.estimation.communities_aggregated_by_strain_cluster(
        est_curr, thresh=thresh, pdist_func=lambda w: w.genotypes.pdist(quiet=quiet),
    )
    _info(f"{agg_communities.sizes['strain']} strains after collapsing.")
    
    _info(f"Iteratively refitting missingness/genotypes.")
    chunks = {}
    for position_start, position_end in _chunk_start_end_iterator(
        metagenotypes.sizes['position'],
        nposition,
    ):
        _info(f"Fitting bin ({position_start}, {position_end}).")
        metagenotypes_chunk = metagenotypes.mlift('isel', position=slice(position_start, position_end))
        est_curr, _ = _estimate_parameters(
            model
            .with_amended_coords(
                position=metagenotypes_chunk.position.values,
                strain=agg_communities.strain.values,
            )
            .condition(
                pi=agg_communities.values,
                mu=est_curr.data.mu.values,
                alpha=est_curr.data.alpha.values,
                epsilon=est_curr.data.epsilon.values,
                m_hyper_r=est_curr.data.m_hyper_r.values,
            )
            .condition(**metagenotypes_chunk.to_counts_and_totals()),
        )
        est_curr, _ = _estimate_parameters(
            model
            .with_amended_coords(
                position=metagenotypes_chunk.position.values,
                strain=agg_communities.strain.values,
            )
            .with_hyperparameters(**stage2_hyperparameters)
            .condition(
                delta=est_curr.data.missingness.values,
                pi=est_curr.data.communities.values,
                mu=est_curr.data.mu.values,
                alpha=est_curr.data.alpha.values,
                epsilon=est_curr.data.epsilon.values,
                m_hyper_r=est_curr.data.m_hyper_r.values,
            )
            .condition(**metagenotypes_chunk.to_counts_and_totals()),
        )
        chunks[position_start] = est_curr
    est_curr = sf.data.World.concat(chunks, dim='position', rename_coords=False)
    end_time = time.time()
    delta_time = end_time - start_time
    _info(f"END: Fit in {delta_time} seconds.")
    return est_curr


# def fit_then_relax_genotypes_and_collapse(
#     model,
#     world,
#     thresh,
#     initialize_params=None,
#     stage2_hyperparameters=None,
#     quiet=False,
#     **kwargs,
# ):
#     if stage2_hyperparameters is None:
#         stage2_hyperparameters = {}

#     est0, history0 = fit_simple(
#         model,
#         world,
#         initialize_params=initialize_params,
#         quiet=quiet,
#         **kwargs,
#     )
#     est1, history1 = fit_simple(
#         model.with_hyperparameters(**stage2_hyperparameters).condition(
#             pi=est0.data.communities.values,
#             mu=est0.data.mu.values,
#             alpha=est0.data.alpha.values,
#             epsilon=est0.data.epsilon.values,
#             m_hyper_r=est0.data.m_hyper_r.values,
#         ),
#         world,
#         quiet=quiet,
#         **kwargs,
#     )
#     agg_communities = sf.estimation.communities_aggregated_by_strain_cluster(
#         est1, thresh=thresh
#     )
#     est2, history2 = fit_simple(
#         model.with_amended_coords(strain=agg_communities.strain).condition(
#             pi=agg_communities.values,
#             mu=est1.data.mu.values,
#             alpha=est1.data.alpha.values,
#             epsilon=est1.data.epsilon.values,
#             m_hyper_r=est1.data.m_hyper_r.values,
#         ),
#         world,
#         quiet=quiet,
#         **kwargs,
#     )
#     est3, history3 = fit_simple(
#         model.with_amended_coords(strain=agg_communities.strain)
#         .with_hyperparameters(**stage2_hyperparameters)
#         .condition(
#             delta=est2.missingness.values,
#             pi=est2.communities.values,
#             mu=est2.data.mu.values,
#             alpha=est2.data.alpha.values,
#             epsilon=est2.data.epsilon.values,
#             m_hyper_r=est2.data.m_hyper_r.values,
#         ),
#         world,
#         quiet,
#         **kwargs,
#     )
#     return (est0, est1, est2, est3), (history0, history1, history2, history3)


# def fit_subsample_then_refit_relaxed_genotypes(
#     model,
#     world,
#     npositions,
#     stage2_hyperparameters,
#     initialize_params=None,
#     quiet=False,
#     **kwargs,
# ):
#     est0, history0 = fit_simple(model, world, initialize_params=initialize_params, quiet=quiet, **kwargs)
#     est1, history1 = fit_simple(
#         model.with_hyperparameters(**stage2_hyperparameters).condition(
#             # FIXME: Drop conditining on delta for consistency with 3-stage workflow?
#             delta=est0.data.missingness.values,
#             pi=est0.data.communities.values,
#             mu=est0.data.mu.values,
#             alpha=est0.data.alpha.values,
#             epsilon=est0.data.epsilon.values,
#             m_hyper_r=est0.data.m_hyper_r.values,
#         ),
#         world,
#         quiet=quiet,
#         **kwargs,
#     )
#     return (est0, est1), (history0, history1)


# def simulation_benchmark(
#     nsample,
#     nposition,
#     sim_nstrain,
#     fit_nstrain,
#     sim_model,
#     fit_model=None,
#     sim_data=None,
#     sim_hyperparameters=None,
#     sim_seed=None,
#     fit_data=None,
#     fit_hyperparameters=None,
#     fit_seed=None,
#     opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
#     **fit_kwargs,
# ):
#     if fit_model is None:
#         fit_model = sim_model

#     sim = sf.model.ParameterizedModel(
#         sim_model,
#         coords=dict(
#             sample=nsample,
#             position=nposition,
#             allele=["alt", "ref"],
#             strain=sim_nstrain,
#         ),
#         data=sim_data,
#         hyperparameters=sim_hyperparameters,
#     ).simulate_world(seed=sim_seed)

#     start_time = time.time()
#     est, history, *_ = sf.estimation.estimate_parameters(
#         sf.model.ParameterizedModel(
#             fit_model,
#             coords=dict(
#                 sample=nsample,
#                 position=nposition,
#                 allele=["alt", "ref"],
#                 strain=fit_nstrain,
#             ),
#             data=fit_data,
#             hyperparameters=fit_hyperparameters,
#         ).condition(
#             **sim.metagenotypes.to_counts_and_totals(),
#         ),
#         opt=opt,
#         seed=fit_seed,
#         **fit_kwargs,
#     )
#     end_time = time.time()

#     return (
#         sf.evaluation.weighted_genotype_error(sim, est),
#         sf.evaluation.community_error(sim, est),
#         end_time - start_time,
#     )