# GenePT cell classification performance on the Tabula Sapiens data set

This notebook downloads (if necessary) the Tabula Sapiens data set (https://cellxgene.cziscience.com/collections/e5f58829-1a66-40b5-a624-9046778e74f5)
uses a GenePT-style embedding to embed the cells and then tests classification performance.  This is a benchmark dataset, but we are using it for
both training and testing. We are doing this to "benchmark" our GenePT embeddings on a large dataset with a high number of cell types.


In [1]:
%run notebook_setup.ipynb

autoreload enabled
repo_dir set to /Users/rj/personal/GenePT-tools
File already exists at /Users/rj/personal/GenePT-tools/data/GenePT_emebdding_v2.zip
Extracting files...
Extracting GenePT_emebdding_v2/
Skipping GenePT_emebdding_v2/NCBI_UniProt_summary_of_genes.json - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_embedding_ada_text.pickle - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_protein_embedding_model_3_text.pickle. - already exists with same size
Skipping GenePT_emebdding_v2/NCBI_summary_of_genes.json - already exists with same size
Extraction complete!
Skipping embedding_original_ada_text.parquet - already exists
Skipping embedding_original_large_3.parquet - already exists
Skipping embedding_associations_age_cell_type_drugs_pathways_openai_large.parquet - already exists
Skipping embedding_associations_age_drugs_pathways_openai_large.parquet - already exists
Skipping embedding_associations_cell_type_openai_large.parquet - alrea

## Download the Tabula Sapiens data set

We use the name `1m_cells.h5ad` for this dataset, which is not the best name.

In [2]:
import requests

dataset = "https://datasets.cellxgene.cziscience.com/10df7690-6d10-4029-a47e-0f071bb2df83.h5ad"
# dataset_id = "10df7690-6d10-4029-a47e-0f071bb2df83"

file_path = data_dir / "1m_cells.h5ad"  # adjust this path as needed

if not file_path.exists():
    response = requests.get(dataset, stream=True)
    with open(file_path, "wb") as file:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:  # filter out keep-alive chunks
                file.write(chunk)

In [3]:
import h5py

with h5py.File(file_path, "r") as f:
    # Look at the structure of the X group
    print("Contents of X group:", list(f["X"].keys()))

    # Look at obs and var to get dimensions
    print("\nContents of obs group:", list(f["obs"].keys()))
    print("Contents of var group:", list(f["var"].keys()))

    # If X contains a sparse matrix, it likely has 'data', 'indices', and 'indptr'
    if "data" in f["X"]:
        print("\nShape of X/data:", f["X"]["data"].shape)
        print("Shape of X/indices:", f["X"]["indices"].shape)
        print("Shape of X/indptr:", f["X"]["indptr"].shape)

Contents of X group: ['data', 'indices', 'indptr']

Contents of obs group: ['10X_run', '_index', '_scvi_batch', '_scvi_labels', 'ambient_removal', 'anatomical_position', 'assay', 'assay_ontology_term_id', 'broad_cell_class', 'cdna_plate', 'cdna_well', 'cell_type', 'cell_type_ontology_term_id', 'compartment', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_assay', 'donor_id', 'donor_method', 'donor_tissue', 'donor_tissue_assay', 'ethnicity_original', 'free_annotation', 'is_primary_data', 'library_plate', 'manually_annotated', 'method', 'n_genes_by_counts', 'notes', 'observation_joinid', 'organism', 'organism_ontology_term_id', 'pct_counts_ercc', 'pct_counts_mt', 'published_2022', 'replicate', 'sample_id', 'sample_number', 'scvi_leiden_donorassay_full', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'tissue', 'tissue_in_publication', 'tissue_ontology_term_id'

In [4]:
with h5py.File(file_path, "r") as f:
    # Look at obs and var to get dimensions
    print("\nContents of obs['broad_cell_class'] group:", len(f["obs"]["cell_type"]["categories"]))


Contents of obs['broad_cell_class'] group: 180


## Load 100K cells 

We load a large subset of the data to make sure that we have enough cells in the smaller classes to train on.  We can use the whole dataset
to select cells of each type from.

In [5]:
import h5py
import anndata as ad
from scipy import sparse
import pandas as pd
from src.utils import load_subset_anndata

# Load with specific obs metadata columns
adata_filtered = load_subset_anndata(
    file_path,
    start_row=0,
    n_rows=100000,
    obs_columns=["cell_type", "broad_cell_class", "donor_id"],
)

print("AnnData shape:", adata_filtered.shape)
print("Feature metadata columns:", adata_filtered.var.columns)  # Print all var metadata
print(
    "Selected Observation metadata columns:", adata_filtered.obs.columns
)  # Print selected obs metadata
print(
    "Matrix density:",
    adata_filtered.X.nnz / (adata_filtered.shape[0] * adata_filtered.shape[1]),
)

AnnData shape: (100000, 61759)
Feature metadata columns: Index(['ensembl_id', 'ensg', 'ercc', 'feature_biotype', 'feature_is_filtered',
       'feature_length', 'feature_name', 'feature_reference', 'feature_type',
       'genome', 'mean', 'mean_counts', 'mt', 'n_cells_by_counts',
       'pct_dropout_by_counts', 'std', 'total_counts'],
      dtype='object')
Selected Observation metadata columns: Index(['cell_type', 'broad_cell_class', 'donor_id'], dtype='object')
Matrix density: 0.045159779789180524




### Count the number of genes and cells in the dataset

* `f['var']['feature_name']` contains the gene names
* `f['obs']` contains metadata about the cells
* `f['X']['indptr']` contains the indices of the cells in the sparse matrix
* `f['X']['indices']` contains the indices of the genes in the sparse matrix rows
* `f['X']['data']` contains the data for the matrix

In [6]:
with h5py.File(file_path, "r") as f:
    print(f["var"]["feature_name"]["categories"])
    # print(f['X']['indices'][:10])

<HDF5 dataset "categories": shape (61759,), type "|O">


In [7]:
gene_names = adata_filtered.var.feature_name
ensembl_ids = adata_filtered.var.ensembl_id
major_ensembl_ids = pd.Series(ensembl_id.split(".")[0] for ensembl_id in ensembl_ids)

## Load the GenePT embedding data
These embeddings were loaded from huggingface in the setup above.

In [8]:
from datasets import load_dataset

gene_info_table_dataset = load_dataset(
    "honicky/genept-composable-embeddings-source-data", "gene_info"
)
gene_info_table = gene_info_table_dataset["train"].to_pandas()
gene_info_table.head()

  from .autonotebook import tqdm as notebook_tqdm
Using the latest cached version of the dataset since honicky/genept-composable-embeddings-source-data couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'gene_info' at /Users/rj/.cache/huggingface/datasets/honicky___genept-composable-embeddings-source-data/gene_info/0.0.0/5a0eee4f5ea012c2de528b150376dc0d4b3c9e45 (last modified on Tue Jan 14 05:59:50 2025).


Unnamed: 0,gene_name,ensembl_id,gene_type
0,TSPAN6,ENSG00000000003,protein_coding
1,TNMD,ENSG00000000005,protein_coding
2,DPM1,ENSG00000000419,protein_coding
3,SCYL3,ENSG00000000457,protein_coding
4,C1orf112,ENSG00000000460,protein_coding


In [9]:
huggingface_model_dir = data_dir / "huggingface_model"
gene_embedding = pd.read_parquet(
    huggingface_model_dir
    / "embedding_associations_cell_type_tissue_drug_pathway_openai_large.parquet"
)

# Embed using the mean for duplicate Ensembl IDs

Some of the embeddings have duplicate Ensembl IDs (but separate gene names). Take the mean of the duplicates as a rough heuristic.

In [10]:
gene_embeddings_with_ensembl_id = gene_embedding.merge(
    gene_info_table, left_index=True, right_on="gene_name"
)

In [11]:
gene_embeddings_with_ensembl_id[
    gene_embeddings_with_ensembl_id.ensembl_id == "ENSG00000222005"
]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,3065,3066,3067,3068,3069,3070,3071,gene_name,ensembl_id,gene_type
83316,-0.008708,0.019564,-0.004104,0.007363,0.001401,-0.015853,0.031096,-0.013775,0.029875,0.012154,...,-0.027197,-0.00172,-0.017885,0.026281,-0.014139,-0.013329,-0.027221,LINC01118,ENSG00000222005,
83105,-0.018096,0.014083,-0.011138,-0.001666,0.012291,-6.7e-05,0.038749,-0.007913,0.007358,0.0082,...,-0.016436,-0.012757,-0.025251,0.022767,-0.019888,-0.012172,-0.025777,LINC01119,ENSG00000222005,


In [12]:
import numpy as np

matching_ensembl_ids = pd.DataFrame(major_ensembl_ids, columns=["ensembl_id"]).merge(
    gene_embeddings_with_ensembl_id, left_on="ensembl_id", right_on="ensembl_id"
)[["gene_name", "ensembl_id"]]
ensembl_id_counts = matching_ensembl_ids.ensembl_id.value_counts()
ensembl_id_counts[ensembl_id_counts > 1]

ensembl_id
ENSG00000000003    2
ENSG00000243485    2
ENSG00000222005    2
ENSG00000264405    2
ENSG00000204792    2
ENSG00000201388    2
ENSG00000222345    2
ENSG00000276234    2
ENSG00000187838    2
ENSG00000202377    2
ENSG00000226364    2
ENSG00000269433    2
ENSG00000183598    2
ENSG00000267151    2
ENSG00000236790    2
ENSG00000264073    2
ENSG00000265134    2
ENSG00000238936    2
ENSG00000270722    2
ENSG00000226444    2
ENSG00000269099    2
ENSG00000204397    2
ENSG00000249532    2
ENSG00000206603    2
ENSG00000284917    2
ENSG00000268942    2
ENSG00000206903    2
ENSG00000063587    2
ENSG00000274020    2
ENSG00000255154    2
ENSG00000090857    2
ENSG00000254508    2
ENSG00000206897    2
ENSG00000251866    2
ENSG00000226419    2
ENSG00000264448    2
ENSG00000206785    2
ENSG00000197927    2
ENSG00000245080    2
ENSG00000269586    2
ENSG00000227518    2
ENSG00000145491    2
ENSG00000221164    2
ENSG00000250331    2
ENSG00000223770    2
ENSG00000269955    2
ENSG00000207187    2
EN

In [13]:
import numpy as np

# Get embeddings without metadata columns
embedding_cols = [
    col
    for col in gene_embeddings_with_ensembl_id.columns
    if col not in ["gene_name", "ensembl_id", "gene_type"]
]

# Group by ensembl_id and take mean of embeddings
merged_embeddings = (
    gene_embeddings_with_ensembl_id.groupby("ensembl_id")[embedding_cols]
    .mean()
    .reset_index()
)

# Renormalize the embeddings
embedding_values = merged_embeddings[embedding_cols].values
norms = np.linalg.norm(embedding_values, axis=1, keepdims=True)
merged_embeddings[embedding_cols] = embedding_values / norms

# Verify the results
print(f"Original shape: {gene_embeddings_with_ensembl_id.shape}")
print(f"After merging duplicates: {merged_embeddings.shape}")

# Verify all duplicates are resolved
duplicate_check = merged_embeddings.ensembl_id.value_counts()
print("\nNumber of remaining duplicates:", (duplicate_check > 1).sum())

Original shape: (37220, 3075)
After merging duplicates: (36573, 3073)

Number of remaining duplicates: 0


In [14]:
merged_embeddings.ensembl_id

0        ENSG00000000003
1        ENSG00000000005
2        ENSG00000000419
3        ENSG00000000457
4        ENSG00000000460
              ...       
36568    ENSGR0000230542
36569    ENSGR0000234958
36570    ENSGR0000236871
36571    ENSGR0000237040
36572    ENSGR0000265658
Name: ensembl_id, Length: 36573, dtype: object

In [15]:
# # Get the embedding values without the metadata columns
# embedding_cols = [col for col in merged_embeddings.columns
#                  if col not in ['ensembl_id']]

# # Create a mapping from major_ensembl_ids to column indices in cell_gene_matrix
# gene_idx_map = {gene_id: idx for idx, gene_id in enumerate(major_ensembl_ids)}

# # Find which embeddings correspond to genes in our expression matrix
# # and get their indices in the correct order
# valid_indices = []
# embedding_indices = []
# for i, ensembl_id in enumerate(merged_embeddings.ensembl_id):
#     if ensembl_id in gene_idx_map:
#         valid_indices.append(gene_idx_map[ensembl_id])
#         embedding_indices.append(i)

# # Create the reordered embedding matrix
# embedding_matrix = merged_embeddings[embedding_cols].iloc[embedding_indices].values.T

In [16]:
# def create_embedding_matrix(merged_embeddings, major_ensembl_ids):
#     """
#     Create a reordered embedding matrix that aligns gene embeddings with expression matrix columns.

#     Args:
#         merged_embeddings (pd.DataFrame): DataFrame containing gene embeddings with 'ensembl_id' column
#         major_ensembl_ids (pd.Series): Series of Ensembl IDs in the order they appear in expression matrix

#     Returns:
#         tuple: (embedding_matrix, valid_indices)
#             - embedding_matrix: numpy array of shape (n_embedding_dims, n_valid_genes)
#             - valid_indices: list of indices mapping to original expression matrix columns
#     """
#     # Get the embedding values without the metadata columns
#     embedding_cols = [
#         col for col in merged_embeddings.columns if col not in ["ensembl_id"]
#     ]

#     # Create a mapping from major_ensembl_ids to column indices in cell_gene_matrix
#     gene_idx_map = {gene_id: idx for idx, gene_id in enumerate(major_ensembl_ids)}

#     # Find which embeddings correspond to genes in our expression matrix
#     # and get their indices in the correct order
#     valid_indices = []
#     embedding_indices = []
#     for i, ensembl_id in enumerate(merged_embeddings.ensembl_id):
#         if ensembl_id in gene_idx_map:
#             valid_indices.append(gene_idx_map[ensembl_id])
#             embedding_indices.append(i)

#     # Create the reordered embedding matrix
#     embedding_matrix = (
#         merged_embeddings[embedding_cols].iloc[embedding_indices].values.T
#     )

#     return embedding_matrix, valid_indices

In [17]:
# def create_cell_embeddings(expression_matrix, embedding_matrix, valid_indices):
#     """
#     Create normalized cell embeddings from gene expression data and gene embeddings.

#     Args:
#         expression_matrix: scipy.sparse.csr_matrix or numpy array of shape (n_cells, n_genes)
#         embedding_matrix: numpy array of shape (n_embedding_dims, n_valid_genes)
#         valid_indices: list of indices to select genes that have embeddings

#     Returns:
#         numpy array of shape (n_cells, n_embedding_dims) containing normalized cell embeddings
#     """
#     # Select only the columns from expression matrix that have corresponding embeddings
#     filtered_expression = expression_matrix[:, valid_indices]

#     # Perform the matrix multiplication (n_cells x n_embedding_dimensions)
#     cell_embeddings = filtered_expression @ embedding_matrix.T

#     # Normalize the cell embeddings
#     norms = np.linalg.norm(cell_embeddings, axis=1, keepdims=True)
#     cell_embeddings = cell_embeddings / norms

#     return cell_embeddings

In [18]:
import src.inference as inference

embedding_matrix, valid_indices = inference.create_embedding_matrix(
    merged_embeddings, major_ensembl_ids
)

In [19]:
import time
import numpy as np

# Time the embedding creation
start_time = time.time()
cell_embeddings = inference.create_cell_embeddings(
    adata_filtered.X, embedding_matrix, valid_indices
)
end_time = time.time()

# Calculate metrics
total_time = end_time - start_time
cells_per_second = cell_embeddings.shape[0] / total_time

print(f"Shape of cell embeddings: {cell_embeddings.shape}")
print(f"Number of genes used: {len(valid_indices)} out of {len(major_ensembl_ids)} total genes")
print(f"Total time: {total_time:.2f} seconds")
print(f"Processing speed: {cells_per_second:.1f} cells/second")

Shape of cell embeddings: (100000, 3072)
Number of genes used: 33258 out of 61759 total genes
Total time: 246.57 seconds
Processing speed: 405.6 cells/second


In [38]:
# # Select only the columns from cell_gene_matrix that have corresponding embeddings
# filtered_expression = cell_gene_matrix[:, valid_indices]

# # Perform the matrix multiplication
# # This will give us (n_cells x n_embedding_dimensions)
# cell_embeddings = filtered_expression @ embedding_matrix.T
# # Normalize the cell embeddings
# norms = np.linalg.norm(cell_embeddings, axis=1, keepdims=True)
# cell_embeddings = cell_embeddings / norms


# print(f"Shape of cell embeddings: {cell_embeddings.shape}")
# print(f"Number of genes used: {len(valid_indices)} out of {len(major_ensembl_ids)} total genes")

In [20]:
adata_filtered.obs

Unnamed: 0,cell_type,broad_cell_class,donor_id
0,"naive thymus-derived CD4-positive, alpha-beta ...",t cell,TSP2
1,B cell,lymphocyte of b lineage,TSP2
2,B cell,lymphocyte of b lineage,TSP2
3,B cell,lymphocyte of b lineage,TSP2
4,"CD8-positive, alpha-beta T cell",t cell,TSP2
...,...,...,...
99995,endothelial cell of artery,endothelial cell,TSP2
99996,mesenchymal stem cell,stem cell,TSP2
99997,pericyte,contractile cell,TSP2
99998,skeletal muscle satellite stem cell,stem cell,TSP2


In [21]:
# def load_cell_metadata(file_path, start_row=0, n_rows=1000, columns=None):
#     """
#     Load metadata for specific cells.

#     Args:
#         file_path: Path to h5ad file
#         start_row: Starting row index
#         n_rows: Number of rows to load
#         columns: List of metadata columns to load (if None, load all)

#     Returns:
#         Dictionary of metadata arrays
#     """
#     with h5py.File(file_path, 'r') as f:
#         metadata = {}
#         obs_keys = list(f['obs'].keys()) if columns is None else columns
#         for key in obs_keys:
#             if key in f['obs']:
#                 column_group = f['obs'][key]
#                 if isinstance(column_group, h5py.Dataset):
#                     metadata[key] = column_group[start_row:start_row + n_rows]
#                 else:
#                     if 'categories' in column_group and 'codes' in column_group:
#                         # Get categories and decode from bytes to strings
#                         categories = [cat.decode('utf-8') for cat in column_group['categories'][:]]
#                         codes = column_group['codes'][start_row:start_row + n_rows]
#                         metadata[key] = np.array([categories[code] for code in codes])
#     return metadata

# # Example usage:
# metadata = load_cell_metadata(file_path, start_row=0, n_rows=1136219,
#                             columns=['cell_type', 'total_counts', 'broad_cell_class', 'donor_id'])
metadata = adata_filtered.obs

print("Loaded metadata keys:", list(metadata.keys()))

# Print first few values of each column
for key in metadata:
    print(f"\nFirst 5 values of {key}:")
    print(metadata[key][:5])

Loaded metadata keys: ['cell_type', 'broad_cell_class', 'donor_id']

First 5 values of cell_type:
0    naive thymus-derived CD4-positive, alpha-beta ...
1                                               B cell
2                                               B cell
3                                               B cell
4                      CD8-positive, alpha-beta T cell
Name: cell_type, dtype: category
Categories (180, object): ['male germ cell', 'spermatocyte', 'spermatid', 'spermatogonium', ..., 'alveolar adventitial fibroblast', 'BEST4+ intestinal epithelial cell, human', 'enteroglial cell', 'unknown']

First 5 values of broad_cell_class:
0                     t cell
1    lymphocyte of b lineage
2    lymphocyte of b lineage
3    lymphocyte of b lineage
4                     t cell
Name: broad_cell_class, dtype: category
Categories (40, object): ['adventitial cell', 'cardiac endothelial cell', 'ciliated epithelial cell', 'conjunctival epithelial cell', ..., 'stromal cell', 't cell', 

In [22]:
embed_genept_pdf = pd.DataFrame(cell_embeddings, index=metadata.index).merge(
    metadata, left_index=True, right_index=True
)
embed_genept_pdf.columns = [str(col) for col in embed_genept_pdf.columns]
embed_genept_pdf

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,3065,3066,3067,3068,3069,3070,3071,cell_type,broad_cell_class,donor_id
0,-0.007720,0.013851,-0.005736,0.010341,0.000320,-0.023683,0.041729,-0.033866,0.035611,0.005238,...,-0.015777,-0.000016,-0.017384,0.023981,-0.007408,-0.003723,-0.018266,"naive thymus-derived CD4-positive, alpha-beta ...",t cell,TSP2
1,-0.005579,0.015617,-0.005555,0.009382,0.001121,-0.022421,0.040697,-0.034355,0.035864,0.004520,...,-0.014589,0.000164,-0.017984,0.023778,-0.007307,-0.004185,-0.018247,B cell,lymphocyte of b lineage,TSP2
2,-0.007412,0.014734,-0.005714,0.010164,0.001523,-0.023222,0.042165,-0.034563,0.036910,0.005400,...,-0.014859,0.000339,-0.017272,0.024080,-0.006957,-0.004390,-0.017801,B cell,lymphocyte of b lineage,TSP2
3,-0.008063,0.014125,-0.005429,0.010397,-0.000308,-0.023580,0.041392,-0.034021,0.035953,0.005418,...,-0.015378,-0.000165,-0.017942,0.023627,-0.007249,-0.003548,-0.018380,B cell,lymphocyte of b lineage,TSP2
4,-0.006627,0.014602,-0.005593,0.009993,0.002419,-0.021874,0.041174,-0.034943,0.036720,0.005483,...,-0.015101,0.000418,-0.018152,0.023877,-0.006763,-0.004058,-0.017972,"CD8-positive, alpha-beta T cell",t cell,TSP2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,-0.006644,0.016508,-0.005705,0.007553,0.003714,-0.020780,0.041159,-0.035570,0.034712,0.004084,...,-0.014371,0.000565,-0.016427,0.023029,-0.006838,-0.005125,-0.018476,endothelial cell of artery,endothelial cell,TSP2
99996,-0.006253,0.017183,-0.005748,0.008020,0.003911,-0.021312,0.042195,-0.035562,0.034941,0.004608,...,-0.014135,0.000487,-0.016491,0.023043,-0.006724,-0.005049,-0.018681,mesenchymal stem cell,stem cell,TSP2
99997,-0.006700,0.016229,-0.005678,0.009247,0.003257,-0.021733,0.041717,-0.035562,0.035392,0.005058,...,-0.014526,0.001264,-0.016985,0.023072,-0.006645,-0.005376,-0.017690,pericyte,contractile cell,TSP2
99998,-0.006115,0.016010,-0.005867,0.007888,0.001904,-0.021571,0.041173,-0.034858,0.034616,0.004558,...,-0.014814,0.000545,-0.016843,0.023437,-0.006314,-0.004606,-0.018471,skeletal muscle satellite stem cell,stem cell,TSP2


In [23]:
embed_genept_pdf.to_parquet(data_dir / "tabula_sapiens_100k_genept_embedding_cell_type_tissue_drug_pathway_openai_large.parquet")

In [24]:
cell_type_labels = pd.Categorical(
    pd.Series(metadata["broad_cell_class"])[: len(cell_embeddings)]
)
cell_type_labels

['t cell', 'lymphocyte of b lineage', 'lymphocyte of b lineage', 'lymphocyte of b lineage', 't cell', ..., 'endothelial cell', 'stem cell', 'contractile cell', 'stem cell', 'stem cell']
Length: 100000
Categories (40, object): ['adventitial cell', 'cardiac endothelial cell', 'ciliated epithelial cell', 'conjunctival epithelial cell', ..., 'stromal cell', 't cell', 'transitional epithelial cell', 'vestibular dark cell']

In [25]:
metadata["broad_cell_class"].value_counts()

broad_cell_class
t cell                             14053
stromal cell                       13013
myeloid leukocyte                   8565
lymphocyte of b lineage             8499
contractile cell                    7916
fibroblast                          6995
endothelial cell                    6019
stem cell                           5937
granulocyte                         5797
intestinal epithelial cell          5764
transitional epithelial cell        5384
innate lymphoid cell                2507
glandular epithelial cell           1988
epithelial cell                     1810
cardiac endothelial cell            1088
epithelial cell of lung              819
endo-epithelial cell                 666
duct epithelial cell                 592
adventitial cell                     450
ciliated epithelial cell             415
hematopoietic cell                   375
erythroid lineage cell               313
secretory cell                       276
dendritic cell                       227

In [26]:
donor_ids = pd.Categorical(metadata["donor_id"][: len(cell_embeddings)])
donor_ids

['TSP2', 'TSP2', 'TSP2', 'TSP2', 'TSP2', ..., 'TSP2', 'TSP2', 'TSP2', 'TSP2', 'TSP2']
Length: 100000
Categories (24, object): ['TSP1', 'TSP2', 'TSP3', 'TSP4', ..., 'TSP26', 'TSP27', 'TSP28', 'TSP30']

In [27]:
cell_embeddings_pdf = pd.DataFrame(
    cell_embeddings, columns=list(range(cell_embeddings.shape[1]))
)
cell_embeddings_pdf["cell_type"] = cell_type_labels
cell_embeddings_pdf["donor_id"] = donor_ids

In [28]:
cell_embeddings_pdf.shape

(100000, 3074)

In [29]:
from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt
import plotly.express as px

# Convert all column names to strings before applying PCA
# features_24_weeks_df.columns = features_24_weeks_df.columns.astype(str)

# Now run PCA
# pca = PCA(n_components=50)
# pca_embeddings = pca.fit_transform(features_24_weeks_df.drop(columns=['drug dose', 'external_id', 'ifn status', "subject sex", "subject age", "subject_id"]))
# Then apply UMAP to the PCA results
reducer = umap.UMAP(random_state=42)
umap_sample_pdf = cell_embeddings_pdf.sample(2000)
umap_embeddings = reducer.fit_transform(
    umap_sample_pdf.drop(columns=["cell_type", "donor_id"])
)

# Create a DataFrame with the UMAP coordinates using the original DataFrame's index
umap_df = pd.DataFrame(
    umap_embeddings, columns=["UMAP1", "UMAP2"], index=umap_sample_pdf.index
)  # Use the original DataFrame's index

# Merge with sample attributes to get metadata for coloring
umap_df = umap_df.merge(
    umap_sample_pdf[["cell_type", "donor_id"]], left_index=True, right_index=True
)
# Create the plot
fig = px.scatter(
    umap_df,
    x="UMAP1",
    y="UMAP2",
    color="cell_type",
    opacity=0.7,
    title="UMAP Visualization of Gene Expression Embeddings",
)

# Update layout
fig.update_layout(title={"y": 0.95, "x": 0.5, "xanchor": "center", "yanchor": "top"})

fig.show()

  warn(
