# cNMF in `cellarium-ml`

Yang Xu

Stephen Fleming

2024.11.20

The `cellarium-ml` project:

https://github.com/cellarium-ai/cellarium-ml

The specific implementation of cNMF we are actively working on:

https://github.com/cellarium-ai/cellarium-ml/pull/196

## Overview

`cellarium-ml` implements a variety of algorithms in a way that is scalable to hundreds of millions of cells and beyond. This notebook provides a demo run of Cellarium's implementation of consensus NMF (cNMF). The specific algorithm for NMF is based on "Online learning for matrix factorization and sparse coding" by Mairal, Bach, Ponce, and Sapiro (JMLR 2010).

__Pre-processing__ which can also be done using `cellarium-ml` but is not part of this notebook.

0. Compute highly-variable genes.

__Running cNMF happens in four stages:__

1. The initial fit on selected highly-variable genes.

    This amounts to creating a YAML file and then running a single command from the command line:
    ```bash
    cellarium-ml nmf fit --config config.yaml
    ```

2. Interactive plotting in the notebook to help determine optimal `k`, `density_threshold`, and `local_neighborhood_size`.

    Uses functions that are currently called `update_consensusD()` and `calculate_rec_error()` in this notebook, along with some plotting.

3. Computing per-cell factor loadings.

    Uses the function currently called `get_embeddding()` in this notebook.

4. Re-computing the `k` factor definitions using all genes (not just highly-variable genes).

    Not yet part of this notebook.

NOTE: You will need to use the `cnmf-yx-streamline` branch of `cellarium-ml` on github.

In [None]:
from cellarium.ml.core import CellariumModule
from cellarium.ml.utilities.data import AnnDataField, densify
from cellarium.ml.data import (
    DistributedAnnDataCollection,
    IterableDistributedAnnDataCollectionDataset,
)
from cellarium.ml.models.nmf import calculate_rec_error, get_embedding, update_consensusD

import os
from string import Template

import torch
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

## Parameters

In [None]:
# choose which values of k to use

k_values: list[int] = list(range(10, 40))  # this thing needs to be a python list

# choose how many repeats of NMF to run to create a "consensus"

num_repeats: int = 100

## Data

This demo uses a human heart dataset which is hosted in a google bucket. We will first download the dataset to the machine where this notebook is running, and then we will run cNMF.

We require 2 files:

- The dataset in h5ad format. In this case the entire dataset is a single h5ad file, but `cellarium-ml` can use an arbitrary number of h5ad files.
- The highly variable genes in CSV format.

In [None]:
# change these values to run on a different dataset:

# define file paths

dataset_h5ad = "gs://yx-data/Strati_pbmc_rna.h5ad"
highly_variable_genes_csv = "gs://yx-data/Strati_pbmc_rna_2000hvg.csv"

# define working directory

working_dir = "./tmp"

In [None]:
!mkdir -p $working_dir

In [None]:
# localize data

local_h5ad = os.path.join(working_dir, "data.h5ad")
local_hvg_csv = os.path.join(working_dir, "hvg.csv")

!gsutil cp $dataset_h5ad $local_h5ad
!gsutil cp $highly_variable_genes_csv $local_hvg_csv

## Config file

This part contains a hack to make things seem simpler: we provide a template config.yaml file here and we modify that file according to the inputs above.

In reality, you would probably create the config.yaml file directly without using the little helpers in this section of the notebook. But this helper makes things smoother for this demo.

In [None]:
# define the path to the config file

config_file_template = "../examples/cli_workflow/nmf_config_template.yaml"  # if you have cloned the repository, it's here

In [None]:
adata = sc.read_h5ad(local_h5ad)
dataset_ncells = adata.n_obs
n_genes_total = adata.n_vars

In [None]:
# modification of the yaml config file to point to this data

local_config_yaml = os.path.join(working_dir, "config.yaml")

with open(config_file_template, "r") as file:
    yaml_text = file.read()

substitutions = {
    "highly_variable_genes_csv": os.path.abspath(local_hvg_csv),
    "dataset_h5ad": os.path.abspath(local_h5ad),
    "dataset_ncells": dataset_ncells,
    "k_range": k_values,
    "num_repeats": num_repeats,
    "n_genes_total": n_genes_total,
}

template = Template(yaml_text)
customized_yaml = template.substitute(substitutions)

# write the customized YAML to a local file in the working directory
with open(local_config_yaml, "w") as file:
    file.write(customized_yaml)

print(f"Config YAML written to: {local_config_yaml}")

In [None]:
# take a look at the config file we end up with: again you could skip the above and write this file manually

!cat $local_config_yaml

## Run cNMF

In [None]:
!cellarium-ml nmf fit -c $local_config_yaml

## Load the trained NMF model

and the dataset

In [None]:
# helper function

