## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

## Library

In [None]:
!ls sfacts

### `__init__.py`

In [None]:
%%writefile sfacts/__init__.py
from sfacts import (
    logging_util,
    pyro_util,
    pandas_util,
    model,
    model_zoo,
    plot,
    estimation,
    evaluation,
#     workflow,
    data,
#     app,
)

### pyro_util.py

In [None]:
%%writefile sfacts/pyro_util.py
import pyro
import pyro.distributions as dist
import torch
from sfacts.logging_util import info
import warnings


def as_torch(x, dtype=None, device=None):
    # Cast inputs and set device
    if isinstance(x, torch.Tensor):
        return torch.tensor(x.numpy(), dtype=dtype, device=device)
    else:
        return torch.tensor(x, dtype=dtype, device=device)


def all_torch(dtype=None, device=None, **kwargs):
    # Cast inputs and set device
    return {k: as_torch(kwargs[k], dtype=dtype, device=device) for k in kwargs}


def shape_info(model, *args, **kwargs):
    _trace = pyro.poutine.trace(model).get_trace(*args, **kwargs)
    _trace.compute_log_prob()
    info(_trace.format_shapes())

def set_random_seed(seed, warn=True):
    if seed is not None:
        pyro.set_rng_seed(seed)

### data.py

In [None]:
%%writefile sfacts/data.py
from sfacts.logging_util import info
from sfacts.pandas_util import idxwhere
import xarray as xr
import numpy as np
from tqdm import tqdm
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage
import scipy as sp
from functools import partial


def _on_2_simplex(d):
    return (d.min() >= 0) and (d.max() <= 1.0)


def _strictly_positive(d):
    return d.min() > 0


def _positive_counts(d):
    return (d.astype(int) == d).all()


class WrappedDataArrayMixin():
    constraints = {}
    
    # The following are all white-listed and
    # transparently passed through to self.data, but with
    # different symantics for the return value.
    dims = ()
    safe_unwrapped = ['shape', 'sizes', 'to_pandas', 'to_dataframe', 'min', 'max', 'sum', 'mean', 'median', 'values', 'pipe', 'to_series']
    safe_lifted = ['isel', 'sel']
    variable_name = None
    
    @classmethod
    def from_ndarray(cls, x, coords=None):
        if coords is None:
            coords = {k: None for k in cls.dims}
        shapes = {k: x.shape[i] for i, k in enumerate(cls.dims)}
        for k in coords:
            if coords[k] is None:
                coords[k] = range(shapes[k])
        data = xr.DataArray(
            x,
            dims=cls.dims,
            coords=coords,
        )
        return cls(data)
    
    @classmethod
    def stack(cls, mapping, dim, prefix=False, validate=True):
        if not len(cls.dims) == 2:
            raise NotImplementedError("Generic stacking has only been implemented for 2D wrapped DataArrays")
        axis = cls.dims.index(dim)
        data = []
        for k, d in mapping.items():
            if prefix:
                d = d.to_pandas().rename(lambda s: f"{k}_{s}", axis=axis).stack().to_xarray()
            else:
                d = d.data
            data.append(d)
        out = cls(xr.concat(data, dim=dim))
        if validate:
            out.validate_constraints()
        return out
    
    
    def __init__(self, data):
        self.data = data
        self.validate_fast()
        
    def __getattr__(self, name):
        if name in self.dims:
            return getattr(self.data, name)
        elif name in self.safe_unwrapped:
            return getattr(self.data, name)
        elif name in self.safe_lifted:
            return lambda *args, **kwargs: self.__class__(getattr(self.data, name)(*args, **kwargs))
        else:
            raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}' "
                                 f"and this name is not found in '{self.__class__.__name__}.dims', "
                                 f"'{self.__class__.__name__}.safe_unwrapped', "
                                 f"or '{self.__class__.__name__}.safe_lifted'. "
                                 f"Consider working with the '{self.__class__.__name__}.data' "
                                 f"xr.DataArray object directly.")
    
    def validate_fast(self):
        assert len(self.data.shape) == len(self.dims)
        assert self.data.dims == self.dims
        
    def validate_constraints(self):
        self.validate_fast()
        for name in self.constraints:
            assert self.constraints[name](self.data), f"Failed constraint: {name}"
        
    def lift(self, func, *args, **kwargs):
        return self.__class__(self.data.pipe(func, *args, **kwargs))
        
    def __repr__(self):
        return f'{self.__class__.__name__}({self.data})'
    
    @classmethod
    def concat(cls, data, dim):
        out_data = []
        new_coords = []
        for name in data:
            d = data[name].data
            out_data.append(d)
            new_coords.extend([f"{name}_{i}" for i in d[dim].values])
        out_data = xr.concat(out_data, dim)
        out_data[dim] = new_coords
        return cls(out_data)
    
    def to_world(self):
        return World(self.data.to_dataset())


class Metagenotypes(WrappedDataArrayMixin):
    dims = ('sample', 'position', 'allele')
    constraints = dict(
        positive_counts = _positive_counts
    )
    variable_name = 'metagenotypes'

    @classmethod
    def load(cls, filename_or_obj, validate=True):
        data = xr.open_dataarray(filename_or_obj).rename({'library_id': 'sample'}).squeeze(drop=True)
        data.name = 'metagenotypes'
        result = cls(data)
        if validate:
            result.validate_constraints()
        return result
    
    @classmethod
    def from_counts_and_totals(cls, y, m, coords=None):
        if coords is None:
            coords = {}
        if not 'allele' in coords:
            coords['allele'] = ['alt', 'ref']
        x = np.stack([y, m - y], axis=-1)
        return cls.from_ndarray(x, coords=coords)
        
    def dump(self, path, validate=True):
        if validate:
            self.validate_constraints()
        self.data.astype(np.uint8).to_dataset(name="tally").to_netcdf(
            path,
            encoding=dict(tally=dict(zlib=True, complevel=6))
        )

    def select_variable_positions(self, incid_thresh, allele_thresh=0):
        # TODO: Consider using .lift() to do this.
        x = self.data
        minor_allele_incid = (x > allele_thresh).mean("sample").min("allele")
        variable_positions = idxwhere(minor_allele_incid.to_series() > incid_thresh)
        return self.__class__(x.sel(position=variable_positions))

    def select_samples_with_coverage(self, cvrg_thresh):
        # TODO: Consider using .lift() to do this.
        x = self.data
        covered_samples = (x.sum("allele") > 0).mean("position") > cvrg_thresh
        return self.__class__(x.sel(sample=covered_samples))

    def frequencies(self, pseudo=0.):
        "Convert metagenotype counts to a frequencies with optional pseudocount."
        return (self.data + pseudo) / (self.data.sum('allele') + pseudo * self.sizes['allele'])
    
    def dominant_allele_fraction(self, pseudo=0.):
        "Convert metagenotype counts to a frequencies with optional pseudocount."
        return self.frequencies(pseudo=pseudo).max('allele')

    def to_genotype_estimates(self, pseudo=1.):
        data = self.frequencies(pseudo=pseudo).sel(allele='alt').rename({'sample': 'strain'})
        return Genotypes(data)

    def to_counts_and_totals(self, binary_allele='alt'):
        return dict(y=self.data.sel(allele=binary_allele).values, m=self.data.sum('allele').values)
    
    def pdist(self, dim='strain', **kwargs):
        return self.to_genotype_estimates().pdist(dim=dim, pseudo=pseudo, **kwargs)
    
    def linkage(self, dim='strain', **kwargs):
        return self.to_genotype_estimates().linkage(dim=dim, **kwargs)


