# cNMF in `cellarium-ml`

Stephen Fleming, Yang Xu

2025.07.31

The `cellarium-ml` project:

https://github.com/cellarium-ai/cellarium-ml

The specific implementation of cNMF we are actively working on:

https://github.com/cellarium-ai/cellarium-ml/pull/196

## Overview

`cellarium-ml` implements a variety of algorithms in a way that is scalable to hundreds of millions of cells and beyond.
This notebook provides a demo run of Cellarium's implementation of consensus NMF (cNMF).

Here we demonstrate our ability to run on a large dataset of 4.17M cells.

Because this is intended to run on a laptop, cellarium is running in "streaming" mode where it is continuously downloading files from a google bucket.
This is not the fastest way to run: downloading all curriculum h5ads locally first is much much faster.
This notebook is just a demonstration that it is possible to run without doing a full download all at once.

The specific algorithm for NMF is based on "Online learning for matrix factorization and sparse coding" by Mairal, Bach, Ponce, and Sapiro (JMLR 2010).

## This notebook

This notebook shows an end-to-end cNMF run in `cellarium-ml`, starting with h5ad files and ending with results.
There are several steps involved.

## Description of analysis steps

1. Compute highly-variable genes.

2. Run cNMF on selected highly-variable genes.

3. Interactive plotting in this notebook to help determine optimal number of programs `k`, and a 
   `density_threshold` and `local_neighborhood_size` for the consensus step.

4. Computing consensus factors.

5. (Optional) Computing per-cell factor loadings.

6. (Optional) Re-computing the `k` factor definitions using all genes (not just highly-variable genes).

## Inputs

