In [61]:
import scanpy as sc
import muon as mu
import numpy as np
import pandas as pd
import mofax as mofa
import seaborn as sns
import matplotlib.pyplot as plt
import pyro
from pyro.nn import PyroSample, PyroModule
from pyro.infer import SVI, Trace_ELBO, autoguide
import torch
import torch.nn.functional as F
from torch.nn.functional import softplus
from sklearn.metrics import mean_squared_error
import random
import seaborn as sns
import muon as mu
import anndata

In [62]:
# dir="/scratch/deeplife/"
dir="data/"
pbmc = sc.read_10x_h5(dir+"5k_pbmc_protein_v3_nextgem_filtered_feature_bc_matrix.h5", gex_only=False)
pbmc.var_names_make_unique()
pbmc.layers["counts"] = pbmc.X.copy()

  utils.warn_names_duplicates("var")


In [63]:
protein = pbmc[:, pbmc.var["feature_types"] == "Antibody Capture"].copy()
rna = pbmc[:, pbmc.var["feature_types"] == "Gene Expression"].copy()


In [64]:
class MOFA(PyroModule):
    def __init__(self, Ys: dict[str, torch.Tensor], K, batch_size=False, num_iterations=4000):
        """
        Args:
            Y: Tensor (Samples x Features)
            K: Number of Latent Factors
        """
        super().__init__()
        pyro.clear_param_store()
        
        self.Ys = Ys  # data/observations
        self.K = K  # number of factors 
        
        # assert sample dim same in Ys
        num_samples = set(Y.shape[0] for Y in self.Ys.values())
        assert len(num_samples) == 1
        self.num_samples = next(iter(num_samples))
        self.num_features = {k: v.shape[1] for k, v in self.Ys.items()}
        
        self.batch_size = batch_size
        self.num_iterations = num_iterations
        
        self.latent_factor_plate = pyro.plate("latent factors", self.K) 
        
        
    def model(self):
        """ Creates the model.
        
        The model needs to be created repeatedly (not sure why), in any case, it is important now, when using 
        `subsample_size` batch size to subsample the dataset differently in each train iteration
        """
        
        # needs to be shared, so returns the same indices in one train step
        sample_plate = pyro.plate("sample", self.num_samples, subsample_size=self.batch_size)
        # the plates get assigned a dim, depending on when in the plate hierarchy they are used. Unfortunately we want to use
        #   feature plates once outside and once inside other plate (sample resp. latent_factor plates, see below)
        #   we therefore need to create separate plates for each of those usages
        get_feature_plates = lambda dim: {k: pyro.plate(f"feature_{k}_{dim}", num_feats) for k, num_feats in self.num_features.items()}

        # W matrices for each modality
        Ws = {}
    
        # for each modality create W matrix and alpha vectors
        for m, feature_plate in get_feature_plates(-2).items():
            # the actual dimensions obtained by plates are read from right to left/inner to outer
            with self.latent_factor_plate:
                # Sample alphas (controls narrowness of weight distr for each factor) from a Gamma distribution
                # Gamma parametrization k, theta or eq. a, b; (where k=a and theta=1/b) 
                # (if k integer) Gamma = the sum of k independent exponentially distributed random variables, each of 
                # which has a mean of theta
                alpha = pyro.sample(f"alpha_{m}", pyro.distributions.Gamma(1, 1))
                
                with feature_plate:
                    # sample weight matrix with Normal prior distribution with alpha narrowness
                    Ws[m] = pyro.sample(f"W_{m}", pyro.distributions.Normal(0., 1. / alpha))                
                
        # create Z matrix
        # (the actual dimensions are read from right to left/inner to outer)
        with self.latent_factor_plate, sample_plate:
            # sample factor matrix with Normal prior distribution
            Z = pyro.sample("Z", pyro.distributions.Normal(0., 1.))
    
        # estimate for Y
        Y_hats = {k: torch.matmul(Z, W.t()) for k, W in Ws.items()}
        
        for m, feature_plate in get_feature_plates(-1).items():
            with feature_plate:
                # sample scale (tau) parameter for each feature-~~sample~~ pair with LogNormal prior (has to be positive)
                scale_tau = pyro.sample(f"scale_{m}", pyro.distributions.LogNormal(0., 1.))
                
                with sample_plate as sub_indices:
                    Y, Y_hat = self.Ys[m][sub_indices], Y_hats[m]
                    
                    # masking the NA values such that they are not considered in the distributions
                    obs_mask = torch.logical_not(torch.isnan(Y))
                    
                    with pyro.poutine.mask(mask=obs_mask):
                        # a valid value for the NAs has to be defined even though these samples will be ignored later
                        Y = torch.nan_to_num(Y, nan=0) 
                
                        # # sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
                        # scale = pyro.sample("scale", pyro.distributions.LogNormal(0., 1.))
                        
                        # compare sampled estimation to the true observation Y
                        pyro.sample(f"obs_{m}", pyro.distributions.Normal(Y_hat, scale_tau), obs=Y)
                        # pyro.sample("obs", pyro.distributions.NegativeBinomial(Y_hat, scale_tau), obs=Y)

    def train(self):
        # set training parameters
        optimizer = pyro.optim.Adam({"lr": 0.02})
        elbo = Trace_ELBO()
        guide = autoguide.AutoDelta(self.model)
        
        # initialize stochastic variational inference
        svi = SVI(
            model = self.model,
            guide = guide,
            optim = optimizer,
            loss = elbo
        )
        
        train_loss = []
        for j in range(self.num_iterations):
            # calculate the loss and take a gradient step
            # (loss should be already scaled down by the subsample_size)
            loss = svi.step()

            train_loss.append(loss/self.num_samples)
            if j % 200 == 0:
                print("[iteration %04d] loss: %.4f" % (j + 1, loss / self.num_samples))
        
        # Obtain maximum a posteriori estimates for W and Z
        # map_estimates = guide(self.Y)  # not sure why needed Y?
        # "Note that Pyro enforces that model() and guide() have the same call signature, i.e. both callables should take the same arguments."
        map_estimates = guide()
        
        return train_loss, map_estimates, guide