class Genotypes(WrappedDataArrayMixin):
    dims = ('strain', 'position')
    constraints = dict(
        on_2_simplex = _on_2_simplex
    )
    variable_name = 'genotypes'
    
    def fuzz_missing(self, missingness, eps=1e-10):
        clip = partial(np.clip, a_min=eps, a_max=(1 - eps))
        return self.lift(lambda g, m: sp.special.expit(sp.special.logit(clip(g)) * clip(m)), m=missingness.data)
    
    # TODO: Move distance metrics to a new module?
    @staticmethod
    def _convert_to_sign_representation(p):
        "Alternative representation of binary genotype on a [-1, 1] interval."
        return p * 2 - 1


    @staticmethod
    def _genotype_sign_representation_dissimilarity(x, y, pseudo=0.):
        "Dissimilarity between 1D genotypes, accounting for fuzzyness."
        dist = ((x - y) / 2) ** 2
        weight = (x * y) ** 2
        wmean_dist = ((weight * dist).mean()) / ((weight.mean() + pseudo))
        return wmean_dist

    @staticmethod
    def _genotype_dissimilarity(x, y, pseudo=0.):
        return self._genotype_sign_representation_dissimilarity(
            self._genotype_p_to_s(x), self._genotype_p_to_s(y)
        )
    
    @staticmethod
    def _genotype_dissimilarity_cdmat(unwrapped_values, pseudo=0., quiet=True):
        g_sign = Genotypes._convert_to_sign_representation(unwrapped_values)
        s, _ = g_sign.shape
        cdmat = np.empty((s * (s - 1)) // 2)
        k = 0
        with tqdm(total=len(cdmat), disable=quiet) as pbar:
            for i in range(0, s - 1):
                for j in range(i + 1, s):
                    cdmat[k] = Genotypes._genotype_sign_representation_dissimilarity(g_sign[i], g_sign[j], pseudo=pseudo)
                    k = k + 1
                    pbar.update()
        return cdmat

    def pdist(self, dim='strain', pseudo=0., quiet=True):
        index = getattr(self, dim)
        if dim == 'strain':
            unwrapped_values = self.values
            cdmat = self._genotype_dissimilarity_cdmat(unwrapped_values, quiet=quiet, pseudo=pseudo)
        elif dim == 'position':
            unwrapped_values = self.values.T
            cdmat = pdist(self._convert_to_sign_representation(self.values.T), metric='cosine')
        # Reboxing
        dmat = pd.DataFrame(squareform(cdmat), index=index, columns=index)
        return dmat
    
    def linkage(self, dim='strain', pseudo=0., quiet=True, method="complete", optimal_ordering=True, **kwargs):
        dmat = self.pdist(dim=dim, pseudo=pseudo, quiet=quiet)
        cdmat = squareform(dmat)
        return linkage(cdmat, method=method, optimal_ordering=optimal_ordering, **kwargs)
    
    @property
    def entropy(self):
        p = self.data
        q = 1 - p
        ent = -(p * np.log2(p) + q * np.log2(q))
        return ent.sum("position").rename("entropy")


class Missingness(WrappedDataArrayMixin):
    dims = ('strain', 'position')
    constraints = dict(
        on_2_simplex = _on_2_simplex
    )
    variable_name = 'missingness'
        
        
class Communities(WrappedDataArrayMixin):
    dims = ('sample', 'strain')
    constraints = dict(
        strains_sum_to_1 = lambda d: (d.sum('strain') == 1.0).all()
    )
    variable_name = 'communities'
    
    def pdist(self, dim='strain', quiet=True):
        index = getattr(self, dim)
        if dim == 'strain':
            unwrapped_values = self.values.T
            cdmat = pdist(unwrapped_values, metric='cosine') 
        elif dim == 'sample':
            unwrapped_values = self.values
            cdmat = pdist(unwrapped_values, metric='braycurtis') 
        # Reboxing
        dmat = pd.DataFrame(squareform(cdmat), index=index, columns=index)
        return dmat
    
    def linkage(self, dim='strain', quiet=True, method="average", optimal_ordering=True, **kwargs):
        dmat = self.pdist(dim=dim, quiet=quiet)
        cdmat = squareform(dmat)
        return linkage(cdmat, method=method, optimal_ordering=optimal_ordering, **kwargs)


class Overdispersion(WrappedDataArrayMixin):
    dims = ('sample',)
    constraints = dict(
        strains_sum_to_1 = _strictly_positive
    )
    variable_name = 'overdispersion'


class ErrorRate(WrappedDataArrayMixin):
    dims = ('sample',)
    constraints = dict(
        on_2_simplex = _on_2_simplex
    )
    variable_name = 'error_rate'
    

class World():
    safe_lifted = ['isel', 'sel']
    safe_unwrapped = ['sizes']
    dims = ('sample', 'position', 'strain', 'allele')
    variables = [Genotypes, Missingness, Communities, Metagenotypes]
    _variable_wrapper_map = {wrapper.variable_name: wrapper for wrapper in variables}
    
    def __init__(self, data):
        self.data = data  # self._align_dims(data)
        self.validate_fast()
        
#     @classmethod
#     def _align_dims(cls, data):
#         missing_dims = [k for k in cls.dims if k not in data.dims]
#         return data.expand_dims(missing_dims).transpose(*cls.dims)
        
    def validate_fast(self):
        assert not (set(self.data.dims) - set(self.dims)), f"Found data dims that shouldn't exist: {self.data.dims}"
        
    def validate_constraints(self):
        self.validate_fast()
        for variable_name in _variable_wrapper_map:
            if variable_name in self.data:
                wrapped_variable = getattr(self, name)
                wrapped_variable.validate_constraints()
    
    @property
    def fuzzed_genotypes(self):
        return self.genotypes.fuzz_missing(self.missingness)
    
    def __getattr__(self, name):
        if name in self.dims:
            # Return dims for those registered in self.dims.
            return getattr(self.data, name)
        if name in self._variable_wrapper_map:
            # Return wrapped variables for those registered in self.variables.
            return self._variable_wrapper_map[name](self.data[name])
        elif name in self.safe_unwrapped:
            # Return a naked version of the variables registered in self.safe_unwrapped
            return getattr(self.data, name)
        elif name in self.safe_lifted:
            # Return a lifted version of the the attributes registered in safe_lifted
            return lambda *args, **kwargs: self.__class__(getattr(self.data, name)(*args, **kwargs))
        else:
            raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}' "
                                 f"and this name is not found in '{self.__class__.__name__}.dims', "
                                 f"'{self.__class__.__name__}.safe_unwrapped', "
                                 f"or '{self.__class__.__name__}.safe_lifted'. "
                                 f"Consider working with the '{self.__class__.__name__}.data' object directly.")
            
    @classmethod
    def concat(cls, data, dim):
        out_data = []
        new_coords = []
        for name in data:
            d = data[name].data
            d['_concat_from'] = xr.DataArray(name, dims=(dim,), coords={dim: d[dim]})
            out_data.append(d)
            new_coords.extend([f"{name}_{i}" for i in d[dim].values])
        out_data = xr.concat(out_data, dim, data_vars='minimal', coords='minimal', compat='override')
        out_data[dim] = new_coords
        return cls(out_data)

### plot.py

In [None]:
%%writefile sfacts/plot.py
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.spatial.distance import squareform
import pandas as pd
import numpy as np
from tqdm import tqdm
import sfacts as sf
from functools import partial


def _calculate_clustermap_sizes(
    nx, ny, scalex=0.15, scaley=0.02, cwidth=0, cheight=0, dwidth=0.2, dheight=1.0
):
    # TODO: Incorporate colors.
    mwidth = nx * scalex
    mheight = ny * scaley
    fwidth = mwidth + cwidth + dwidth
    fheight = mheight + cheight + dheight
    dendrogram_ratio = (dwidth / fwidth, dheight / fheight)
    colors_ratio = (cwidth / fwidth, cheight / fheight)
    return (fwidth, fheight), dendrogram_ratio, colors_ratio


def _min_max_normalize(x):
    return (x - x.min()) / (x.max() - x.min())


def _scale_to_max_of_one(x):
    return x / x.max()


def dictionary_union(*args):
    out = args[0].copy()
    for a in args:
        out.update(a)
    return out


