# Incremental PCA

The notebook demonstrates the use of [scikit-learn IncrementalPCA](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.IncrementalPCA.html#sklearn.decomposition.IncrementalPCA) to perform PCA on Census data.

Approach:

* Use a SOMA query to define the cells to be embedded,
* From these cells, select N top genes using the `experimental.pp.highly_variable_genes` method,
* Incrementally train over the selected cells and the N top genes,
* Compute components, and annotate the `obs` dataframe.

Depending on the number of cells and genes selected, this can be a resource intensive computation. It is known to complete succesfully when trained on the top 5000 genes for all cells in the human and mouse Census data, but requires a large host. For example, the full human PCA has been succesfully demonstrated on an AWS EC2 c6id.32xlarge instance.

In [3]:
import cellxgene_census
import numpy as np
import tiledbsoma as soma
from cellxgene_census.experimental.pp import highly_variable_genes
from sklearn.decomposition import IncrementalPCA


"""
Configuration - the dataset and computational parameters.
"""
census_version = "latest"  # which Census version is used
experiment_name = "mus_musculus"  # which organism: mus_musculus or homo_sapiens
obs_value_filter = "tissue_general == 'heart'"  # the subset of cells (both train and embed). Set to None if all cells.
n_components = 30  # number of components to keep in the final result
n_top_genes = 3000  # number of genes to use as analysis input

In [5]:
with cellxgene_census.open_soma(census_version=census_version) as census:
    exp = census["census_data"][experiment_name]

    with exp.axis_query(
        measurement_name="RNA",
        obs_query=soma.AxisQuery(value_filter=obs_value_filter),
    ) as query:
        print(f"{query.n_obs} cells selected")
        print("Beginning HVG calculation")
        hvgs = highly_variable_genes(query, n_top_genes=n_top_genes)
        var_soma_joinids = hvgs[hvgs.highly_variable].index.to_numpy()
        del hvgs
        print("Finished HVG calculation")

    with exp.axis_query(
        measurement_name="RNA",
        obs_query=soma.AxisQuery(value_filter=obs_value_filter),
        var_query=soma.AxisQuery(coords=(var_soma_joinids,)),
    ) as query:
        print("Start training")
        pca = IncrementalPCA(n_components=n_components)
        training_chunk_size = 2000
        for n, (chunk, _) in enumerate(query.X("raw").blockwise(axis=0).scipy()):
            for i in range(0, chunk.shape[0], training_chunk_size):
                training_chunk = chunk[i : i + training_chunk_size, :].toarray()
                pca.partial_fit(training_chunk)
        print("End training")

        obs = query.obs(column_names=["soma_joinid"]).concat().to_pandas().set_index("soma_joinid")
        for colname in (f"X_pca_{n}" for n in range(0, n_components)):
            obs[colname] = np.zeros((len(obs),), dtype=np.float64)

        print("Start transform")
        for n, (chunk, (obs_join_ids, _)) in enumerate(query.X("raw").blockwise(axis=0).scipy()):
            chunk_trnsfm = pca.transform(chunk.toarray())
            for c in range(n_components):
                obs.loc[obs_join_ids, f"X_pca_{c}"] = chunk_trnsfm[:, c]
        print("Complete")

obs

54846 cells selected
Beginning HVG calculation
Finished HVG calculation
Start training
End training
Start transform
Complete


 -6106.00199531 -6106.09804514]' has dtype incompatible with float32, please explicitly cast to a compatible dtype first.
  obs.loc[obs_join_ids, f"X_pca_{c}"] = chunk_trnsfm[:, c]
 -576.70016606]' has dtype incompatible with float32, please explicitly cast to a compatible dtype first.
  obs.loc[obs_join_ids, f"X_pca_{c}"] = chunk_trnsfm[:, c]
 -1015.17917319 -1011.41209974]' has dtype incompatible with float32, please explicitly cast to a compatible dtype first.
  obs.loc[obs_join_ids, f"X_pca_{c}"] = chunk_trnsfm[:, c]
 -1080.41530105 -1077.74939956]' has dtype incompatible with float32, please explicitly cast to a compatible dtype first.
  obs.loc[obs_join_ids, f"X_pca_{c}"] = chunk_trnsfm[:, c]
 -403.9789659 ]' has dtype incompatible with float32, please explicitly cast to a compatible dtype first.
  obs.loc[obs_join_ids, f"X_pca_{c}"] = chunk_trnsfm[:, c]
 -543.28581212]' has dtype incompatible with float32, please explicitly cast to a compatible dtype first.
  obs.loc[obs_join_id

Unnamed: 0_level_0,tissue_general,X_pca_0,X_pca_1,X_pca_2,X_pca_3,X_pca_4,X_pca_5,X_pca_6,X_pca_7,X_pca_8,...,X_pca_20,X_pca_21,X_pca_22,X_pca_23,X_pca_24,X_pca_25,X_pca_26,X_pca_27,X_pca_28,X_pca_29
soma_joinid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1098904,heart,-6103.897557,-576.094738,-1014.750946,-1076.603745,-404.040710,-545.348395,-1883.439240,167.353805,-928.235793,...,86.881651,236.440139,91.229725,184.084294,-45.376517,-65.396823,-279.537626,76.043615,14.196347,-39.152242
1098905,heart,-6036.403127,-531.924775,-1018.312283,-1059.016927,-403.877323,-538.872595,-1890.718545,148.791338,-940.647789,...,111.839093,176.946210,69.180401,217.722848,-54.336771,-52.895215,-277.535501,78.907913,18.370136,-39.080736
1098906,heart,-5966.964850,-552.323639,-1027.924142,-972.869535,-402.672055,-571.779429,-1901.798637,190.152645,-956.953252,...,107.388665,212.101164,89.270168,188.134473,-52.622232,-49.468095,-283.137515,82.380142,17.715630,-38.834921
1098907,heart,-6106.810835,-576.437019,-1014.225233,-1080.239086,-403.975038,-542.733454,-1881.976257,185.105403,-956.278856,...,103.591962,234.377189,93.730744,177.916249,-53.582481,-61.889106,-275.614753,63.275505,12.381560,-38.513311
1098908,heart,-6105.709163,-576.797047,-1015.435796,-1078.769678,-403.829511,-546.489207,-1901.371504,180.973627,-940.050741,...,101.460551,230.593577,95.965773,179.862356,-47.389897,-50.344945,-274.834650,66.330169,13.658505,-38.782565
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5319084,heart,-6107.054479,-576.587150,-786.633080,-1053.481896,-402.167164,-547.625854,-1906.217455,180.805991,-950.523210,...,98.710491,231.839639,95.178560,176.196306,-47.454766,-55.114983,-278.546931,73.568393,18.169595,-39.121850
5319085,heart,-6106.546675,-577.447876,-542.384272,-1024.999426,-400.482284,-546.255808,-1907.431921,181.835153,-952.357400,...,112.303364,234.827397,96.365195,176.753333,-44.485777,-62.055225,-276.452301,69.362886,16.810147,-39.298848
5319086,heart,-6106.346127,-576.517468,-1015.005396,-1079.435096,-403.719065,-547.139904,-1904.440632,181.556583,-944.288037,...,102.040556,230.710397,93.013368,174.963236,-47.427482,-52.918660,-277.060551,68.193486,15.742796,-38.665931
5319087,heart,-6106.001995,-575.606474,-1015.179173,-1080.415301,-403.728131,-546.869061,-1903.692560,180.381321,-938.620875,...,100.515246,230.548365,94.603516,174.249243,-46.325529,-55.394091,-280.171038,75.486676,18.861933,-38.850995


In [7]:
chunk_trnsfm.dtype

dtype('float64')