- a lits of filepaths to h5ad files: can be local or in a google bucket (or at some URL)
    - ideally the h5ad files would be from an extracted curriculum from `cellarium-nexus`, but these files can be any h5ad files
        - until [#324](https://github.com/cellarium-ai/cellarium-ml/issues/324) is resolved, the h5ad files
          should be limited in size to what can fit in memory

## Outputs (work in progress... not complete)

- anndata object for all cells (with an empty count matrix) containing:
    - `adata.obsm["X_cnmf_k20"]`: (cell, k) matrix of per-cell factor loadings (for the `k = 20` decomposition)
    - `adata.obsp["cnmf_k20_factors_hvg"]`: (gene, k) matrix of definitions of each of the `k` consensus programs
        - from the initial cNMF fit: all non-highly-variable genes have weight zero
    - `adata.obsp["cnmf_k20_factors_hvg_tpm"]`: (gene, k) matrix of definitions of each of the `k` consensus programs
        - same as above but weights are recomputed to represent TPM values via a refitting step
    - `adata.obsp["cnmf_k20_factors"]`: (gene, k) matrix of definitions of each of the `k` consensus programs
        - computed by refitting cell loadings from `adata.obsp["cnmf_k20_factors_hvg"]` by refitting the dataset including all genes
    - `adata.obsp["cnmf_k20_factors_tpm"]`: (gene, k) matrix of definitions of each of the `k` consensus programs
        - same as above but weights are recomputed to represent TPM values via a refitting step
    - (optionally): all of the above for other choices of `k` as well

NOTE: You will need to use the `nmf_sf_singlenotebook` branch of `cellarium-ml` on github

In [None]:
import lightning.pytorch as pl

import cellarium.ml.api
import cellarium.ml.data
import cellarium.ml.models
import cellarium.ml.preprocessing
import cellarium.ml.transforms

In [None]:
%load_ext autoreload
%autoreload 2

# Data

In [None]:
# for demonstration purposes: automatically grab h5ad file paths from a bucket prefix, like data from Nexus

example_cellarium_curriculum_h5ad_paths = cellarium.ml.api.h5ad_paths_from_google_bucket(
    "gs://cellarium-nexus-file-system-335649/pipeline/data-extracts/all_cells_cap_freeze1_20250721/extract_files"
)
print(f"[{example_cellarium_curriculum_h5ad_paths[0]}, ...]")

In [None]:
# actually here I have mounted the bucket via gcsfuse using this command
# ~/go/bin/gcsfuse -o ro --only-dir pipeline/data-extracts/all_cells_cap_freeze1_20250721/extract_files
# cellarium-nexus-file-system-335649 /Users/sfleming/Desktop/fuse
# so the files look like they are local at /Users/sfleming/Desktop/fuse

example_cellarium_curriculum_h5ad_paths = [
    f.replace(
        "gs://cellarium-nexus-file-system-335649/pipeline/data-extracts/all_cells_cap_freeze1_20250721/extract_files",
        "/Users/sfleming/Desktop/fuse",
    )
    for f in example_cellarium_curriculum_h5ad_paths
]
print(f"[{example_cellarium_curriculum_h5ad_paths[0]}, ...]")

## Cellarium data setup

For this demo we are using the python API for cellarium.  It's also possible to use command line versions of these tools.

In [None]:
h5ad_paths = example_cellarium_curriculum_h5ad_paths
h5ad_paths[:5]

In [None]:
len(h5ad_paths)

(For remote files over the internet, this next cell can take a minute.)

In [None]:
# counts cells in each file, takes 10 mins

limits = cellarium.ml.api.get_h5ad_files_limits(h5ad_paths)

In [None]:
datamodule = cellarium.ml.CellariumAnnDataDataModule(
    dadc=cellarium.ml.data.DistributedAnnDataCollection(
        filenames=h5ad_paths,
        limits=limits,
        obs_columns_to_validate=[],
        max_cache_size=3,
        cache_size_strictly_enforced=True,
    ),
    batch_keys={
        "x_ng": cellarium.ml.utilities.data.AnnDataField(attr="X", convert_fn=cellarium.ml.utilities.data.densify),
        "var_names_g": cellarium.ml.utilities.data.AnnDataField(attr="var_names"),
        "obs_names_n": cellarium.ml.utilities.data.AnnDataField(attr="obs_names"),
    },
    batch_size=5000,
    shuffle=True,
    train_size=1.0,
    prefetch_factor=2,
    num_workers=8,
    persistent_workers=True,
)

datamodule.setup(stage="fit")

NOTE: It is highly recommended in practice to run this notebook on a machine where you can actually download all the h5ad files. Everything will run much faster.

Try
```
mkdir -p data_extract_h5ads
gsutil -m cp gs://cellarium-nexus-file-system-335649/pipeline/data-extracts/all_cells_cap_freeze1_20250721/extract_files/*.h5ad data_extract_h5ads/
```

and then replace `h5ad_paths` above with local paths.


Example:
- the onepass model below takes about 1.5 hr on 4M cells over my home wifi
- the onepass model below takes about 30 min on 4M cells over my home wifi with gcsfuse
- a onepass model on 4M cells on a local disk should take about 10 mins

# Highly variable genes

## Run onepass model

This computes mean and variance per gene.

In [None]:
# get gene names to use later (and assume all files have the same genes)

var_names_g = cellarium.ml.api.get_h5ad_file_var_names_g(h5ad_paths[0])
var_names_g[:3]

In [None]:
# set up the model that will be used to compute mean and var of each gene

onepass_module = cellarium.ml.CellariumModule(
    transforms=[
        cellarium.ml.transforms.NormalizeTotal(),
        cellarium.ml.transforms.Log1p(),
    ],
    model=cellarium.ml.models.OnePassMeanVarStd(
        var_names_g=cellarium.ml.api.get_h5ad_file_var_names_g(h5ad_paths[0]),
    ),
)

In [None]:
trainer = pl.Trainer(
    accelerator="cpu",
    devices=1,
    max_epochs=1,
    default_root_dir="tmp/onepass",
)
trainer.fit(onepass_module, datamodule)

In [None]:
# the onepass model computes a mean and variance per gene

mean_g = trainer.model.model.mean_g
var_g = trainer.model.model.var_g

## Compute hvgs

You can choose `n_top_genes` to suit your needs.

In [None]:
var = cellarium.ml.preprocessing.get_highly_variable_genes(
    gene_names=var_names_g,
    mean=mean_g,
    var=var_g,
    n_top_genes=2000,
)
var

In [None]:
var["highly_variable"].sum()

In [None]:
# the highly variable genes

hvg_var_names_g = var.index[var["highly_variable"]]
hvg_var_names_g

# cNMF

## Run NMF

Set things up to run cNMF in cellarium.

With data in a google bucket, 4M cells takes 1.5 hrs per epoch on my `mps` laptop with 20 `k` values and 50 NMF replicates.

Using gcsfuse, it's about the same. First epoch 2 hrs, later epochs 1 hr.

In [None]:
# user's choice for the number of components: must input a python list

# k_values = [10, 20, 30]
k_values = list(range(5, 25))

In [None]:
# user's choice for the number of NMF replicates that should go into consensus

nmf_replicates = 50

In [None]:
# get set up for training

nmf_model = cellarium.ml.models.NonNegativeMatrixFactorization(
    var_names_hvg=hvg_var_names_g,
    k_values=k_values,
    r=nmf_replicates,
)

nmf_module = cellarium.ml.CellariumModule(
    cpu_transforms=[cellarium.ml.transforms.Filter(filter_list=hvg_var_names_g)],
    model=nmf_model,
)

datamodule.setup(stage="fit")

trainer_nmf = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=3,  # this is up for debate, but empirically 5 was enough for the donor regression benchmark
    default_root_dir="tmp/nmf",
)