def plot_generic_clustermap_factory(
    matrix_func,
    col_colors_func=None,
    row_colors_func=None,
    col_linkage_func=None,
    row_linkage_func=None,
    row_col_annotation_cmap=mpl.cm.viridis,
    scalex=0.05,
    scaley=0.05,
    cwidth=0.4,
    cheight=0.4,
    dwidth=1.0,
    dheight=1.0,
    vmin=None,
    vmax=None,
    cmap=None,
    norm=mpl.colors.PowerNorm(1.),
    xticklabels=0,
    yticklabels=0,
    metric='correlation',
    cbar_pos=None,
):


    def _plot_func(
        world,
        matrix_func=matrix_func,
        col_linkage_func=col_linkage_func,
        row_linkage_func=row_linkage_func,
        col_colors_func=col_colors_func,
        row_colors_func=row_colors_func,
        row_col_annotation_cmap=row_col_annotation_cmap,
        scalex=scalex,
        scaley=scaley,
        cwidth=cwidth,
        cheight=cheight,
        dwidth=dwidth,
        dheight=dheight,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
        norm=norm,
        xticklabels=xticklabels,
        yticklabels=yticklabels,
        metric=metric,
        cbar_pos=cbar_pos,
        **kwargs,
    ):
        if col_linkage_func is None:
            col_linkage = None
        else:
            col_linkage = col_linkage_func(world)

        if row_linkage_func is None:
            row_linkage = None
        else:
            row_linkage = row_linkage_func(world)
        
        if col_colors_func is None:
            col_colors = None
        else:
            col_colors = col_colors_func(world).pipe(_scale_to_max_of_one).to_dataframe().applymap(row_col_annotation_cmap)

        if row_colors_func is None:
            row_colors = None    
        else:
            row_colors = row_colors_func(world).pipe(_scale_to_max_of_one).to_dataframe().applymap(row_col_annotation_cmap)

        matrix_data = matrix_func(world)
        
        ny, nx = matrix_data.shape
        figsize, dendrogram_ratio, colors_ratio = _calculate_clustermap_sizes(
            nx,
            ny,
            scalex=scalex,
            scaley=scaley,
            cwidth=cwidth,
            cheight=cheight,
            dwidth=dwidth,
            dheight=dheight,
        )
    #     sf.logging_util.info(matrix_data.shape, applied_scale_kwargs, figsize, dendrogram_ratio, colors_ratio)
    
        clustermap_kwargs = dict(
            vmin=vmin,
            vmax=vmax,
            norm=norm,
            cmap=cmap,
            xticklabels=xticklabels,
            yticklabels=yticklabels,
            col_linkage=col_linkage,
            row_linkage=row_linkage,
            row_colors=row_colors,
            col_colors=col_colors,
            figsize=figsize,
            dendrogram_ratio=dendrogram_ratio,
            colors_ratio=colors_ratio,
            metric=metric,
            cbar_pos=cbar_pos,
        )
        clustermap_kwargs.update(kwargs)
        
        grid = sns.clustermap(
            matrix_data,
            **clustermap_kwargs
        )
        return grid
    return _plot_func
    

plot_metagenotype = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.metagenotypes.to_genotype_estimates().to_pandas().T,
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim='strain', pseudo=1.),
    scalex=0.15,
    scaley=0.01,
    vmin=0,
    vmax=1,
    cmap=mpl.cm.coolwarm,
    xticklabels=1,
    yticklabels=0,
    col_colors_func=(
        lambda w: (
            w
            .metagenotypes
            .sum('allele')
            .mean('position')
            .pipe(np.sqrt)
            .rename('mean_depth')
        )
    ),
)

plot_genotype = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.genotypes.to_pandas().T,
    col_linkage_func=lambda w: w.genotypes.linkage(dim='strain'),
    row_linkage_func=lambda w: w.genotypes.linkage(dim='position'),
    scalex=0.15,
    scaley=0.01,
    vmin=0,
    vmax=1,
    cmap=mpl.cm.coolwarm,
    xticklabels=1,
    yticklabels=0,
    col_colors_func=(
        lambda w: (
            w
            .genotypes
            .entropy
        )
    ),
)

plot_fuzzed_genotype = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.fuzzed_genotypes.to_pandas().T,
    col_linkage_func=lambda w: w.fuzzed_genotypes.linkage(dim='strain'),
    row_linkage_func=lambda w: w.genotypes.linkage(dim='position'),
    scalex=0.15,
    scaley=0.01,
    vmin=0,
    vmax=1,
    cmap=mpl.cm.coolwarm,
    xticklabels=1,
    yticklabels=0,
    col_colors_func=(
        lambda w: (
            w
            .genotypes
            .entropy
        )
    ),
)

plot_missing = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.missingness.to_pandas().T,
    col_linkage_func=lambda w: w.genotypes.linkage(dim='strain'),
    row_linkage_func=lambda w: w.genotypes.linkage(dim='position'),
    metric='cosine',
    scalex=0.15,
    scaley=0.01,
    vmin=0,
    vmax=1,
    cmap=None,
    xticklabels=1,
    yticklabels=0,
    col_colors_func=(
        lambda w: (
            1 - w
            .missingness
            .mean('position')
        )
    ),
)

plot_community = plot_generic_clustermap_factory(
    matrix_func=lambda w: w.communities.to_pandas(),
    col_linkage_func=lambda w: w.genotypes.linkage(dim='strain'),
    metric='cosine',
    scalex=0.15,
    scaley=0.14,
    dheight=1.0,
    vmin=0,
    vmax=1,
    cmap=None,
    norm=mpl.colors.PowerNorm(1/2),
    xticklabels=1,
    yticklabels=1,
    col_colors_func=(
        lambda w: (
            w
            .communities
            .sum('sample')
            .pipe(np.sqrt)
        )
    ),
)

def plot_loss_history(trace):
    trace = np.array(trace)
    plt.plot((trace - trace.min()))
    plt.yscale("log")

### model.py

In [None]:
%%writefile sfacts/model.py
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from torch.nn.functional import pad as torch_pad
import xarray as xr
import sfacts as sf
from sfacts.pyro_util import all_torch
from sfacts.logging_util import info


class Structure():
    def __init__(self, generative, dims, description, default_hyperparameters=None):
        """
        
        *generative* :: Pyro generative model function(shape_dim_0, shape_dim_1, shape_dim_2, ..., **hyper_parameters)
        *dims* :: Sequence of names for dim_0, dim_1, dim_2, ...
        *description* :: Mapping from model variable to its dims.
        *default_hyperparamters* :: Values to use for hyperparameters when not explicitly set.
        """
        if default_hyperparameters is None:
            default_hyperparameters = {}
        
        self.generative = generative
        self.dims = dims
        self.description = description
        self.default_hyperparameters = default_hyperparameters
#         _ = self(self._dummy_shape, **all_torch(**self.default_hyperparameters))

#         info(f"New Structure({self.generative}, {self.default_hyperparameters})")
        
    def __call__(self, shape, data, hyperparameters):
        assert len(shape) == len(self.dims)
        conditioned_generative = pyro.condition(self.generative, data)
        return conditioned_generative(*shape, **hyperparameters)
    
#     def condition(self, **data):
#         new_data = self.data.copy()
#         new_data.update(data)
#         return self.__class__(
#             generative=self.generative,
#             dims=self.dims,
#             description=self.description,
#             default_hyperparameters=self.default_hyperparameters,
#             data=new_data,
#         )
    
    @property
    def _dummy_shape(self):
        shape = range(1, len(self.dims) + 1)
        return shape

    def explain_shapes(self, shape=None):
        if shape is None:
            shape = self._dummy_shape
        info(dict(zip(self.dims, shape)))
        sf.pyro_util.shape_info(self(shape, **self.default_hyperparameters))

    
# For decorator use.
def structure(dims, description, default_hyperparameters=None):
    return partial(Structure, dims=dims, description=description, default_hyperparameters=default_hyperparameters)