In [59]:
mofa = MOFA({
    'rna': torch.tensor(rna.X.toarray()),
    'protein': torch.tensor(protein.X.toarray()),
}, K=5, batch_size=33, num_iterations=4000)
loss, map_estimates, trained_guide = mofa.train()


[iteration 0001] loss: 1544855.6143
[iteration 0201] loss: 43426.8338
[iteration 0401] loss: -19751.4866
[iteration 0601] loss: -33614.6363
[iteration 0801] loss: -48270.5133
[iteration 1001] loss: -56885.2889
[iteration 1201] loss: -42953.5086
[iteration 1401] loss: -61820.0962
[iteration 1601] loss: -54214.0418
[iteration 1801] loss: -70193.2789
[iteration 2001] loss: -64776.4558
[iteration 2201] loss: -66183.7409
[iteration 2401] loss: -45832.1790
[iteration 2601] loss: -69544.2023
[iteration 2801] loss: -67679.5133
[iteration 3001] loss: -69013.3838
[iteration 3201] loss: -76465.5619
[iteration 3401] loss: -29830.0084
[iteration 3601] loss: -71609.8842
[iteration 3801] loss: -54700.3288


In [60]:
map_estimates

{'alpha_rna': tensor([ 0.6805,  0.7047, 51.2332,  6.3386, 44.6454],
        grad_fn=<ExpandBackward0>),
 'W_rna': tensor([[ 5.7743e-16,  4.8005e-16,  2.3599e-16,  2.2752e-17,  1.7289e-15],
         [ 1.5914e-04,  2.6798e-04,  1.2638e-03,  1.9229e-03,  1.1706e-03],
         [-7.3769e-06, -6.1725e-06,  9.5526e-06, -1.1013e-05,  1.0729e-05],
         ...,
         [-1.4335e-01,  5.1932e-02, -1.5062e-02,  6.9615e-02,  2.6003e-02],
         [ 7.9154e-22, -3.3275e-22, -7.4161e-23, -4.1880e-22,  1.9829e-22],
         [-1.4943e-03, -2.7019e-03, -7.7028e-03, -8.6110e-04,  1.1006e-02]],
        grad_fn=<ExpandBackward0>),
 'alpha_protein': tensor([0.3126, 0.5185, 1.0603, 0.4464, 0.7072], grad_fn=<ExpandBackward0>),
 'W_protein': tensor([[-2.5605, -0.4950, -0.4921, -2.5739,  0.4239],
         [-4.4372, -1.8422,  0.8739, -2.3409,  1.4852],
         [-2.8029, -0.1755, -0.7576, -1.4559,  0.1247],
         [-1.8529, -0.6747,  0.1978, -2.7917,  2.1802],
         [-2.6411, -0.4595, -0.2467, -2.2207,  1