In [None]:
# train on the data

pl.seed_everything(0)  # not required but helps make this notebook reproducible

trainer_nmf.fit(nmf_module, datamodule)

In [None]:
# see the shape of the NMF gene programs that have been inferred: [replicates, k, genes]

for k in nmf_model.k_values:
    print(getattr(nmf_model, f"D_{k}_rkg").shape)

## Set up to explore outputs

We have a helper class that facilitates downstream analysis steps. Here we instantiate it and use it to get various outputs.

In [None]:
from cellarium.ml.models.nmf import NMFOutput

nmf_output = NMFOutput(
    nmf_module=nmf_module,
    datamodule=datamodule,
)

In [None]:
nmf_output

## Default k-selection plot

This is what Kotliar cNMF would produce with default values for `local_neighborhood_size=0.3` and `density_threshold=0.5`

In [None]:
nmf_output.default_k_selection_plot()

## Other versions of the k-selection plot

Kotliar does not demo this, but the k-selection plot itself depends on the values of `density_threshold` and `local_neighborhood_size`.

In principle, you could choose different hyperparameters for each `k`, run `nmf_output.compute_consensus_factors(k, <your selected params here>)` on all the `k`, and then re-run `nmf_output.calculate_reconstruction_error()` and re-create the k-selection plot using `nmf_output.k_selection_plot()`.

You know, one guiding principle for this process could be the following... for each `k`, automatically choose a (reasonable) `density_threshold` that maximizes the stability for the given `k`. Let's try it.

In [None]:
nmf_output.maximal_stability_k_selection_plot()

In [None]:
# visualize the results of this kind of automatic choosing of density_threshold
# note that the call to plot_clustermap() recomputes consensus if density_threshold is not None

for k in [10, 13, 21]:  # nmf_output.consensus:
    nmf_output.plot_clustermap(k=k, density_threshold=None)

(The strange thing about the stability metric is that it does not guarantee that there are actually `k` clusters.)

## Explore the factors

In [None]:
best_k = 13

Just look at some of the genes involved

In [None]:
nmf_output.nmf_module.model.var_names_hvg

In [None]:
import pandas as pd

# gene_name_lookup = adata.var['feature_name'].to_dict()

factor_df = pd.DataFrame(
    nmf_output.consensus[best_k]["consensus_D_kg"].t().numpy(),
    index=nmf_output.nmf_module.model.var_names_hvg,
)
# factor_df['gene_name'] = factor_df.index.map(gene_name_lookup)
factor_df

In [None]:
factor_df.sort_values(by=1, ascending=False).head(10)

In [None]:
factor_df.sort_values(by=2, ascending=False).head(10)

In [None]:
factor_df.sort_values(by=5, ascending=False).head(10)

In [None]:
factor_df.sort_values(by=6, ascending=False).head(10)

## Compute per-cell loadings

In [None]:
best_k

In [None]:
# pick ten h5ad files and load data (cells are randomly shuffled)
# in theory this could be done for all cells, but I have no use for that here