class ParameterizedModel():
    def __init__(self, structure, coords, dtype=torch.float32, device='cpu', data=None, hyperparameters=None):
        if hyperparameters is None:
            hyperparameters = {}
            
        if data is None:
            data = {}
       
        self.structure = structure
        self.coords = {k: self._coords_or_range(coords[k]) for k in self.structure.dims}
        self.dtype = dtype
        self.device = device
        self.hyperparameters = self.structure.default_hyperparameters.copy()
        self.hyperparameters.update(hyperparameters)
        self.data = data

    @property
    def sizes(self):
        return {k: len(self.coords[k]) for k in self.structure.dims}
    
    @property
    def shape(self):
        return tuple(self.sizes.values())
    
    def __repr__(self):
        return (f"{self.__class__.__name__}"
                f"({self.structure}, "
                f"coords={self.coords}, "
                f"dtype={self.dtype}, "
                f"device={self.device}, "
                f"hyperparameters={self.hyperparameters}, "
                f"data={self.data})")
        
    def __call__(self):
        # Here's where all the action happens.
        # All parameters are cast based on dtype and device.
        # The model is conditioned on the
        # data, and then called with the shape tuple
        # and cast hyperparameters.
        return self.structure(
            self.shape,
            data=all_torch(**self.data, dtype=self.dtype, device=self.device),
            hyperparameters=all_torch(**self.hyperparameters, dtype=self.dtype, device=self.device),
        )
    
    @staticmethod
    def _coords_or_range(coords):
        if type(coords) == int:
            return range(coords)
        else:
            return coords
        
    def with_hyperparameters(self, **hyperparameters):
        new_hyperparameters = self.hyperparameters.copy()
        new_hyperparameters.update(hyperparameters)
        return self.__class__(
            structure=self.structure,
            coords=self.coords,
            dtype=self.dtype,
            device=self.device,
            hyperparameters=new_hyperparameters,
            data=self.data,
        )
    
    def condition(self, **data):
        new_data = self.data.copy()
        new_data.update(data)
        return self.__class__(
            structure=self.structure,
            coords=self.coords,
            dtype=self.dtype,
            device=self.device,
            hyperparameters=self.hyperparameters,
            data=new_data,
        )

    def format_world(self, data):
        out = {}
        for k in self.structure.description:
            out[k] = xr.DataArray(
                data[k],
                dims=self.structure.description[k],
                coords={dim: self.coords[dim] for dim in self.structure.description[k]},
            )
        return sf.data.World(xr.Dataset(out))

    def simulate(self, n=1, seed=None):
        sf.pyro_util.set_random_seed(seed)
        obs = pyro.infer.Predictive(self, num_samples=n)()
        obs = {k: obs[k].detach().cpu().numpy().squeeze() for k in obs.keys()}
        return obs
    
    def simulate_world(self, seed=None):
        return self.format_world(self.simulate(n=1))

### model_zoo.py

In [None]:
%%writefile sfacts/model_zoo.py
import sfacts as sf
import pyro
import pyro.distributions as dist
import torch
from torch.nn.functional import pad as torch_pad


SHARED_DIMS = ('sample', 'position', 'strain', 'allele')
SHARED_DESCRIPTIONS = dict(
    gamma=('strain', 'position'),
    delta=('strain', 'position'),
    rho=('strain',),
    pi=('sample', 'strain'),
    epsilon=('sample',),
    m_hyper_r=(),
    mu=('sample',),
    nu=('sample', 'position'),
    p_noerr=('sample', 'position'),
    p=('sample', 'position'),
    alpha_hyper_mean=(),
    alpha=('sample',),
    m=('sample', 'position'),
    y=('sample', 'position'),
    genotypes=('strain', 'position'),
    missingness=('strain', 'position'),
    communities=('sample', 'strain'),
    metagenotypes=('sample', 'position', 'allele'),
)

def _mapping_subset(mapping, keys):
    return {k: mapping[k] for k in keys}

def stickbreaking_betas_to_probs(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return torch_pad(beta, (0, 1), value=1) * torch_pad(beta1m_cumprod, (1, 0), value=1)


def NegativeBinomialReparam(mu, r):
    p = 1.0 / ((r / mu) + 1.0)
    logits = torch.logit(p)
#     p = torch.clamp(p, eps, 1 - eps)
    return dist.NegativeBinomial(
        total_count=r,
        logits=logits,
    )

def unit_interval_power_transformation(p, alpha, beta):
    log_p = torch.log(p)
    log_q = torch.log1p(-p)
    log_p_raised = log_p * alpha
    log_q_raised = log_q * beta
    return torch.exp(
        log_p_raised -
        torch.logsumexp(torch.stack([log_p_raised, log_q_raised]), dim=0)
    )

def _pp_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp):
    # Genotypes
    #     delta_hyper_p = pyro.sample('delta_hyper_p', dist.Beta(1., 1.))
    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            _gamma = pyro.sample(
                "_gamma", dist.Beta(1., 1.)
            )
            gamma = pyro.deterministic(
                'gamma',
                unit_interval_power_transformation(_gamma, 1 / gamma_hyper, 1 / gamma_hyper))
#                 Position presence/absence
            _delta = pyro.sample(
                "_delta", dist.Beta(1., 1.)
            )
            delta = pyro.deterministic(
                'delta',
                unit_interval_power_transformation(
                    _delta,
                    2 * (1 - delta_hyper_r) / delta_hyper_temp,
                    2 * delta_hyper_r / delta_hyper_temp
                )
            )

#                 delta = pyro.sample(
#                     'delta',
#                     dist.RelaxedBernoulli(
#                         temperature=delta_hyper_temp, probs=delta_hyper_p
#                     ),
#                 )

    # These deterministics are accessed by PointMixin class properties.
    pyro.deterministic("genotypes", gamma)
    pyro.deterministic("missingness", delta)
    return gamma, delta


def _gsm_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp):
    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            gamma = pyro.sample(
                'gamma',
                dist.RelaxedBernoulli(
                    temperature=gamma_hyper, probs=0.5,
                ),
            )
            delta = pyro.sample(
                'delta',
                dist.RelaxedBernoulli(
                    temperature=delta_hyper_temp, probs=delta_hyper_r
                ),
            )
    pyro.deterministic("genotypes", gamma)
    pyro.deterministic("missingness", delta)
    return gamma, delta


def _beta_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp):
    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            gamma = pyro.sample(
                'gamma',
                dist.Beta(
                    gamma_hyper, gamma_hyper,
                ),
            )
            delta = pyro.sample(
                'delta',
                dist.RelaxedBernoulli(
                    delta_hyper_r * delta_hyper_temp,
                    (1 - delta_hyper_r) * delta_hyper_temp,
                ),
            )
    pyro.deterministic("genotypes", gamma)
    pyro.deterministic("missingness", delta)
    return gamma, delta


def _hybrid_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp):
    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            _gamma = pyro.sample(
                "_gamma", dist.Beta(1., 1.)
            )
            gamma = pyro.deterministic(
                'gamma',
                unit_interval_power_transformation(_gamma, 1 / gamma_hyper, 1 / gamma_hyper))
#                 Position presence/absence
            delta = pyro.sample(
                'delta',
                dist.RelaxedBernoulli(
                    temperature=delta_hyper_temp, probs=delta_hyper_r
                ),
            )
    pyro.deterministic("genotypes", gamma)
    pyro.deterministic("missingness", delta)
    return gamma, delta


def _dp_rho_module(s, rho_hyper):
#         # TODO: Will torch.ones(s) fail when I try to run this on the GPU because it's, by default on the CPU?
#         rho = pyro.sample(
#             "rho", dist.Dirichlet(rho_hyper * torch.ones(s))
#         )
    # Meta-community composition
    rho_betas = pyro.sample('rho_betas', dist.Beta(1., rho_hyper).expand([s - 1]).to_event())
    rho = pyro.deterministic('rho', stickbreaking_betas_to_probs(rho_betas))
    pyro.deterministic("metacommunity", rho)
    return rho


def _dirichlet_pi_epsilon_alpha_mu_module(
    n, pi_hyper, rho, epsilon_hyper_alpha, epsilon_hyper_beta, alpha_hyper_mean, alpha_hyper_scale, mu_hyper_mean, mu_hyper_scale
):
    with pyro.plate("sample", n, dim=-1):
        # Community composition
        pi = pyro.sample("pi", dist.Dirichlet(pi_hyper * rho, validate_args=False))
        # Sequencing error
        epsilon = pyro.sample(
            "epsilon", dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta)
        ).unsqueeze(-1)
        alpha = pyro.sample(
            "alpha",
            dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale),
        ).unsqueeze(-1)
        # Sample coverage
        mu = pyro.sample(
            "mu", dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale)
        )
    pyro.deterministic("communities", pi)
    return pi, epsilon, alpha, mu


def _lognormal_alpha_hyper_mean_module(alpha_hyper_hyper_mean, alpha_hyper_hyper_scale):
    alpha_hyper_mean = pyro.sample(
        "alpha_hyper_mean",
        dist.LogNormal(
            loc=torch.log(alpha_hyper_hyper_mean), scale=alpha_hyper_hyper_scale
        ),
    )
    return alpha_hyper_mean