def get_cellarium_dataset_from_h5ad(
    h5ad: str, 
    batch_size: int = 1024, 
    shard_size: int | None = None, 
    shuffle: bool = False, 
    drop_last_indices: bool = False,
) -> IterableDistributedAnnDataCollectionDataset:
    """
    Get IterableDistributedAnnDataCollectionDataset from an h5ad file specifier.

    Args:
        h5ad: h5ad file, allowing brace notation for several files.
        batch_size: Batch size.
        shard_size: Shard size.
        shuffle: Whether to shuffle the dataset.
        drop_last_indices: Whether to drop the last incomplete batch.

    Returns:
        IterableDistributedAnnDataCollectionDataset.
    """
    dadc = DistributedAnnDataCollection(
        h5ad,
        shard_size=shard_size,
        max_cache_size=1,
    )

    dataset = IterableDistributedAnnDataCollectionDataset(
        dadc,
        batch_keys={
            "x_ng": AnnDataField(attr="X", convert_fn=densify),
            "var_names_g": AnnDataField(attr="var_names"),
            "obs_names_n": AnnDataField(attr="obs_names"),
        },
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last_indices=drop_last_indices,
    )
    return dataset

In [None]:
# load trained module

ckpt_file = "NMF.ckpt"  # currently this is hard-coded, but we will change this

module = CellariumModule.load_from_checkpoint(ckpt_file)
module.eval()
module

In [None]:
# get a dataset object in the necessary cellarium-ml format

dataset = get_cellarium_dataset_from_h5ad(
    os.path.abspath(local_h5ad), 
    shard_size=adata.n_obs,
    shuffle=False, 
    drop_last_indices=False,
)
dataset

### Compute consensus factors

In [None]:
# calculate consensus D for all Ks

# change these thresholds if desired
density_threshold = 0.1
local_neighborhood_size = 0.3

consensus_stat = update_consensusD(
    module.pipeline,
    density_threshold=density_threshold, 
    local_neighborhood_size=local_neighborhood_size,
)

### Create clustermap plots

In [None]:
# (optional) get gene loadings (of highly variable genes) for a specific k

k_range = module.pipeline[-1].k_range


def get_gene_loadings_for_k(k: int, model=module.pipeline[-1]) -> torch.Tensor:
    D_kg = getattr(model, f"D_{k}_kg")
    D_kg = D_kg.detach().cpu()
    return D_kg


D_kg = get_gene_loadings_for_k(10)
print(D_kg.shape)
D_kg

In [None]:
sc.set_figure_params(scanpy=True, dpi=75, dpi_save=75, vector_friendly=True)

In [None]:
for k in k_range:
    cg = sns.clustermap(consensus_stat[k]['topk_euc_dist'].cpu().numpy(), 
                   row_cluster=True, col_cluster=True, cbar_pos=(0.05, 0.25, 0.03, 0.15),
                   xticklabels=False, yticklabels=False)
    cg.ax_row_dendrogram.set_visible(False)
    cg.ax_col_dendrogram.set_visible(False)
    # cg.cax.set_visible(False)
    cg.cax.set_ylabel('Euclidean distance')
    cg.ax_heatmap.set_title(f"k = {k}")
    plt.show()
    
    sns.histplot(consensus_stat[k]['local_neigh_dist'].cpu().numpy())
    ymax = plt.gca().get_ylim()[1]
    plt.vlines(density_threshold, ymin=0, ymax=ymax, color='Red')
    plt.xlabel(f'Mean distance to {int(num_repeats * local_neighborhood_size)} nearest neigbors')
    plt.ylabel('Runs of NMF')
    plt.title(f"k = {k} local density histogram")
    plt.ylim(0, ymax)
    
    plt.show()

### Compute reconstruction error at each k

In [None]:
# we need to calculate the reconstruction error: this takes time

rec_errors = calculate_rec_error(dataset, module.pipeline)

### Create the k-selection plot

In [None]:
silhouette_scores = {}
for k in k_values:
    silhouette_scores[k] = consensus_stat[k]['stability']
eval_metrics = pd.DataFrame.from_dict(silhouette_scores, orient='index')
eval_metrics.columns = ['stability']
eval_metrics['rec_error'] = rec_errors

In [None]:
# plot stability and reconstruction error

sns.set_style("ticks")
plt.plot(eval_metrics.index, eval_metrics['stability'], 'o-', color='r')
plt.ylabel('Stability', color='r')
plt.xlabel('Number of components: k')
plt.gca().tick_params(axis='y', colors='r')
plt.twinx()
plt.plot(eval_metrics.index, eval_metrics['rec_error'], 'o-', color='b')
plt.ylabel('Reconstruction error', color='b')
plt.gca().tick_params(axis='y', colors='b')
plt.show()

### Compute per-cell loadings

The loadings of each factor, computed for each cell.

In [None]:
# get per-cell factor loadings using the best k: this takes time

best_k = 14
obsm_key_added = 'X_nmf'
df = get_embedding(dataset, module.pipeline, k=best_k)

In [None]:
# add this information to the anndata object

adata.obsm['X_nmf'] = df.loc[adata.obs_names].values
adata.obsm['X_nmf'].shape

In [None]:
# see whether the loadings sum to 1 (nearly)

adata.obsm['X_nmf'].sum(axis=1)

### UMAP based on cNMF factor loadings

In [None]:
# compute UMAP

sc.pp.neighbors(adata, use_rep='X_nmf', n_neighbors=15, metric='cosine')
sc.tl.umap(adata)

In [None]:
sc.pl.embedding(adata, basis='umap', color=['cell_type', 'disease_setting', 'therapy', 'analysis_group', 'sex', 'donor_id'], ncols=1)