n_cells_visualization = 100_000
n_anndata_shards_visualization = n_cells_visualization // 10_000

datamodule_small = cellarium.ml.CellariumAnnDataDataModule(
    dadc=cellarium.ml.data.DistributedAnnDataCollection(
        filenames=h5ad_paths[:n_anndata_shards_visualization],
        limits=limits[:n_anndata_shards_visualization],
        obs_columns_to_validate=[],
        max_cache_size=n_anndata_shards_visualization,
        cache_size_strictly_enforced=True,
    ),
    batch_keys={
        "x_ng": cellarium.ml.utilities.data.AnnDataField(attr="X", convert_fn=cellarium.ml.utilities.data.densify),
        "var_names_g": cellarium.ml.utilities.data.AnnDataField(attr="var_names"),
        "obs_names_n": cellarium.ml.utilities.data.AnnDataField(attr="obs_names"),
    },
    batch_size=5000,
    shuffle=False,
    train_size=1.0,
    prefetch_factor=None,
    num_workers=None,
    persistent_workers=False,
)

datamodule_small.setup(stage="predict")

The loadings of each factor, computed for each cell.

In [None]:
# get per-cell factor loadings using the best k: this takes time
# `normalize` controls whether the per-cell loadings sum to 1

df = nmf_output.compute_loadings(k=best_k, datamodule=datamodule_small, normalize=True)

In [None]:
df.shape

In theory could add this information to the anndata object if you had a single object.
Here we will assume the dataset might be very large in total, so we will just try to grab a chunk of data and add the annotations for those cells.

In [None]:
# grab cells as anndata object (this might take time to download data from bucket)

adata = datamodule_small.dadc[:n_cells_visualization]
if adata.raw is not None:
    adata.layers["counts"] = adata.raw.X.copy()
else:
    adata.layers["counts"] = adata.X.copy()
adata

In [None]:
# add cNMF loadings to obsm

adata.obsm["X_nmf"] = df.loc[adata.obs_names].values
adata.obsm["X_nmf"].shape

### Visualize factor loadings on a UMAP

Just for fun, if you have scanpy installed in your environment.

In [None]:
import scanpy as sc

In [None]:
sc.set_figure_params(figsize=(5, 5), fontsize=14, vector_friendly=True)

In [None]:
sc.pp.highly_variable_genes(adata, layer="counts", flavor="seurat_v3", n_top_genes=2000)

In [None]:
adata.X = adata.layers["counts"].copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
# sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, use_highly_variable=True)
sc.pp.neighbors(adata, method="umap", metric="cosine", n_pcs=10)
sc.tl.umap(adata)
adata.obsm["X_umap_counts"] = adata.obsm["X_umap"].copy()

In [None]:
# put these in obs for plotting

for k in range(adata.obsm["X_nmf"].shape[1]):
    adata.obs[f"nmf_{k}"] = adata.obsm["X_nmf"][:, k]

In [None]:
adata.obs.columns

In [None]:
sc.pl.embedding(
    adata,
    basis="umap_counts",
    color=["cell_type", "brain_region_abbreviation", "cohort", "scpred_class", "village"],
    # color_map='Oranges',
    ncols=1,
)

In [None]:
sc.pl.embedding(adata, basis="umap_counts", color=[f"nmf_{i}" for i in range(0, best_k)], color_map="Oranges", ncols=2)

In [None]:
# try a UMAP derived from the NMF components

sc.pp.neighbors(adata, use_rep="X_nmf", method="umap", metric="cosine")
sc.tl.umap(adata)
adata.obsm["X_umap_nmf"] = adata.obsm["X_umap"].copy()

In [None]:
sc.pl.embedding(adata, basis="umap_nmf", color=["scpred_class"])

Definitely a bit wonky, but also definitely picking up on cell types.

In [None]:
sc.pl.embedding(adata, basis="umap", color=["neuropath_diagnosis"])

## Project factors back to all genes

Now refit for all genes, not just the highly variable genes. In cNMF this involves solving an auxiliary linear regression problem.

## Results as a summary anndata

The results can be packaged up into an anndata object if desired, and perhaps saved that way as an h5ad file.

Here we omit the actual count matrix, since in theory it is too big to fit in memory.