def _betabinomial_observation_module(
    pi, gamma, delta, m_hyper_r_mu, m_hyper_r_scale, mu, epsilon, alpha
):
    # Depth at each position
    nu = pyro.deterministic("nu", pi @ delta)
    m_hyper_r = pyro.sample("m_hyper_r", dist.LogNormal(loc=m_hyper_r_mu, scale=m_hyper_r_scale))
    # TODO: Consider using pyro.distributions.GammaPoisson parameterization?
    m = pyro.sample(
        "m",
        NegativeBinomialReparam(
            nu * mu.reshape((-1, 1)), m_hyper_r
        ).to_event(),
    )

    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic("p_noerr", pi @ (gamma * delta) / nu)
    p = pyro.deterministic(
        "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
    )

    # Observation
    y = pyro.sample(
        "y",
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
#             validate_args=False,
        ).to_event(),
    )
    # TODO: Check that dim=0 works?
    metagenotypes = pyro.deterministic("metagenotypes", torch.stack([y, m - y], dim=-1))


@sf.model.structure(
    dims=SHARED_DIMS,
    description=_mapping_subset(
        SHARED_DESCRIPTIONS,
        ['rho', 'epsilon', 'm_hyper_r', 'mu',
         'nu', 'p_noerr', 'p', 'm', 'y',
         'alpha_hyper_mean', 'alpha',
         'genotypes', 'missingness',
         'communities', 'metagenotypes']
    ),
    default_hyperparameters=dict(
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_r=0.9,
        rho_hyper=5.0,
        pi_hyper=0.2,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=1.0,
        mu_hyper_scale=10.0,
        m_hyper_r_mu=1.,
        m_hyper_r_scale=1.,
        alpha_hyper_hyper_mean=100.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=0.5,
    ),
)
def pp_fuzzy_missing_dp_betabinomial_metagenotype(
        n,
        g,
        s,
        a,
        gamma_hyper,
        delta_hyper_r,
        delta_hyper_temp,
        rho_hyper,  #=1.0,
        pi_hyper,  #=1.0,
        alpha_hyper_hyper_mean,  #=100.0,
        alpha_hyper_hyper_scale,  #=1.0,
        alpha_hyper_scale,  #=0.5,
        epsilon_hyper_alpha,  #=1.5,
        epsilon_hyper_beta,  #=1.5 / 0.01,
        mu_hyper_mean,  #=1.0,
        mu_hyper_scale,  #=1.0,
        m_hyper_r_mu,
        m_hyper_r_scale,
    ):
    gamma, delta = _pp_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp)
    rho = _dp_rho_module(s, rho_hyper)
    alpha_hyper_mean = _lognormal_alpha_hyper_mean_module(
        alpha_hyper_hyper_mean, alpha_hyper_hyper_scale
    )
    pi, epsilon, alpha, mu = _dirichlet_pi_epsilon_alpha_mu_module(
        n, pi_hyper, rho, epsilon_hyper_alpha, epsilon_hyper_beta, alpha_hyper_mean, alpha_hyper_scale, mu_hyper_mean, mu_hyper_scale
    )
    _betabinomial_observation_module(
        pi, gamma, delta, m_hyper_r_mu, m_hyper_r_scale, mu, epsilon, alpha
    )
    
@sf.model.structure(
    dims=SHARED_DIMS,
    description=_mapping_subset(
        SHARED_DESCRIPTIONS,
        ['rho', 'epsilon', 'm_hyper_r', 'mu',
         'nu', 'p_noerr', 'p', 'm', 'y',
         'alpha_hyper_mean', 'alpha',
         'genotypes', 'missingness',
         'communities', 'metagenotypes']
    ),
    default_hyperparameters=dict(
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_r=0.9,
        rho_hyper=5.0,
        pi_hyper=0.2,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=1.0,
        mu_hyper_scale=10.0,
        m_hyper_r_mu=1.,
        m_hyper_r_scale=1.,
        alpha_hyper_hyper_mean=100.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=0.5,
    ),
)
def gsm_fuzzy_missing_dp_betabinomial_metagenotype(
        n,
        g,
        s,
        a,
        gamma_hyper,
        delta_hyper_r,
        delta_hyper_temp,
        rho_hyper,  #=1.0,
        pi_hyper,  #=1.0,
        alpha_hyper_hyper_mean,  #=100.0,
        alpha_hyper_hyper_scale,  #=1.0,
        alpha_hyper_scale,  #=0.5,
        epsilon_hyper_alpha,  #=1.5,
        epsilon_hyper_beta,  #=1.5 / 0.01,
        mu_hyper_mean,  #=1.0,
        mu_hyper_scale,  #=1.0,
        m_hyper_r_mu,
        m_hyper_r_scale,
    ):
    gamma, delta = _gsm_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp)
    rho = _dp_rho_module(s, rho_hyper)
    alpha_hyper_mean = _lognormal_alpha_hyper_mean_module(
        alpha_hyper_hyper_mean, alpha_hyper_hyper_scale
    )
    pi, epsilon, alpha, mu = _dirichlet_pi_epsilon_alpha_mu_module(
        n, pi_hyper, rho, epsilon_hyper_alpha, epsilon_hyper_beta, alpha_hyper_mean, alpha_hyper_scale, mu_hyper_mean, mu_hyper_scale
    )
    _betabinomial_observation_module(
        pi, gamma, delta, m_hyper_r_mu, m_hyper_r_scale, mu, epsilon, alpha
    )
    
    
@sf.model.structure(
    dims=SHARED_DIMS,
    description=_mapping_subset(
        SHARED_DESCRIPTIONS,
        ['rho', 'epsilon', 'm_hyper_r', 'mu',
         'nu', 'p_noerr', 'p', 'm', 'y',
         'alpha_hyper_mean', 'alpha',
         'genotypes', 'missingness',
         'communities', 'metagenotypes']
    ),
    default_hyperparameters=dict(
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_r=0.9,
        rho_hyper=5.0,
        pi_hyper=0.2,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=1.0,
        mu_hyper_scale=10.0,
        m_hyper_r_mu=1.,
        m_hyper_r_scale=1.,
        alpha_hyper_hyper_mean=100.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=0.5,
    ),
)
def hybrid_fuzzy_missing_dp_betabinomial_metagenotype(
        n,
        g,
        s,
        a,
        gamma_hyper,
        delta_hyper_r,
        delta_hyper_temp,
        rho_hyper,  #=1.0,
        pi_hyper,  #=1.0,
        alpha_hyper_hyper_mean,  #=100.0,
        alpha_hyper_hyper_scale,  #=1.0,
        alpha_hyper_scale,  #=0.5,
        epsilon_hyper_alpha,  #=1.5,
        epsilon_hyper_beta,  #=1.5 / 0.01,
        mu_hyper_mean,  #=1.0,
        mu_hyper_scale,  #=1.0,
        m_hyper_r_mu,
        m_hyper_r_scale,
    ):
    gamma, delta = _hybrid_gamma_delta_module(s, g, gamma_hyper, delta_hyper_r, delta_hyper_temp)
    rho = _dp_rho_module(s, rho_hyper)
    alpha_hyper_mean = _lognormal_alpha_hyper_mean_module(
        alpha_hyper_hyper_mean, alpha_hyper_hyper_scale
    )
    pi, epsilon, alpha, mu = _dirichlet_pi_epsilon_alpha_mu_module(
        n, pi_hyper, rho, epsilon_hyper_alpha, epsilon_hyper_beta, alpha_hyper_mean, alpha_hyper_scale, mu_hyper_mean, mu_hyper_scale
    )
    _betabinomial_observation_module(
        pi, gamma, delta, m_hyper_r_mu, m_hyper_r_scale, mu, epsilon, alpha
    )

### estimation.py

In [None]:
%%writefile sfacts/estimation.py
import sfacts as sf
# from sklearn.cluster import AgglomerativeClustering
# from sklearn.decomposition import non_negative_factorization
# from sfacts.genotype import genotype_pdist, adjust_genotype_by_missing
from sfacts.pyro_util import all_torch
# import pandas as pd
import numpy as np
# import scipy as sp
# from scipy.spatial.distance import squareform
import pyro
# import pyro.distributions as dist
import torch
from tqdm import tqdm
from sfacts.logging_util import info
# from sfacts.model import condition_model


