In [9]:
import scanpy as sc
import muon as mu
import numpy as np
import pandas as pd
import mofax as mofa
import seaborn as sns
import matplotlib.pyplot as plt
import pyro
from pyro.nn import PyroSample, PyroModule
from pyro.infer import SVI, Trace_ELBO, autoguide
import torch
import torch.nn.functional as F
from torch.nn.functional import softplus
from sklearn.metrics import mean_squared_error
import random
import seaborn as sns
import muon as mu
import anndata

In [10]:
dir="/scratch/deeplife/"
pbmc = sc.read_10x_h5(dir+"5k_pbmc_protein_v3_nextgem_filtered_feature_bc_matrix.h5", gex_only=False)
pbmc.var_names_make_unique()
pbmc.layers["counts"] = pbmc.X.copy()

In [11]:
protein = pbmc[:, pbmc.var["feature_types"] == "Antibody Capture"].copy()
rna = pbmc[:, pbmc.var["feature_types"] == "Gene Expression"].copy()


In [21]:
class FA(PyroModule):
    def __init__(self, Y, K):
        """
        Args:
            Y: Tensor (Samples x Features)
            K: Number of Latent Factors
        """
        super().__init__()
        pyro.clear_param_store()
        
        # data
        self.Y = Y
        self.K = K
        
        self.num_samples = self.Y.shape[0]
        self.num_features = self.Y.shape[1]
        
        self.sample_plate = pyro.plate("sample", self.num_samples)
        self.feature_plate = pyro.plate("feature", self.num_features)
        self.latent_factor_plate = pyro.plate("latent factors", self.K)
        
        
    def model(self):
        """
        how to generate a matrix
        """
        with self.latent_factor_plate:
            with self.feature_plate:
                # sample weight matrix with Normal prior distribution
                W = pyro.sample("W", pyro.distributions.Normal(0., 1.))                
                
            with self.sample_plate:
                # sample factor matrix with Normal prior distribution
                Z = pyro.sample("Z", pyro.distributions.Normal(0., 1.))
        
        # estimate for Y
        Y_hat = torch.matmul(Z, W.t())
        
        with pyro.plate("feature_", self.Y.shape[1]), pyro.plate("sample_", self.Y.shape[0]):
            # masking the NA values such that they are not considered in the distributions
            obs_mask = torch.logical_not(torch.isnan(self.Y))
            with pyro.poutine.mask(mask=obs_mask):
                # a valid value for the NAs has to be defined even though these samples will be ignored later
                self.Y = torch.nan_to_num(self.Y, nan=0) 
        
                # sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
                scale = pyro.sample("scale", pyro.distributions.LogNormal(0., 1.))
                # compare sampled estimation to the true observation Y
                pyro.sample("obs", pyro.distributions.Normal(Y_hat, scale), obs=self.Y)


    def train(self):
        # set training parameters
        optimizer = pyro.optim.Adam({"lr": 0.02})
        elbo = Trace_ELBO()
        guide = autoguide.AutoDelta(self.model)
        
        # initialize stochastic variational inference
        svi = SVI(
            model = self.model,
            guide = guide,
            optim = optimizer,
            loss = elbo
        )
        
        num_iterations = 4000
        train_loss = []
        for j in range(num_iterations):
            # calculate the loss and take a gradient step
            loss = svi.step()

            train_loss.append(loss/self.Y.shape[0])
            if j % 200 == 0:
                print("[iteration %04d] loss: %.4f" % (j + 1, loss / self.Y.shape[0]))
        
        # Obtain maximum a posteriori estimates for W and Z
        map_estimates = guide(self.Y)
        
        return train_loss, map_estimates, guide


In [49]:
factor_model = FA(Y = torch.tensor(protein.X.toarray()), K = 5)
loss, map_estimates, trained_guide = factor_model.train()

[iteration 0001] loss: 2279998.1385
[iteration 0201] loss: 87566.6773
[iteration 0401] loss: 30587.6447
[iteration 0601] loss: 15317.3306
[iteration 0801] loss: 9143.7706
[iteration 1001] loss: 6099.0164
[iteration 1201] loss: 4400.4007
[iteration 1401] loss: 3368.6171
[iteration 1601] loss: 2695.5634
[iteration 1801] loss: 2233.0764
[iteration 2001] loss: 1901.8620
[iteration 2201] loss: 1658.9542
[iteration 2401] loss: 1476.1332
[iteration 2601] loss: 1334.7491
[iteration 2801] loss: 1223.7632
[iteration 3001] loss: 1135.0665
[iteration 3201] loss: 1062.6931
[iteration 3401] loss: 1001.9166
[iteration 3601] loss: 949.7983
[iteration 3801] loss: 903.5972
