In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
import pyro
import pyro.distributions as dist
import torch
from pyro.nn import PyroModule

'''
TODO: Every model class must inherit from Abstract Meta Model Class, e.g.

class GenerativeModel(PyroModule):
    def __init__(self, device: Optional[Union[torch.device, int]] = None):
        """Generative model base class.

        The class uses a fixed number of tensor dimensions, associated with
        - -3: samples
        - -2: factors
        - -1: features
        Parameters
        ----------
        device : Optional[Union[torch.device, int]], optional
            Device to run the model on, by default None
        """
        super().__init__(name="GenerativeModel")
        self.device = device

    def get_plates(self, *args, **kwargs) -> Dict[str, pyro.plate]:
        raise NotImplementedError()
    def forward(self, *args, **kwargs) -> None:
        raise NotImplementedError()
'''


  from .autonotebook import tqdm as notebook_tqdm


'\nTODO: Every model class must inherit from Abstract Meta Model Class, e.g.\n\nclass GenerativeModel(PyroModule):\n    def __init__(self, device: Optional[Union[torch.device, int]] = None):\n        """Generative model base class.\n\n        The class uses a fixed number of tensor dimensions, associated with\n        - -3: samples\n        - -2: factors\n        - -1: features\n        Parameters\n        ----------\n        device : Optional[Union[torch.device, int]], optional\n            Device to run the model on, by default None\n        """\n        super().__init__(name="GenerativeModel")\n        self.device = device\n\n    def get_plates(self, *args, **kwargs) -> Dict[str, pyro.plate]:\n        raise NotImplementedError()\n    def forward(self, *args, **kwargs) -> None:\n        raise NotImplementedError()\n'

<IPython.core.display.Javascript object>

In [3]:
class MOFA_Model(PyroModule):
    def __init__(self, n_factors: int):
        super().__init__(name="MOFA_Model")
        self.n_factors = n_factors

    def _setup(self, n_obs, feature_offsets):
        # TODO: at some point replace n_obs with obs_offsets
        self.n_obs = n_obs
        self.n_features = feature_offsets[-1]
        self.n_feature_groups = len(feature_offsets) - 1
        self.feature_offsets = feature_offsets

    def forward(self, X):
        """Generative model for MOFA."""
        plates = self.get_plates()

        feature_group_scale = pyro.sample(
            "feature_group_scale", dist.HalfCauchy(torch.ones(self.n_feature_groups))
        ).view(-1, self.n_feature_groups)

        with plates["obs"], plates["factors"]:
            z = pyro.sample("z", dist.Normal(torch.zeros(1), torch.ones(1))).view(-1, self.n_obs, self.n_factors, 1)

        # print(z.shape)

        with plates["features"], plates["factors"]:
            # implement the horseshoe prior
            w_shape = (-1, 1, self.n_factors, self.n_features)
            w_scale = pyro.sample("w_scale", dist.HalfCauchy(torch.ones(1))).view(w_shape)
            w_scale = torch.cat(
                [
                    w_scale[..., self.feature_offsets[m] : self.feature_offsets[m + 1]]
                    * feature_group_scale[..., m : m + 1]
                    for m in range(self.n_feature_groups)
                ],
                dim=-1,
            )
            w = pyro.sample("w", dist.Normal(torch.zeros(1), w_scale)).view(w_shape)

        # print(w.shape)
        with plates["features"]:
            sigma = pyro.sample("sigma", dist.InverseGamma(torch.tensor(3.0), torch.tensor(1.0))).view(
                -1, 1, 1, self.n_features
            )

        # print(sigma.shape)
        with plates["obs"]:
            prod = torch.einsum("...ikj,...ikj->...ij", z, w).view(-1, self.n_obs, 1, self.n_features)
            # print(prod.shape)
            y = pyro.sample("data", dist.Normal(prod, torch.sqrt(sigma)), obs=X.view(1, self.n_obs, 1, self.n_features))

        return {"z": z, "w": w, "sigma": sigma, "y": y}

    def get_plates(self):
        return {
            "obs": pyro.plate("obs", self.n_obs, dim=-3),
            "factors": pyro.plate("factors", self.n_factors, dim=-2),
            "features": pyro.plate("features", self.n_features, dim=-1),
        }


<IPython.core.display.Javascript object>

In [4]:
import numpy as np


<IPython.core.display.Javascript object>

In [5]:
n_obs = 100
feature_offsets = [0, 10, 50]
n_factors = 5

X = np.random.normal(size=(n_obs, feature_offsets[-1]))
model = MOFA_Model(n_factors)
model._setup(n_obs, feature_offsets)


<IPython.core.display.Javascript object>

In [6]:
model(torch.Tensor(X))['w'].shape


torch.Size([1, 1, 5, 50])

<IPython.core.display.Javascript object>

In [7]:
# from cellij.core.models import MOFA
# from cellij.core._data import Importer

  from pandas.core.index import RangeIndex


ImportError: cannot import name 'PathLike' from 'anndata.compat' (/Users/chigurh/miniconda3/envs/cellij/lib/python3.9/site-packages/anndata/compat/__init__.py)

<IPython.core.display.Javascript object>