# def cluster_genotypes(gamma, thresh, quiet=True, precomputed_pdist=None):
# 
#     if precomputed_pdist is None:
#         compressed_dmat = genotype_pdist(gamma, quiet=quiet)
#     else:
#         compressed_dmat = precomputed_pdist
# 
#     clust = pd.Series(
#         AgglomerativeClustering(
#             n_clusters=None,
#             affinity="precomputed",
#             linkage="complete",
#             distance_threshold=thresh,
#         )
#         .fit(squareform(compressed_dmat))
#         .labels_
#     )
# 
#     return clust, compressed_dmat
# 
# 
# def initialize_parameters_by_clustering_samples(
#     metagenotype,
#     thresh=None,
#     additional_strains_factor=0.5,
#     quiet=True,
#     precomputed_pdist=None,
# ):
#     n, g, a = metagenotype.shape
# 
#     sample_genotype = sf.genotype.metagenotype_to_genotype(metagenotype)
#     clust, cdmat = cluster_genotypes(
#         sample_genotype,
#         thresh=thresh,
#         quiet=quiet,
#         precomputed_pdist=precomputed_pdist,
#     )
#     s_clust = len(clust.value_counts())
#     
#     # FIXME: This probably doesn't work using xarray.
#     # How to use pandas style aggregation here?
#     clust_metagenotype = metagenotype.groupby(sample=clust)
#     s_additional_haplotypes = int(additional_strains_factor * s_clust)
#     s_init = s_clust + s_additional_haplotypes
# 
#     # FIXME: This probably doesn't work using xarray.
#     # The "additional_haplotypes" cells will need to be indexed.
#     # Consider doing the matrix building in numpy space,
#     # and then apply genotype.build_genotype(gamma).
#     gamma_init = xr.concat(
#         [
#             clust_metagenotype,
#             np.ones((s_additional_haplotypes, g)) * 0.5,
#         ]
#     ).values
# 
#     pi_init = np.ones((n, s_init))
#     for i in range(n):
#         pi_init[i, clust[i]] = s_init - 1
#     pi_init /= pi_init.sum(1, keepdims=True)
# 
#     assert (~np.isnan(gamma_init)).all()
# 
#     return gamma_init, pi_init, cdmat
# 
# 
# def initialize_parameters_by_nmf(
#     y, m, s, quiet=True, solver="mu", alpha=100.0, l1_ratio=1.0, tol=1e-2, **kwargs
# ):
#     n, g = y.shape
# 
#     # Fit to counts of both reference and alternative alleles by stacking them.
#     stacked_metagenotype = np.concatenate([y, m - y], axis=1)
#     pi_unnorm, gamma_unnorm, _ = non_negative_factorization(
#         stacked_metagenotype,
#         n_components=s,
#         solver=solver,
#         verbose=int(not quiet),
#         alpha=alpha,
#         l1_ratio=l1_ratio,
#         tol=tol,
#         **kwargs,
#     )
# 
#     # TODO: Find a more principled way to convert pi_unnorm into pi_init.
#     pi_init = (pi_unnorm + 1) / (pi_unnorm + 1).sum(1, keepdims=True)
#     gamma_init = (gamma_unnorm[:, :g] + 1) / (
#         gamma_unnorm[:, :g] + gamma_unnorm[:, -g:] + 2
#     )
# 
#     return gamma_init, pi_init, None


def estimate_parameters(
    model,
#     data,
    dtype=torch.float32,
    device="cpu",
    initialize_params=None,
    maxiter=10000,
    lagA=20,
    lagB=100,
    opt=pyro.optim.Adamax({"lr": 1e-2}, {"clip_norm": 100}),
    quiet=False,
    seed=None,
#     **model_kwargs,
):
    if initialize_params is None:
        initialize_params = {}
        
    sf.pyro_util.set_random_seed(seed, warn=(not quiet))

#     conditioned_model = model.with_hyperparameters(**hyperparameters).condition(**data)
    
#     condition_model(
#         model,
#         data=data,
#         dtype=dtype,
#         device=device,
#         **model_kwargs,
#     )

    _guide = pyro.infer.autoguide.AutoLaplaceApproximation(
        model,
        init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(
            values=all_torch(**initialize_params, dtype=dtype, device=device)
        ),
    )
    svi = pyro.infer.SVI(
        model, _guide, opt, loss=pyro.infer.JitTrace_ELBO()
    )
    pyro.clear_param_store()

    history = []
    pbar = tqdm(range(maxiter), disable=quiet)
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                pbar.close()
                raise RuntimeError("ELBO NaN?")

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if i % 10 == 0:
                if i > lagB:
                    delta = history[-2] - history[-1]
                    delta_lagA = (history[-lagA] - history[-1]) / lagA
                    delta_lagB = (history[-lagB] - history[-1]) / lagB
                    pbar.set_postfix(
                        {
                            "ELBO": history[-1],
                            "delta": delta,
                            f"lag{lagA}": delta_lagA,
                            f"lag{lagB}": delta_lagB,
                        }
                    )
                    if (delta_lagA <= 0) and (delta_lagB <= 0):
                        pbar.close()
                        info("Converged", quiet=quiet)
                        break
    except KeyboardInterrupt:
        pbar.close()
        info("Interrupted", quiet=quiet)
        pass
    est = pyro.infer.Predictive(model, guide=_guide, num_samples=1)()
    est = {k: est[k].detach().cpu().numpy().mean(0).squeeze() for k in est.keys()}

    if device.startswith("cuda"):
        #         info(
        #             "CUDA available mem: {}".format(
        #                 torch.cuda.get_device_properties(0).total_memory
        #             ),
        #         )
        #         info("CUDA reserved mem: {}".format(torch.cuda.memory_reserved(0)))
        #         info("CUDA allocated mem: {}".format(torch.cuda.memory_allocated(0)))
        #         info(
        #             "CUDA free mem: {}".format(
        #                 torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
        #             )
        #         )
        torch.cuda.empty_cache()

    return model.format_world(est), history


# def merge_similar_genotypes(
#     gamma,
#     pi,
#     thresh,
#     delta=None,
#     progress=False,
# ):
#     if delta is None:
#         delta = np.ones_like(gamma)

#     gamma_adjust = adjust_genotype_by_missing(gamma, delta)

#     clust, dmat = cluster_genotypes(gamma_adjust, thresh=thresh, progress=progress)
#     gamma_mean = (
#         pd.DataFrame(pd.DataFrame(gamma_adjust))
#         .groupby(clust)
#         .apply(lambda x: sp.special.expit(sp.special.logit(x)).mean(0))
#         .values
#     )
#     delta_mean = pd.DataFrame(pd.DataFrame(delta)).groupby(clust).mean().values
#     pi_sum = pd.DataFrame(pd.DataFrame(pi)).groupby(clust, axis="columns").sum().values

#     return gamma_mean, pi_sum, delta_mean

### evaluation.py

In [None]:
%%writefile sfacts/evaluation.py
from scipy.spatial.distance import cdist, pdist
import pandas as pd
import numpy as np
import xarray as xr


# def binary_entropy(p):
#     q = 1 - p
#     ent = -(p * np.log2(p) + q * np.log2(q))
#     return ent


# def sum_binary_entropy(p, normalize=False, axis=None):
#     q = 1 - p
#     ent = np.sum(-(p * np.log2(p) + q * np.log2(q)), axis=axis)
#     if normalize:
#         ent = ent / p.shape[axis]
#     return ent


# def mean_masked_genotype_entropy(gamma, delta):
#     return (binary_entropy(gamma) * delta).mean(1)


# def sample_mean_masked_genotype_entropy(pi, gamma, delta):
#     return (pi @ mean_masked_genotype_entropy(gamma, delta).reshape((-1, 1))).squeeze()



def _rmse(x, y):
    return np.sqrt(np.square(x - y).mean())


def _rss(x, y):
    return np.sqrt(np.square(x - y).sum())


def match_genotypes(worldA, worldB):
    gammaA = worldA.genotypes.data.to_pandas()
    gammaB = worldB.genotypes.data.to_pandas()

    g = gammaA.shape[1]
    dist = pd.DataFrame(cdist(gammaA, gammaB, metric="cityblock"))
    return dist.idxmin(axis=1), dist.min(axis=1) / g


def weighted_genotype_error(worldA, worldB):
    _, accuracy = match_genotypes(worldA, worldB)
    error = xr.DataArray(accuracy, dims=('strain',), coords=dict(strain=worldA.strain))
    total_coverage = (worldA.data.mu * worldA.data.communities).sum("sample")
    return float((error * total_coverage).sum() / total_coverage.sum())


def community_error(worldA, worldB, reps=99):
    piA = worldA.communities.to_pandas()
    piB = worldB.communities.to_pandas()
    bcA = 1 - pdist(piA, metric="braycurtis")
    bcB = 1 - pdist(piB, metric="braycurtis")
    return _rmse(bcA, bcB)


def community_error_test(worldA, worldB, reps=99):
    pi_sim = worldA.communities.to_pandas()
    pi_fit = worldB.communities.to_pandas()
    
    bc_sim = 1 - pdist(pi_sim, metric="braycurtis")
    bc_fit = 1 - pdist(pi_fit, metric="braycurtis")
    err = _rmse(bc_sim, bc_fit)

    null = []
    n = len(bc_sim)
    for i in range(reps):
        bc_sim_permute = np.random.permutation(bc_sim)
        null.append(_rmse(bc_sim, bc_sim_permute))
    null = np.array(null)

    return err, null, err / np.mean(null), (np.sort(null) < err).mean()


# def metacommunity_composition_rss(worldA, worldB):
#     pi_sim = worldA.communities.to_dataframe()
#     pi_fit = worldB.communities.to_dataframe()
#     mean_sim = pi_sim.mean(0)
#     mean_fit = pi_fit.mean(0)
#     s_sim = mean_sim.shape[0]
#     s_fit = mean_fit.shape[0]
#     s = max(s_sim, s_fit)
#     mean_sim = np.sort(np.pad(mean_sim, pad_width=(0, s - s_sim)))
#     mean_fit = np.sort(np.pad(mean_fit, pad_width=(0, s - s_fit)))
#     return _rss(mean_sim, mean_fit)


## Prototype

### Unit-tests

### Loading Data / Plotting

In [None]:
# Sanity check on sfacts/data.py
obs = (
    sf.data.Metagenotypes.load('data/ucfmt.sp-100022.gtpro-pileup.nc')
    .select_variable_positions(incid_thresh=0.2)
    .select_samples_with_coverage(0.1)
)

obs.genotypes.validate_constraints()

print(obs.sizes)
# sf.plot.plot_genotype(d.genotypes.lift(lambda d: d.isel(strain=range(30), position=range(1000))))
sf.plot.plot_genotype(obs)

In [None]:
sf.plot.plot_genotype_similarity(d.genotypes)

### Simulation

In [None]:
# def unit_interval_power_transformation(p, alpha, beta):
#     log_p = torch.log(p)
#     log_q = torch.log1p(-p)
#     log_p_raised = log_p * alpha
#     log_q_raised = log_q * beta
#     return torch.exp(
#         log_p_raised -
#         torch.logsumexp(torch.stack([log_p_raised, log_q_raised]), dim=0)
#     )

x = pyro.sample('x', dist.Beta(1., 1.).expand([10000]))
b = pyro.sample('b', dist.Beta(0.2, 0.8).expand([10000]))

r = 0.6
temp = 1e-1
x_pt = sf.model_zoo.unit_interval_power_transformation(x, 2 * (1 - r) / temp, 2 * r / temp)

bins = np.linspace(0., 1.0, num=10)

plt.hist(x.numpy(), bins=bins, density=True)
# plt.hist(b.numpy(), bins=bins, alpha=0.5, density=True)
plt.hist(x_pt.numpy(), bins=bins, density=True, alpha=0.5)
print(x_pt.numpy().mean(), x_pt.numpy().min(), x_pt.numpy().max())
# plt.yscale('log')
None

In [None]:
%%time
sim = sf.model.ParameterizedModel(
    sf.model_zoo.pp_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=100,
        position=500,
        strain=20,
        allele=['alt', 'ref'],
    ),
    hyperparameters=dict(
        gamma_hyper=0.01,
        delta_hyper_r=0.9,
        delta_hyper_temp=0.01,
        rho_hyper=4.,
        pi_hyper=0.4,
#         alpha_hyper_hyper_mean=100.0,
#         alpha_hyper_hyper_scale=1.0,
#         alpha_hyper_scale=0.5,
#         epsilon_hyper_alpha=1.5,
#         epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=50.0,
        mu_hyper_scale=2.0,
        m_hyper_r_mu=5.,
        m_hyper_r_scale=2.,
    )
).simulate_world()

sf.plot.plot_genotype(sim.metagenotypes)
sf.plot.plot_community(sim)
sf.plot.plot_genotype(sim)
sf.plot.plot_missing(sim)
None
# sf.plot.plot_genotype(sf.data.Metagenotypes.from_counts_and_totals(sim0.data['y'], sim0.data['m']))

In [None]:
sf.plot.plot_genotype(sim.fuzzed_genotypes)

### Fitting

#### Real Data

In [None]:
mgen = obs.metagenotypes

print(mgen.sizes)

model_fit = (
    sf.model.ParameterizedModel(
        sf.model_zoo.pp_fuzzy_missing_dirichlet_binomial_metagenotype,
        coords=dict(
            sample=mgen.sample.values,
            position=mgen.position.values,
            allele=mgen.allele.values,
            strain=range(20),
        ),
        hyperparameters=dict(
            gamma_hyper=0.1,
            delta_hyper_r=0.9,
            delta_hyper_temp=0.1,
            rho_hyper=1.,
            pi_hyper=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
            mu_hyper_mean=10.0,
            mu_hyper_scale=10.0,
        )
    )
    .condition(
        **mgen.to_counts_and_totals()
#         y=metagenotype_sim.data.sel(allele='alt').values,
#         m=metagenotype_sim.data.sum('allele').values
    )
)

est, history = sf.estimation.estimate_parameters(
    model_fit,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100})
)

sf.plot.plot_loss_history(history)

In [None]:
sf.plot.plot_genotype(est)
sf.plot.plot_community(est)
sf.plot.plot_missing(est)

# sf.plot.plot_missing_comparison(dict(sim=sim0, est=est))


#### Simulated Data

In [None]:
mgen = sim.metagenotypes

print(mgen.sizes)

model_fit = (
    sf.model.ParameterizedModel(
        sf.model_zoo.pp_fuzzy_missing_dp_betabinomial_metagenotype,
        coords=dict(
            sample=mgen.sample.values,
            position=mgen.position.values,
            allele=mgen.allele.values,
            strain=range(20),
        ),
        hyperparameters=dict(
            gamma_hyper=0.1,
            delta_hyper_r=0.9,
            delta_hyper_temp=0.1,
            rho_hyper=1.,
            pi_hyper=0.5,
#             alpha_hyper_hyper_mean=100.0,
#             alpha_hyper_hyper_scale=1.0,
#             alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
            mu_hyper_mean=10.0,
            mu_hyper_scale=10.0,

        )
    )
    .condition(
        **mgen.to_counts_and_totals()
#         y=metagenotype_sim.data.sel(allele='alt').values,
#         m=metagenotype_sim.data.sum('allele').values
    )
)

est, history = sf.estimation.estimate_parameters(
    model_fit,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100})
)

sf.plot.plot_loss_history(history)

In [None]:
sf.plot.plot_genotype_comparison(dict(sim=sim, est=est))
sf.plot.plot_community_comparison(dict(sim=sim, est=est))
# sf.plot.plot_missing_comparison(dict(sim=sim0, est=est))


### Evaluation

In [None]:
sf.evaluation.community_accuracy_test(sim.communities.data.values, est.communities.data.values)

# ---------

## Model Specification

In [None]:
epsilon_hyper_alpha, epsilon_hyper_beta = 1.5, 1.5 / 0.01
plt.hist(pyro.sample('epsilon_hyper', dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta).expand([10000])).cpu().numpy(), bins=100)
None

In [None]:
plt.hist(pyro.sample('test', sf.model.NegativeBinomialReparam(torch.tensor(10.), r=torch.tensor(1.), eps=torch.tensor(1e-5)).expand([1000])).numpy())

In [None]:
sf.pyro_util.shape_info(sf.model.model, n=100, g=200, s=20)

## Simulation

### SimShape-1: Small study

In [None]:
seed = 1
pyro.util.set_rng_seed(seed)

n_sim = 100
g_sim = 5000
s_sim = 20

sim1 = sf.model.simulate(
    sf.model.condition_model(
        sf.model.model,
        data=dict(
            alpha_hyper_mean=100.
        ),
        n=n_sim,
        g=g_sim,
        s=s_sim,
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_p=0.7,
        pi_hyper=0.5,
        rho_hyper=10.,
        mu_hyper_mean=2.,
        mu_hyper_scale=0.5,
        m_hyper_r=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5/0.01,
        device='cpu'
    )
)

## Visualization

In [None]:
n_plt = 100
g_plt = 200
s_plt = 20

In [None]:
sf.plot.plot_community(sim1['pi'][:s_plt, :n_plt])

In [None]:
sf.plot.plot_genotype(
    sf.genotype.counts_to_p_estimate(
        sim1['y'][:n_plt, :g_plt],
        sim1['m'][:n_plt, :g_plt]),
    linkage_kw=dict(progress=True)
)

In [None]:
sf.plot.plot_genotype_similarity(sf.genotype.counts_to_p_estimate(sim1['y'][:n_plt, :g_plt], sim1['m'][:n_plt, :g_plt]), linkage_kw=dict(progress=True))

In [None]:
sf.plot.plot_genotype(sim1['gamma'][:s_plt, :g_plt])

In [None]:
sf.plot.plot_missing(sim1['delta'][:s_plt, :g_plt])

In [None]:
sf.plot.plot_missing(sim1['nu'][:n_plt, :g_plt])

In [None]:
sns.clustermap(sim1['m'][:n_plt, :g_plt], norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
sf.plot.plot_genotype_similarity(sim1['gamma'][:s_plt, :g_plt])

In [None]:
plt.hist(sim1['epsilon'], bins=50)
None

In [None]:
plt.hist(sim1['alpha'], bins=20)
None

## Estimation

### Initialization

In [None]:
g_fit = 1000  # sim1['y'].shape[1]
n_fit = sim1['y'].shape[0]

sim1_gamma_init, sim1_pi_init, sim1_dmat = sf.estimation.initialize_parameters_by_clustering_samples(
    sim1['y'][:n_fit, :g_fit],
    sim1['m'][:n_fit, :g_fit],
    thresh=0.05,
    additional_strains_factor=0.,
    progress=True,
)

print(sim1_pi_init.shape)

In [None]:
sf.plot.plot_genotype(sim1_gamma_init[:s_plt, :g_plt])

In [None]:
sf.plot.plot_genotype_similarity(sim1_gamma_init)

In [None]:
sf.plot.plot_community(sim1_pi_init[:n_plt, :s_plt])

### Fitting

In [None]:
s_fit = sim1_gamma_init.shape[0]
initialize_params = dict(gamma=sim1_gamma_init, pi=sim1_pi_init)

sim1_fit1, history = sf.estimation.estimate_parameters(
    sf.model.model,
    data=dict(y=sim1['y'][:, :g_fit], m=sim1['m'][:, :g_fit]),
    n=n_fit,
    g=g_fit,
    s=s_fit,
    gamma_hyper=0.1,
    pi_hyper=1.0,
    rho_hyper=0.5,
    mu_hyper_mean=5,
    mu_hyper_scale=5.,
    m_hyper_r=10.,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=10.,
    alpha_hyper_scale=0.5,
    epsilon_hyper_alpha=1.5,
    epsilon_hyper_beta=1.5 / 0.01,
    initialize_params=initialize_params,
    device='cpu',
    lag=100,
    lr=1e-1,
)

### Merging Strains

In [None]:
sim1_fit1_gamma_merge, sim1_fit1_pi_merge, sim1_fit1_delta_merge  = sf.estimation.merge_similar_genotypes(
    sim1_fit1['gamma'],
    sim1_fit1['pi'],
    delta=sim1_fit1['delta'],
    thresh=0.1,
)

# print(sim1_gamma_init.shape[0], sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])
print(sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])

## Evaluation

In [None]:
sim1_gamma_adjusted = sf.genotype.mask_missing_genotype(sim1['gamma'][:, :g_fit], sim1['delta'][:, :g_fit])
sim1_fit1_gamma_adjusted = sf.genotype.mask_missing_genotype(sim1_fit1['gamma'], sim1_fit1['delta'])

### Ground Truth

#### Visualization

In [None]:
sf.plot.plot_genotype_comparison(
    data=dict(
        true=sim1_gamma_adjusted[:, :g_plt],
#         fit=sim1_fit1['gamma'][:, :g_plt],
        adj=sim1_fit1_gamma_adjusted[:, :g_plt],
#         init=sim1_gamma_init,
#         merg=sim1_fit1_gamma_merge,
    ),
    linkage_kw=dict(progress=True),
)

In [None]:
sf.plot.plot_community_comparison(
    data=dict(
        true=sim1['pi'],
        fit=sim1_fit1['pi'],
#         init=sim1_pi_init,
#         merg=sim1_fit1_pi_merge,
    ),
)

In [None]:
plt.scatter(sim1['epsilon'], sim1_fit1['epsilon'])
plt.plot([0, 0.04], [0, 0.04])

In [None]:
plt.scatter(sim1['alpha'], sim1_fit1['alpha'])
plt.plot([0, 200], [0, 200])

In [None]:
sns.heatmap(sim1_fit1['delta'], vmin=0, vmax=1)

In [None]:
plt.scatter(sim1['mu'], sim1_fit1['mu'])
plt.plot([0, 40], [0, 40])

In [None]:
# TODO: Plot comparing genotype accuracy to true strain abundance
# colored by mean entropy of the estimated genotype masked by delta

#### Fit scores

In [None]:
plt.scatter(sim1_fit1['alpha'], sample_mean_masked_genotype_entropy(sim1_fit1['pi'], sim1_fit1['gamma'], sim1_fit1['delta']))
sample_mean_masked_genotype_entropy(sim1_fit1['pi'], sim1_fit1['gamma'], sim1_fit1['delta']).mean()

In [None]:
best_hit, best_dist = match_genotypes(sim1_gamma_adjusted[:, :g_fit], sim1_fit1_gamma_adjusted[:, :g_fit])

print('weighted_mean_distance:', (best_dist * sim1['pi'].mean(0)).sum())
plt.scatter((sim1['pi'] * sim1['mu'].reshape(-1, 1)).sum(0), best_dist)

In [None]:
bc_sim = 1 - pdist(sim1['pi'], metric='braycurtis')
bc_fit = 1 - pdist(sim1_fit1['pi'], metric='braycurtis')
plt.scatter(
    bc_sim,
    bc_fit,
    marker='.',
    alpha=0.2,
)

community_accuracy_test(sim1['pi'], sim1_fit1['pi'])

### No Ground Truth

#### Visualization

In [None]:
# Strains that are not representative of true haplotypes
# are high entropy (even after masking with delta)
# and have low estimated total coverage.

best_true_strain, best_true_strain_dist = match_genotypes(sim1_fit1_gamma_adjusted[:, :g_fit], sim1_gamma_adjusted[:, :g_fit])
best_true_strain_dist

plt.scatter((sim1_fit1['pi'] * sim1_fit1['mu'].reshape((-1, 1))).sum(0), best_true_strain_dist, c=mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']))

In [None]:
plot_genotype(sim1_fit1_gamma_adjusted[:, :g_plt], linkage_kw=dict(progress=True))

In [None]:
plot_community(sim1_fit1['pi'])

#### Confidence Scores

In [None]:
plt.hist(mean_masked_genotype_entropy(sim1_fit1['gamma'][:, :g_plt], sim1_fit1['delta'][:, :g_plt]))
None

In [None]:
plt.hist(sim1_fit1['alpha'], bins=20)
None

In [None]:
plot_genotype(sim1_fit1['gamma'][mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']) < 0.1, :g_fit])